24.02.1
|
GEMM reshape information class. More...
#include <Types.h>
Public Member Functions | |
GEMMReshapeInfo () | |
Default constructor. More... | |
GEMMReshapeInfo (int m, int n, int k, int mult_transpose1xW_width=1, int mult_interleave4x4_height=1, int depth_output_gemm3d=0, bool reinterpret_input_as_3d=false, bool broadcast_bias=false) | |
Constructor. More... | |
int | m () const |
Number of matrix A rows. More... | |
int | n () const |
Number of matrix B columns. More... | |
int | k () const |
Number of matrix A columns or matrix B rows. More... | |
int | mult_transpose1xW_width () const |
Multiplication factor for the width of the 1xW transposed block. More... | |
int | mult_interleave4x4_height () const |
Multiplication factor for the height of the 4x4 interleaved block. More... | |
int | depth_output_gemm3d () const |
Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. More... | |
bool | reinterpret_input_as_3d () const |
Flag which specifies if the input tensor has to be reinterpreted as 3D. More... | |
bool | broadcast_bias () const |
Flag which specifies whether to broadcast the shape of the bias tensor. More... | |
GEMM reshape information class.
This class stores the necessary information about matrix A and matrix B reshape.
The matrix A can only be reshaped through opencl::kernels::ClGemmReshapeLhsMatrixKernel or cpu::kernels::CpuGemmInterleave4x4Kernel Note: Optionally just for opencl::kernels::ClGemmReshapeLhsMatrixKernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block
The matrix B can only be reshaped through opencl::kernels::ClGemmReshapeRhsMatrixKernel or cpu::kernels::CpuGemmTranspose1xWKernel Note: Optionally just for opencl::kernels::ClGemmReshapeRhsMatrixKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block
|
inline |
|
inline |
Constructor.
[in] | m | Number of matrix A rows |
[in] | n | Number of matrix B columns |
[in] | k | Number of matrix A columns or matrix B rows |
[in] | mult_transpose1xW_width | (Optional) Multiplication factor for the width of the 1xW transposed block |
[in] | mult_interleave4x4_height | (Optional) Multiplication factor for the height of the 4x4 interleaved block |
[in] | depth_output_gemm3d | (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. If 0 the output will not be reinterpreted as 3D. Default 0 |
[in] | reinterpret_input_as_3d | (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used to perform 1x1 convolutions with the NHWC data layout) |
[in] | broadcast_bias | (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. |
Definition at line 1798 of file Types.h.
|
inline |
|
inline |
Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
Definition at line 1863 of file Types.h.
Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), ClGemmLowpMatrixMultiplyNativeKernel::configure(), and ClGemmLowpMatrixMultiplyReshapedKernel::configure().
|
inline |
Number of matrix A columns or matrix B rows.
Definition at line 1836 of file Types.h.
Referenced by ClGemmLowpMatrixMultiplyNativeKernel::configure(), and ClGemmLowpMatrixMultiplyReshapedKernel::configure().
|
inline |
Number of matrix A rows.
Definition at line 1820 of file Types.h.
Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), ClGemmLowpMatrixMultiplyNativeKernel::configure(), ClGemmLowpMatrixMultiplyReshapedKernel::configure(), and CpuGemmMatrixMultiplyKernel::configure().
|
inline |
|
inline |
|
inline |
Number of matrix B columns.
Definition at line 1828 of file Types.h.
Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), ClGemmLowpMatrixMultiplyNativeKernel::configure(), ClGemmLowpMatrixMultiplyReshapedKernel::configure(), and CpuGemmMatrixMultiplyKernel::configure().
|
inline |
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition at line 1871 of file Types.h.
Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), and ClGemmLowpMatrixMultiplyNativeKernel::configure().