Compute Library
 23.08
GEMMReshapeInfo Class Referencefinal

GEMM reshape information class. More...

#include <Types.h>

Public Member Functions

 GEMMReshapeInfo ()
 Default constructor. More...
 
 GEMMReshapeInfo (int m, int n, int k, int mult_transpose1xW_width=1, int mult_interleave4x4_height=1, int depth_output_gemm3d=0, bool reinterpret_input_as_3d=false, bool broadcast_bias=false)
 Constructor. More...
 
int m () const
 Number of matrix A rows. More...
 
int n () const
 Number of matrix B columns. More...
 
int k () const
 Number of matrix A columns or matrix B rows. More...
 
int mult_transpose1xW_width () const
 Multiplication factor for the width of the 1xW transposed block. More...
 
int mult_interleave4x4_height () const
 Multiplication factor for the height of the 4x4 interleaved block. More...
 
int depth_output_gemm3d () const
 Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. More...
 
bool reinterpret_input_as_3d () const
 Flag which specifies if the input tensor has to be reinterpreted as 3D. More...
 
bool broadcast_bias () const
 Flag which specifies whether to broadcast the shape of the bias tensor. More...
 

Detailed Description

GEMM reshape information class.

This class stores the necessary information about matrix A and matrix B reshape.

The matrix A can only be reshaped through opencl::kernels::ClGemmReshapeLhsMatrixKernel or cpu::kernels::CpuGemmInterleave4x4Kernel Note: Optionally just for opencl::kernels::ClGemmReshapeLhsMatrixKernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block

The matrix B can only be reshaped through opencl::kernels::ClGemmReshapeRhsMatrixKernel or cpu::kernels::CpuGemmTranspose1xWKernel Note: Optionally just for opencl::kernels::ClGemmReshapeRhsMatrixKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block

Definition at line 1697 of file Types.h.

Constructor & Destructor Documentation

◆ GEMMReshapeInfo() [1/2]

GEMMReshapeInfo ( )
inline

Default constructor.

Definition at line 1701 of file Types.h.

1702  : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false)
1703  {
1704  }

◆ GEMMReshapeInfo() [2/2]

GEMMReshapeInfo ( int  m,
int  n,
int  k,
int  mult_transpose1xW_width = 1,
int  mult_interleave4x4_height = 1,
int  depth_output_gemm3d = 0,
bool  reinterpret_input_as_3d = false,
bool  broadcast_bias = false 
)
inline

Constructor.

Parameters
[in]mNumber of matrix A rows
[in]nNumber of matrix B columns
[in]kNumber of matrix A columns or matrix B rows
[in]mult_transpose1xW_width(Optional) Multiplication factor for the width of the 1xW transposed block
[in]mult_interleave4x4_height(Optional) Multiplication factor for the height of the 4x4 interleaved block
[in]depth_output_gemm3d(Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. If 0 the output will not be reinterpreted as 3D. Default 0
[in]reinterpret_input_as_3d(Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used to perform 1x1 convolutions with the NHWC data layout)
[in]broadcast_bias(Optional) Broadcast the shape of the bias tensor from a vector to a matrix.

Definition at line 1718 of file Types.h.

1719  : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
1720  _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias)
1721  {
1722  }

Member Function Documentation

◆ broadcast_bias()

bool broadcast_bias ( ) const
inline

Flag which specifies whether to broadcast the shape of the bias tensor.

Returns
True if the shape of the bias tensor is to be broadcasted.

Definition at line 1786 of file Types.h.

1787  {
1788  return _broadcast_bias;
1789  };

◆ depth_output_gemm3d()

int depth_output_gemm3d ( ) const
inline

Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.

Note
GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case: m = depth_output_gemm3d * output_height
Returns
the depth of the output tensor to be used with the GEMM3D kernel

Definition at line 1770 of file Types.h.

1771  {
1772  return _depth_output_gemm3d;
1773  }

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape().

◆ k()

int k ( ) const
inline

Number of matrix A columns or matrix B rows.

Returns
the number of matrix A columns or matrix B rows

Definition at line 1743 of file Types.h.

1744  {
1745  return _k;
1746  }

◆ m()

int m ( ) const
inline

Number of matrix A rows.

Returns
the number of matrix A rows

Definition at line 1727 of file Types.h.

1728  {
1729  return _m;
1730  }

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), and CpuGemmMatrixMultiplyKernel::configure().

◆ mult_interleave4x4_height()

int mult_interleave4x4_height ( ) const
inline

Multiplication factor for the height of the 4x4 interleaved block.

Returns
the multiplication factor for the height of the 4x4 interleaved block

Definition at line 1759 of file Types.h.

1760  {
1761  return _mult_interleave4x4_height;
1762  }

◆ mult_transpose1xW_width()

int mult_transpose1xW_width ( ) const
inline

Multiplication factor for the width of the 1xW transposed block.

Returns
the multiplication factor for the width of the 1xW transposed block

Definition at line 1751 of file Types.h.

1752  {
1753  return _mult_transpose1xW_width;
1754  }

◆ n()

int n ( ) const
inline

Number of matrix B columns.

Returns
the number of matrix B columns

Definition at line 1735 of file Types.h.

1736  {
1737  return _n;
1738  }

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), and CpuGemmMatrixMultiplyKernel::configure().

◆ reinterpret_input_as_3d()

bool reinterpret_input_as_3d ( ) const
inline

Flag which specifies if the input tensor has to be reinterpreted as 3D.

Returns
True if the input tensor has to be reinterpreted as 3D tensor

Definition at line 1778 of file Types.h.

1779  {
1780  return _reinterpret_input_as_3d;
1781  };

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape().


The documentation for this class was generated from the following file:
arm_compute::GEMMReshapeInfo::k
int k() const
Number of matrix A columns or matrix B rows.
Definition: Types.h:1743
arm_compute::GEMMReshapeInfo::n
int n() const
Number of matrix B columns.
Definition: Types.h:1735
arm_compute::GEMMReshapeInfo::reinterpret_input_as_3d
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:1778
arm_compute::GEMMReshapeInfo::m
int m() const
Number of matrix A rows.
Definition: Types.h:1727
arm_compute::GEMMReshapeInfo::broadcast_bias
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:1786
arm_compute::GEMMReshapeInfo::mult_transpose1xW_width
int mult_transpose1xW_width() const
Multiplication factor for the width of the 1xW transposed block.
Definition: Types.h:1751
arm_compute::GEMMReshapeInfo::depth_output_gemm3d
int depth_output_gemm3d() const
Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
Definition: Types.h:1770
arm_compute::GEMMReshapeInfo::mult_interleave4x4_height
int mult_interleave4x4_height() const
Multiplication factor for the height of the 4x4 interleaved block.
Definition: Types.h:1759