Compute Library
 23.08
GEMMInfo Class Reference

GEMM information class. More...

#include <GEMMInfo.h>

Public Member Functions

 GEMMInfo () noexcept
 Default constructor. More...
 
 GEMMInfo (bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d=0, bool reinterpret_input_as_3d=false, bool retain_internal_weights=false, GEMMLowpOutputStageInfo gemmlowp_output_stage=GEMMLowpOutputStageInfo(), bool fp_mixed_precision=false, bool fast_math=false, bool broadcast_bias=false, const ActivationLayerInfo &activation_info=ActivationLayerInfo(), const experimental::PostOpList< ITensorInfo * > &post_ops=experimental::PostOpList< ITensorInfo * >(), bool fixed_format=false, arm_compute::WeightFormat weight_format=arm_compute::WeightFormat::UNSPECIFIED) noexcept
 Constructor. More...
 
bool is_a_reshaped () const
 Flag which specifies if the matrix A has been reshaped. More...
 
bool is_b_reshaped () const
 Flag which specifies if the matrix B has been reshaped. More...
 
bool reshape_b_only_on_first_run () const
 Flag which specifies if the reshape of matrix B should executed only for the first. More...
 
int depth_output_gemm3d () const
 Depth of the output when GEMM output is reinterpreted as 3D tensor. More...
 
bool reinterpret_input_as_3d () const
 Flag which specifies if the input tensor has to be reinterpreted as 3D. More...
 
bool retain_internal_weights () const
 Flag which specifies if the weights tensor has to be retained from previous run. More...
 
GEMMLowpOutputStageInfo gemmlowp_output_stage () const
 GEMMLowp output stage. More...
 
void set_gemmlowp_output_stage (GEMMLowpOutputStageInfo &output_stage)
 Sets GEMMLowp output stage. More...
 
bool fp_mixed_precision () const
 Flag which specifies if a wider accumulator should be used. More...
 
bool fast_math () const
 Flag which specifies if a shorter accumulator to be used. More...
 
void set_fast_math (bool fast_math)
 Set fast math flag. More...
 
bool broadcast_bias () const
 Flag which specifies whether to broadcast the shape of the bias tensor. More...
 
bool pretranspose_A () const
 Flag which specifies whether A should be pre-transposed if supported. More...
 
void set_pretranspose_A (bool flag)
 Set pre-transpose A flag. More...
 
bool pretranspose_B () const
 Flag which specifies whether b should be pre-transposed if supported. More...
 
void set_pretranspose_B (bool flag)
 Set pre-transpose b flag. More...
 
ActivationLayerInfo activation_info () const
 Activation layer to apply after the matrix multiplication. More...
 
void set_activation_info (const ActivationLayerInfo &activation_info)
 Set activation layer info. More...
 
const experimental::PostOpList< ITensorInfo * > & post_ops () const
 Post operations to apply after the matrix multiplication. More...
 
void set_post_ops (const experimental::PostOpList< ITensorInfo * > &post_ops)
 Set post ops. More...
 
bool fixed_format () const
 Flag which specifies if the GEMM operation is running fixed-format kernels. More...
 
void set_fixed_format (bool fixed_format)
 Set fixed-format flag. More...
 
arm_compute::WeightFormat weight_format () const
 
void set_weight_format (arm_compute::WeightFormat weight_format)
 Set weight format to be used. More...
 

Detailed Description

GEMM information class.

This class stores the necessary information to compute GEMM functions

This object also contains the information about how matrix A and matrix B have been reshaped

Definition at line 64 of file GEMMInfo.h.

Constructor & Destructor Documentation

◆ GEMMInfo() [1/2]

GEMMInfo ( )
inlinenoexcept

Default constructor.

Definition at line 68 of file GEMMInfo.h.

69  : _is_a_reshaped(false),
70  _is_b_reshaped(false),
71  _reshape_b_only_on_first_run(true),
72  _depth_output_gemm3d(0),
73  _reinterpret_input_as_3d(false),
74  _retain_internal_weights(false),
75  _gemmlowp_output_stage(),
76  _fast_math(false),
77  _fp_mixed_precision(false),
78  _broadcast_bias(false),
79  _pretranspose_A(false),
80  _pretranspose_B(false),
81  _activation_info(),
82  _post_ops(),
83  _fixed_format(false),
85  {
86  }

◆ GEMMInfo() [2/2]

GEMMInfo ( bool  is_a_reshaped,
bool  is_b_reshaped,
bool  reshape_b_only_on_first_run,
int  depth_output_gemm3d = 0,
bool  reinterpret_input_as_3d = false,
bool  retain_internal_weights = false,
GEMMLowpOutputStageInfo  gemmlowp_output_stage = GEMMLowpOutputStageInfo(),
bool  fp_mixed_precision = false,
bool  fast_math = false,
bool  broadcast_bias = false,
const ActivationLayerInfo activation_info = ActivationLayerInfo(),
const experimental::PostOpList< ITensorInfo * > &  post_ops = experimental::PostOpList<ITensorInfo *>(),
bool  fixed_format = false,
arm_compute::WeightFormat  weight_format = arm_compute::WeightFormat::UNSPECIFIED 
)
inlinenoexcept

Constructor.

Parameters
[in]is_a_reshapedTrue if the matrix A has been reshaped
[in]is_b_reshapedTrue if the matrix B has been reshaped
[in]reshape_b_only_on_first_runReshape matrix B only for the first run
[in]depth_output_gemm3d(Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel If 0 the output will not be reinterpreted as 3D. Default 0
[in]reinterpret_input_as_3d(Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used to perform 1x1 convolutions with the NHWC data layout)
[in]retain_internal_weights(Optional) Retain the weights tensor from previous run
[in]gemmlowp_output_stage(Optional) GEMMLowp Output stage info
[in]fp_mixed_precision(Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
[in]fast_math(Optional) Use a data type of shorter width to improve performance
[in]broadcast_bias(Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
[in]activation_info(Optional) Activation to apply after the matrix multiplication
[in]post_ops(Optional) A sequence of post operations that are performed after the main operation.
[in]fixed_format(Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
[in]weight_format(Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.

Definition at line 106 of file GEMMInfo.h.

110  : _is_a_reshaped(is_a_reshaped),
111  _is_b_reshaped(is_b_reshaped),
112  _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
113  _depth_output_gemm3d(depth_output_gemm3d),
114  _reinterpret_input_as_3d(reinterpret_input_as_3d),
115  _retain_internal_weights(retain_internal_weights),
116  _gemmlowp_output_stage(gemmlowp_output_stage),
117  _fast_math(fast_math),
118  _fp_mixed_precision(fp_mixed_precision),
119  _broadcast_bias(broadcast_bias),
120  _pretranspose_A(false),
121  _pretranspose_B(false),
122  _activation_info(activation_info),
123  _post_ops(post_ops),
124  _fixed_format(fixed_format),
125  _weight_format(weight_format)
126  {
127  }

Member Function Documentation

◆ activation_info()

ActivationLayerInfo activation_info ( ) const
inline

Activation layer to apply after the matrix multiplication.

Returns
ActivationLayerInfo object

Definition at line 262 of file GEMMInfo.h.

263  {
264  return _activation_info;
265  }

Referenced by GEMMInfo::set_activation_info().

◆ broadcast_bias()

bool broadcast_bias ( ) const
inline

Flag which specifies whether to broadcast the shape of the bias tensor.

Returns
True if the shape of the bias tensor is to be broadcasted.

Definition at line 222 of file GEMMInfo.h.

223  {
224  return _broadcast_bias;
225  };

◆ depth_output_gemm3d()

int depth_output_gemm3d ( ) const
inline

Depth of the output when GEMM output is reinterpreted as 3D tensor.

Returns
the depth of the output tensor

Definition at line 158 of file GEMMInfo.h.

159  {
160  return _depth_output_gemm3d;
161  };

◆ fast_math()

bool fast_math ( ) const
inline

Flag which specifies if a shorter accumulator to be used.

Returns
True if a shorter accumulator has to be used

Definition at line 206 of file GEMMInfo.h.

207  {
208  return _fast_math;
209  };

Referenced by GEMMInfo::set_fast_math().

◆ fixed_format()

bool fixed_format ( ) const
inline

Flag which specifies if the GEMM operation is running fixed-format kernels.

Returns
True if the GEMM operation is running fixed-format kernel else false.

Definition at line 294 of file GEMMInfo.h.

295  {
296  return _fixed_format;
297  }

Referenced by GEMMInfo::set_fixed_format().

◆ fp_mixed_precision()

bool fp_mixed_precision ( ) const
inline

Flag which specifies if a wider accumulator should be used.

Returns
True if a wider accumulator has to be used

Definition at line 198 of file GEMMInfo.h.

199  {
200  return _fp_mixed_precision;
201  };

◆ gemmlowp_output_stage()

GEMMLowpOutputStageInfo gemmlowp_output_stage ( ) const
inline

GEMMLowp output stage.

Returns
the GEMMLowp output stage info

Definition at line 182 of file GEMMInfo.h.

183  {
184  return _gemmlowp_output_stage;
185  };

Referenced by ClGemmLowpMatrixMultiplyCore::prepare(), and CpuGemmLowpMatrixMultiplyCore::run().

◆ is_a_reshaped()

bool is_a_reshaped ( ) const
inline

Flag which specifies if the matrix A has been reshaped.

Returns
True if the matrix A has been reshaped

Definition at line 132 of file GEMMInfo.h.

133  {
134  return _is_a_reshaped;
135  };

◆ is_b_reshaped()

bool is_b_reshaped ( ) const
inline

Flag which specifies if the matrix B has been reshaped.

Returns
True if the matrix B has been reshaped

Definition at line 140 of file GEMMInfo.h.

141  {
142  return _is_b_reshaped;
143  };

◆ post_ops()

const experimental::PostOpList<ITensorInfo *>& post_ops ( ) const
inline

Post operations to apply after the matrix multiplication.

Returns
experimental::PostOpList object

Definition at line 278 of file GEMMInfo.h.

279  {
280  return _post_ops;
281  }

Referenced by GEMMInfo::set_post_ops().

◆ pretranspose_A()

bool pretranspose_A ( ) const
inline

Flag which specifies whether A should be pre-transposed if supported.

Returns
True if A should be pre-transposed else false.

Definition at line 230 of file GEMMInfo.h.

231  {
232  return _pretranspose_A;
233  };

◆ pretranspose_B()

bool pretranspose_B ( ) const
inline

Flag which specifies whether b should be pre-transposed if supported.

Returns
True if b should be pre-transposed else false.

Definition at line 246 of file GEMMInfo.h.

247  {
248  return _pretranspose_B;
249  };

◆ reinterpret_input_as_3d()

bool reinterpret_input_as_3d ( ) const
inline

Flag which specifies if the input tensor has to be reinterpreted as 3D.

Returns
True if the input tensor has to be reinterpreted as 3D tensor

Definition at line 166 of file GEMMInfo.h.

167  {
168  return _reinterpret_input_as_3d;
169  };

◆ reshape_b_only_on_first_run()

bool reshape_b_only_on_first_run ( ) const
inline

Flag which specifies if the reshape of matrix B should executed only for the first.

Note
This flag could be set to TRUE when GEMM is used to accelerate convolution layer
Returns
True if the reshaped of matrix B happens only for the first run

Definition at line 150 of file GEMMInfo.h.

151  {
152  return _reshape_b_only_on_first_run;
153  };

◆ retain_internal_weights()

bool retain_internal_weights ( ) const
inline

Flag which specifies if the weights tensor has to be retained from previous run.

Returns
True if the weights tensor has to be retained

Definition at line 174 of file GEMMInfo.h.

175  {
176  return _retain_internal_weights;
177  };

◆ set_activation_info()

void set_activation_info ( const ActivationLayerInfo activation_info)
inline

Set activation layer info.

Parameters
[in]activation_infoActivationLayerInfo object to set

Definition at line 270 of file GEMMInfo.h.

271  {
272  _activation_info = activation_info;
273  }

References GEMMInfo::activation_info().

◆ set_fast_math()

void set_fast_math ( bool  fast_math)
inline

Set fast math flag.

Parameters
[in]fast_mathFlag to set

Definition at line 214 of file GEMMInfo.h.

215  {
216  _fast_math = fast_math;
217  }

References GEMMInfo::fast_math().

◆ set_fixed_format()

void set_fixed_format ( bool  fixed_format)
inline

Set fixed-format flag.

Parameters
[in]fixed_formatsets whether or not to use fixed-format kernels

Definition at line 303 of file GEMMInfo.h.

304  {
305  _fixed_format = fixed_format;
306  }

References GEMMInfo::fixed_format().

◆ set_gemmlowp_output_stage()

void set_gemmlowp_output_stage ( GEMMLowpOutputStageInfo output_stage)
inline

Sets GEMMLowp output stage.

Parameters
[in]output_stageOutput stage to set

Definition at line 190 of file GEMMInfo.h.

191  {
192  _gemmlowp_output_stage = output_stage;
193  };

References output_stage.

◆ set_post_ops()

void set_post_ops ( const experimental::PostOpList< ITensorInfo * > &  post_ops)
inline

Set post ops.

Parameters
[in]post_opsexperimental::PostOpList object to set

Definition at line 286 of file GEMMInfo.h.

287  {
288  _post_ops = post_ops;
289  }

References GEMMInfo::post_ops().

◆ set_pretranspose_A()

void set_pretranspose_A ( bool  flag)
inline

Set pre-transpose A flag.

Parameters
[in]flagFlag to set

Definition at line 238 of file GEMMInfo.h.

239  {
240  _pretranspose_A = flag;
241  }

◆ set_pretranspose_B()

void set_pretranspose_B ( bool  flag)
inline

Set pre-transpose b flag.

Parameters
[in]flagFlag to set

Definition at line 254 of file GEMMInfo.h.

255  {
256  _pretranspose_B = flag;
257  }

◆ set_weight_format()

void set_weight_format ( arm_compute::WeightFormat  weight_format)
inline

Set weight format to be used.

Parameters
[in]weight_formatarm_compute::WeightFormat enumeration

Definition at line 317 of file GEMMInfo.h.

318  {
319  _weight_format = weight_format;
320  }

References GEMMInfo::weight_format().

◆ weight_format()

arm_compute::WeightFormat weight_format ( ) const
inline

Definition at line 308 of file GEMMInfo.h.

309  {
310  return _weight_format;
311  }

Referenced by GEMMInfo::set_weight_format().


The documentation for this class was generated from the following file:
arm_compute::GEMMInfo::fp_mixed_precision
bool fp_mixed_precision() const
Flag which specifies if a wider accumulator should be used.
Definition: GEMMInfo.h:198
arm_compute::GEMMInfo::is_a_reshaped
bool is_a_reshaped() const
Flag which specifies if the matrix A has been reshaped.
Definition: GEMMInfo.h:132
arm_compute::GEMMInfo::fixed_format
bool fixed_format() const
Flag which specifies if the GEMM operation is running fixed-format kernels.
Definition: GEMMInfo.h:294
arm_compute::GEMMInfo::weight_format
arm_compute::WeightFormat weight_format() const
Definition: GEMMInfo.h:308
arm_compute::GEMMInfo::fast_math
bool fast_math() const
Flag which specifies if a shorter accumulator to be used.
Definition: GEMMInfo.h:206
output_stage
const OutputStage & output_stage
Definition: working_space.hpp:107
arm_compute::GEMMInfo::reshape_b_only_on_first_run
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
Definition: GEMMInfo.h:150
arm_compute::GEMMInfo::is_b_reshaped
bool is_b_reshaped() const
Flag which specifies if the matrix B has been reshaped.
Definition: GEMMInfo.h:140
arm_compute::WeightFormat::UNSPECIFIED
@ UNSPECIFIED
arm_compute::GEMMInfo::broadcast_bias
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: GEMMInfo.h:222
arm_compute::GEMMInfo::activation_info
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Definition: GEMMInfo.h:262
arm_compute::GEMMInfo::retain_internal_weights
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
Definition: GEMMInfo.h:174
arm_compute::GEMMInfo::reinterpret_input_as_3d
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: GEMMInfo.h:166
arm_compute::GEMMInfo::post_ops
const experimental::PostOpList< ITensorInfo * > & post_ops() const
Post operations to apply after the matrix multiplication.
Definition: GEMMInfo.h:278
arm_compute::GEMMInfo::gemmlowp_output_stage
GEMMLowpOutputStageInfo gemmlowp_output_stage() const
GEMMLowp output stage.
Definition: GEMMInfo.h:182
arm_compute::GEMMInfo::depth_output_gemm3d
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
Definition: GEMMInfo.h:158