21.05
|
Interface to convert the 2D Fully Connected weights from NCHW to NHWC or vice versa. More...
#include <CpuConvertFullyConnectedWeightsKernel.h>
Public Member Functions | |
CpuConvertFullyConnectedWeightsKernel () | |
Default constructor. More... | |
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE (CpuConvertFullyConnectedWeightsKernel) | |
void | configure (const ITensorInfo *src, ITensorInfo *dst, const TensorShape &original_input_shape, DataLayout data_layout) |
Set the src and dst tensor. More... | |
void | run_op (ITensorPack &tensors, const Window &window, const ThreadInfo &info) override |
Execute the kernel on the passed window. More... | |
const char * | name () const override |
Name of the kernel. More... | |
Public Member Functions inherited from ICPPKernel | |
virtual | ~ICPPKernel ()=default |
Default destructor. More... | |
virtual void | run (const Window &window, const ThreadInfo &info) |
Execute the kernel on the passed window. More... | |
virtual void | run_nd (const Window &window, const ThreadInfo &info, const Window &thread_locator) |
legacy compatibility layer for implemantions which do not support thread_locator In these cases we simply narrow the interface down the legacy version More... | |
Public Member Functions inherited from IKernel | |
IKernel () | |
Constructor. More... | |
virtual | ~IKernel ()=default |
Destructor. More... | |
virtual bool | is_parallelisable () const |
Indicates whether or not the kernel is parallelisable. More... | |
virtual BorderSize | border_size () const |
The size of the border for that kernel. More... | |
const Window & | window () const |
The maximum window the kernel can be executed on. More... | |
bool | is_window_configured () const |
Function to check if the embedded window of this kernel has been configured. More... | |
Static Public Member Functions | |
static Status | validate (const ITensorInfo *src, const ITensorInfo *dst, const TensorShape &original_input_shape, DataLayout data_layout) |
Static function to check if given info will lead to a valid configuration of CpuConvertFullyConnectedWeightsKernel. More... | |
Interface to convert the 2D Fully Connected weights from NCHW to NHWC or vice versa.
Definition at line 44 of file CpuConvertFullyConnectedWeightsKernel.h.
Default constructor.
Definition at line 37 of file CpuConvertFullyConnectedWeightsKernel.cpp.
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE | ( | CpuConvertFullyConnectedWeightsKernel | ) |
void configure | ( | const ITensorInfo * | src, |
ITensorInfo * | dst, | ||
const TensorShape & | original_input_shape, | ||
DataLayout | data_layout | ||
) |
Set the src and dst tensor.
[in] | src | Source weights tensor info to convert. Must be 2 dimensional. Data types supported: All. |
[in] | dst | The converted weights tensor info. Shape and Data Type: Same as src . |
[in] | original_input_shape | Shape of the original src tensor (the one entering fully connected layer). |
[in] | data_layout | The data layout the weights have been trained in. |
Definition at line 42 of file CpuConvertFullyConnectedWeightsKernel.cpp.
References ARM_COMPUTE_ERROR_ON_NULLPTR, ARM_COMPUTE_ERROR_THROW_ON, arm_compute::auto_init_if_empty(), arm_compute::calculate_max_window(), arm_compute::CHANNEL, arm_compute::test::validation::data_layout, arm_compute::test::validation::dst, arm_compute::get_data_layout_dimension_index(), arm_compute::HEIGHT, arm_compute::NCHW, arm_compute::NHWC, arm_compute::test::validation::src, CpuConvertFullyConnectedWeightsKernel::validate(), and arm_compute::WIDTH.
|
overridevirtual |
Name of the kernel.
Implements ICPPKernel.
Definition at line 131 of file CpuConvertFullyConnectedWeightsKernel.cpp.
|
overridevirtual |
Execute the kernel on the passed window.
[in] | tensors | A vector containing the tensors to operate on. |
[in] | window | Region on which to execute the kernel. (Must be a region of the window returned by window()) |
[in] | info | Info about executing thread and CPU. |
Reimplemented from ICPPKernel.
Definition at line 105 of file CpuConvertFullyConnectedWeightsKernel.cpp.
References arm_compute::ACL_DST, arm_compute::ACL_SRC, ARM_COMPUTE_ERROR, ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW, ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL, ARM_COMPUTE_UNUSED, arm_compute::test::validation::dst, ITensorPack::get_const_tensor(), ITensorPack::get_tensor(), arm_compute::test::validation::info, arm_compute::test::validation::src, and IKernel::window().
|
static |
Static function to check if given info will lead to a valid configuration of CpuConvertFullyConnectedWeightsKernel.
[in] | src | Source weights tensor info to convert. Must be 2 dimensional. Data types supported: All. |
[in] | dst | The converted weights tensor info. Shape and Data Type: Same as src . |
[in] | original_input_shape | Shape of the original src tensor (the one entering fully connected layer). |
[in] | data_layout | The data layout the weights have been trained in. |
Definition at line 70 of file CpuConvertFullyConnectedWeightsKernel.cpp.
References ARM_COMPUTE_RETURN_ERROR_ON, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR, arm_compute::test::validation::data_layout, arm_compute::test::validation::dst, arm_compute::test::validation::src, TensorShape::total_size_lower(), and arm_compute::UNKNOWN.
Referenced by CpuConvertFullyConnectedWeightsKernel::configure(), and CpuConvertFullyConnectedWeights::validate().