24 #ifndef ARM_COMPUTE_NEGEMMCONVOLUTIONLAYER_H 25 #define ARM_COMPUTE_NEGEMMCONVOLUTIONLAYER_H 45 class NEWeightsReshapeKernel;
92 std::unique_ptr<NEWeightsReshapeKernel> _weights_reshape_kernel;
95 namespace weights_transformations
115 _bias_bit = (biases !=
nullptr) ? 1 : 0;
116 _func.configure(input, biases, &_output);
121 _output.allocator()->allocate();
133 _output.allocator()->free();
138 return ((0x8) | (_bias_bit << 7));
149 int32_t _bias_bit{ 0 };
253 int gemm_3d_depth = 1,
bool skip_im2col =
false);
271 std::unique_ptr<NEIm2ColKernel> _im2col_kernel;
274 std::unique_ptr<NECol2ImKernel> _col2im_kernel;
277 const ITensor *_original_weights;
278 const ITensor *_original_output;
NEConvolutionLayerReshapeWeights()
Constructor.
Base class for all functions.
Basic function to execute GEMM on Neon.
Function to reshape the weights.
Store the tensor's metadata.
NEConvolutionLayerReshapeWeights & operator=(const NEConvolutionLayerReshapeWeights &)=delete
Prevent instances of this class from being copied (As this class contains pointers) ...
Activation Layer Information class.
Interface for Neon tensor.
Basic function to compute the convolution layer.
Copyright (c) 2017-2021 Arm Limited.
Convolution Layer Weights Information class.
static Status validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NEConvolutionLayerReshap...
virtual void prepare()
Prepare the function for executing.
const unsigned int num_groups
Basic implementation of the tensor interface.
Padding and stride information class.
Weights manager interface to handle weights transformations.
void run() override
Run the kernels contained in the function.
Basic function to run cpu::kernels::CpuReshapeKernel.
Class for specifying the size of an image or rectangle.
void configure(const ITensor *weights, const ITensor *biases, ITensor *output)
Set the input and output tensors.
DataLayout
[DataLayout enum definition]
~NEConvolutionLayerReshapeWeights()
Default destructor.
Basic function to execute GEMMLowpMatrixMultiplyCore on Neon.