Compute Library
 21.02
GEMMReshapeInfo Class Referencefinal

GEMM reshape information class. More...

#include <Types.h>

Public Member Functions

 GEMMReshapeInfo ()
 Default constructor. More...
 
 GEMMReshapeInfo (int m, int n, int k, int mult_transpose1xW_width=1, int mult_interleave4x4_height=1, int depth_output_gemm3d=0, bool reinterpret_input_as_3d=false, bool broadcast_bias=false)
 Constructor. More...
 
int m () const
 Number of matrix A rows. More...
 
int n () const
 Number of matrix B columns. More...
 
int k () const
 Number of matrix A columns or matrix B rows. More...
 
int mult_transpose1xW_width () const
 Multiplication factor for the width of the 1xW transposed block. More...
 
int mult_interleave4x4_height () const
 Multiplication factor for the height of the 4x4 interleaved block. More...
 
int depth_output_gemm3d () const
 Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. More...
 
bool reinterpret_input_as_3d () const
 Flag which specifies if the input tensor has to be reinterpreted as 3D. More...
 
bool broadcast_bias () const
 Flag which specifies whether to broadcast the shape of the bias tensor. More...
 

Detailed Description

GEMM reshape information class.

This class stores the necessary information about matrix A and matrix B reshape.

The matrix A can only be reshaped through CLGEMMReshapeLHSMatrixKernel or NEGEMMInterleave4x4Kernel or GCGEMMInterleave4x4Kernel Note: Optionally just for CLGEMMReshapeLHSMatrixKernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block

The matrix B can only be reshaped through CLGEMMReshapeRHSMatrixKernel or NEGEMMTranspose1xWKernel or GCGEMMTranspose1xWKernel Note: Optionally just for CLGEMMReshapeRHSMatrixKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block

Definition at line 1831 of file Types.h.

Constructor & Destructor Documentation

◆ GEMMReshapeInfo() [1/2]

GEMMReshapeInfo ( )
inline

Default constructor.

Definition at line 1835 of file Types.h.

1836  : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false)
1837  {
1838  }

◆ GEMMReshapeInfo() [2/2]

GEMMReshapeInfo ( int  m,
int  n,
int  k,
int  mult_transpose1xW_width = 1,
int  mult_interleave4x4_height = 1,
int  depth_output_gemm3d = 0,
bool  reinterpret_input_as_3d = false,
bool  broadcast_bias = false 
)
inline

Constructor.

Parameters
[in]mNumber of matrix A rows
[in]nNumber of matrix B columns
[in]kNumber of matrix A columns or matrix B rows
[in]mult_transpose1xW_width(Optional) Multiplication factor for the width of the 1xW transposed block
[in]mult_interleave4x4_height(Optional) Multiplication factor for the height of the 4x4 interleaved block
[in]depth_output_gemm3d(Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. If 0 the output will not be reinterpreted as 3D. Default 0
[in]reinterpret_input_as_3d(Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used to perform 1x1 convolutions with the NHWC data layout)
[in]broadcast_bias(Optional) Broadcast the shape of the bias tensor from a vector to a matrix.

Definition at line 1852 of file Types.h.

1853  : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
1854  _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias)
1855  {
1856  }
int mult_interleave4x4_height() const
Multiplication factor for the height of the 4x4 interleaved block.
Definition: Types.h:1893
int mult_transpose1xW_width() const
Multiplication factor for the width of the 1xW transposed block.
Definition: Types.h:1885
int n() const
Number of matrix B columns.
Definition: Types.h:1869
int k() const
Number of matrix A columns or matrix B rows.
Definition: Types.h:1877
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:1920
int m() const
Number of matrix A rows.
Definition: Types.h:1861
int depth_output_gemm3d() const
Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
Definition: Types.h:1904
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:1912

Member Function Documentation

◆ broadcast_bias()

bool broadcast_bias ( ) const
inline

Flag which specifies whether to broadcast the shape of the bias tensor.

Returns
True if the shape of the bias tensor is to be broadcasted.

Definition at line 1920 of file Types.h.

1921  {
1922  return _broadcast_bias;
1923  };

◆ depth_output_gemm3d()

int depth_output_gemm3d ( ) const
inline

Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.

Note
GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case: m = depth_output_gemm3d * output_height
Returns
the depth of the output tensor to be used with the GEMM3D kernel

Definition at line 1904 of file Types.h.

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), CLGEMMLowpMatrixMultiplyNativeKernel::configure(), and CLGEMMLowpMatrixMultiplyReshapedKernel::configure().

1905  {
1906  return _depth_output_gemm3d;
1907  }

◆ k()

int k ( ) const
inline

Number of matrix A columns or matrix B rows.

Returns
the number of matrix A columns or matrix B rows

Definition at line 1877 of file Types.h.

Referenced by CLGEMMLowpMatrixMultiplyReshapedKernel::configure(), and arm_compute::operator<<().

1878  {
1879  return _k;
1880  }

◆ m()

int m ( ) const
inline

Number of matrix A rows.

Returns
the number of matrix A rows

Definition at line 1861 of file Types.h.

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), NEGEMMMatrixMultiplyKernel::configure(), and arm_compute::operator<<().

1862  {
1863  return _m;
1864  }

◆ mult_interleave4x4_height()

int mult_interleave4x4_height ( ) const
inline

Multiplication factor for the height of the 4x4 interleaved block.

Returns
the multiplication factor for the height of the 4x4 interleaved block

Definition at line 1893 of file Types.h.

Referenced by arm_compute::operator<<().

1894  {
1895  return _mult_interleave4x4_height;
1896  }

◆ mult_transpose1xW_width()

int mult_transpose1xW_width ( ) const
inline

Multiplication factor for the width of the 1xW transposed block.

Returns
the multiplication factor for the width of the 1xW transposed block

Definition at line 1885 of file Types.h.

Referenced by arm_compute::operator<<().

1886  {
1887  return _mult_transpose1xW_width;
1888  }

◆ n()

int n ( ) const
inline

Number of matrix B columns.

Returns
the number of matrix B columns

Definition at line 1869 of file Types.h.

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), NEGEMMMatrixMultiplyKernel::configure(), and arm_compute::operator<<().

1870  {
1871  return _n;
1872  }

◆ reinterpret_input_as_3d()

bool reinterpret_input_as_3d ( ) const
inline

Flag which specifies if the input tensor has to be reinterpreted as 3D.

Returns
True if the input tensor has to be reinterpreted as 3D tensor

Definition at line 1912 of file Types.h.

Referenced by arm_compute::misc::shape_calculator::compute_mm_shape(), and CLGEMMLowpMatrixMultiplyNativeKernel::configure().

1913  {
1914  return _reinterpret_input_as_3d;
1915  };

The documentation for this class was generated from the following file: