Compute Library
 22.08
CpuGemmAssemblyDispatch Class Reference

Assembly kernel glue. More...

#include <CpuGemmAssemblyDispatch.h>

Collaboration diagram for CpuGemmAssemblyDispatch:
[legend]

Data Structures

class  IFallback
 

Public Member Functions

 CpuGemmAssemblyDispatch ()
 Constructor. More...
 
 ~CpuGemmAssemblyDispatch ()=default
 Defautl destructor. More...
 
 ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE (CpuGemmAssemblyDispatch)
 
void configure (const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
 If supported create a Compute Library function else fallback to the arm_gemm function. More...
 
bool is_configured () const
 Was the function successfully configured ? More...
 
bool isVarWeightsKernel () const
 Indicates if the convolution executes in variable weights mode. More...
 
void prepare (ITensorPack &tensors) override
 Prepare the function for executing. More...
 
void run (ITensorPack &tensors) override
 Run the kernels contained in the function. More...
 
experimental::MemoryRequirements workspace () const override
 Return the memory requirements required by the workspace. More...
 
- Public Member Functions inherited from INEOperator
 INEOperator (IRuntimeContext *ctx=nullptr)
 Constructor. More...
 
 INEOperator (const INEOperator &)=delete
 Prevent instances of this class from being copied (As this class contains pointers) More...
 
 INEOperator (INEOperator &&)=default
 Default move constructor. More...
 
INEOperatoroperator= (const INEOperator &)=delete
 Prevent instances of this class from being copied (As this class contains pointers) More...
 
INEOperatoroperator= (INEOperator &&)=default
 Default move assignment operator. More...
 
 ~INEOperator ()
 Default destructor. More...
 
- Public Member Functions inherited from IOperator
virtual ~IOperator ()=default
 Destructor. More...
 

Static Public Member Functions

static Status validate (const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
 Indicates whether or not this function can be used to process the given parameters. More...
 
static Status has_opt_impl (arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
 Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. More...
 
static bool is_activation_supported (const ActivationLayerInfo &activation)
 Checks if activation is supported by the gemm assembly dispatcher. More...
 

Detailed Description

Assembly kernel glue.

Definition at line 60 of file CpuGemmAssemblyDispatch.h.

Constructor & Destructor Documentation

◆ CpuGemmAssemblyDispatch()

Constructor.

Definition at line 689 of file CpuGemmAssemblyDispatch.cpp.

690  : _arm_gemm(nullptr)
691 {
692 }

◆ ~CpuGemmAssemblyDispatch()

Defautl destructor.

Member Function Documentation

◆ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE()

ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE ( CpuGemmAssemblyDispatch  )

◆ configure()

void configure ( const ITensorInfo a,
const ITensorInfo b,
const ITensorInfo c,
ITensorInfo d,
const AsmGemmInfo info 
)

If supported create a Compute Library function else fallback to the arm_gemm function.

Parameters
[in]aInput tensor (Matrix A)
[in]bInput tensor (Matrix B)
[in]cInput tensor (Matrix C) used to pass the bias for quantized calculations
[out]dOutput tensor to store the result of matrix multiplication. Data type supported: same as input0.
[in]infoGEMM meta-data

Definition at line 811 of file CpuGemmAssemblyDispatch.cpp.

References AsmGemmInfo::activation_info, ARM_COMPUTE_ERROR_ON_NULLPTR, arm_compute::test::validation::b, arm_compute::BFLOAT16, ITensorInfo::data_type(), arm_compute::F16, arm_compute::F32, arm_compute::test::validation::info, arm_compute::assembly_utils::map_to_arm_gemm_activation(), arm_compute::QASYMM8, arm_compute::QASYMM8_SIGNED, arm_compute::S32, arm_compute::S8, arm_compute::U8, and CpuGemmAssemblyDispatch::validate().

812 {
815 
816  //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
818  {
819  return;
820  }
821 
822  switch(a->data_type())
823  {
824  case DataType::F32:
825  create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
826  break;
827 #ifdef __aarch64__
828  case DataType::U8:
829  case DataType::QASYMM8:
830  if(d->data_type() == DataType::S32)
831  {
832  create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
833  }
834  else
835  {
836  create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
837  }
838  break;
839  case DataType::S8:
841  if(d->data_type() == DataType::S32)
842  {
843  create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
844  }
845  else
846  {
847  create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
848  }
849  break;
850 #endif /* __aarch64__ */
851 #if defined(ARM_COMPUTE_ENABLE_BF16)
852  case DataType::BFLOAT16:
853  create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
854  break;
855 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
856 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
857  case DataType::F16:
858  create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
859  break;
860 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
861  default:
862  break;
863  }
864 }
SimpleTensor< float > b
Definition: DFT.cpp:157
1 channel, 1 U8 per channel
1 channel, 1 F32 per channel
1 channel, 1 F16 per channel
1 channel, 1 S32 per channel
16-bit brain floating-point number
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
Indicates whether or not this function can be used to process the given parameters.
quantized, asymmetric fixed-point 8-bit number unsigned
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
quantized, asymmetric fixed-point 8-bit number signed
arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
Performs a mapping between Compute Library ActivationLayerInfo and the assembly Activation structure...
signed 8-bit number

◆ has_opt_impl()

Status has_opt_impl ( arm_compute::WeightFormat weight_format,
const ITensorInfo a,
const ITensorInfo b,
const ITensorInfo c,
const ITensorInfo d,
const AsmGemmInfo info 
)
static

Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.

This method has the same use of NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that the value of arm_compute::WeightFormat need to be passed via the parameter info.

Returns
a status.

Definition at line 694 of file CpuGemmAssemblyDispatch.cpp.

References AsmGemmInfo::activation_info, GemmTuner::args, ARM_COMPUTE_ERROR_ON_NULLPTR, ARM_COMPUTE_RETURN_ERROR_ON_MSG, ARM_COMPUTE_UNUSED, arm_compute::BFLOAT16, ci, IScheduler::cpu_info(), ITensorInfo::data_type(), arm_compute::F16, arm_compute::F32, AsmGemmInfo::fast_mode, AsmGemmInfo::fixed_format, Scheduler::get(), arm_compute::assembly_utils::map_to_arm_compute_weight_format(), arm_compute::assembly_utils::map_to_arm_gemm_activation(), arm_compute::assembly_utils::map_to_arm_gemm_weight_format(), IScheduler::num_threads(), arm_compute::QASYMM8, arm_compute::QASYMM8_SIGNED, arm_compute::S32, arm_compute::S8, arm_compute::U8, AsmGemmInfo::weight_format, and GemmConfig::weight_format.

Referenced by CpuGemmAssemblyDispatch::validate().

696 {
700  Params p = extract_parameters(a, b, d, info);
701  const CPUInfo &ci = NEScheduler::get().cpu_info();
702  unsigned int num_threads = NEScheduler::get().num_threads();
705  arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
706  arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
707  switch(a->data_type())
708  {
709  case DataType::F32:
710  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
711  "We could not find an optimized kernel for F32 input");
712  break;
713 #ifdef __aarch64__
714  case DataType::U8:
715  case DataType::QASYMM8:
716  if(d->data_type() == DataType::S32)
717  {
718  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
719  "We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
720  }
721  else
722  {
723  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
724  "We could not find an optimized kernel for U8 input and U8 output");
725  }
726  break;
727  case DataType::S8:
729  if(d->data_type() == DataType::S32)
730  {
731  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
732  "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
733  }
734  else
735  {
736  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
737  "We could not find an optimized kernel for S8 input and S32 output");
738  }
739  break;
740 #endif /* __aarch64__ */
741 #if defined(ARM_COMPUTE_ENABLE_BF16)
742  case DataType::BFLOAT16:
743  {
744  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
745  "We could not find an optimized kernel for BFLOAT16 input and F32 output");
746  break;
747  }
748 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
749 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
750  case DataType::F16:
751  ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
752  "We could not find an optimized kernel for BFLOAT16 input and F32 output");
753  break;
754 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
755  default:
756  ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
757  break;
758  }
759  expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
760 
761  return Status{};
762 }
SimpleTensor< float > b
Definition: DFT.cpp:157
1 channel, 1 U8 per channel
1 channel, 1 F32 per channel
CPUInfo & cpu_info()
Get CPU info.
Definition: IScheduler.cpp:41
const CPUInfo & ci
1 channel, 1 F16 per channel
1 channel, 1 S32 per channel
16-bit brain floating-point number
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
quantized, asymmetric fixed-point 8-bit number unsigned
WeightFormat weight_format
Definition: arm_gemm.hpp:112
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
Performs a mapping from Compute Library WeightFormat to the assembly WeightFormat enum...
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
quantized, asymmetric fixed-point 8-bit number signed
arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
Performs a mapping between Compute Library ActivationLayerInfo and the assembly Activation structure...
virtual unsigned int num_threads() const =0
Returns the number of threads that the SingleThreadScheduler has in its pool.
signed 8-bit number
arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
Performs a mapping from Assembly WeightFormat to the Compute Library WeightFormat enum...
static IScheduler & get()
Access the scheduler singleton.
Definition: Scheduler.cpp:94

◆ is_activation_supported()

bool is_activation_supported ( const ActivationLayerInfo activation)
static

Checks if activation is supported by the gemm assembly dispatcher.

Parameters
[in]activationActivation to check
Returns
True if activation is supported else false

Definition at line 805 of file CpuGemmAssemblyDispatch.cpp.

References arm_compute::assembly_utils::map_to_arm_gemm_activation(), Activation::None, and Activation::type.

Referenced by CpuGemmLowpMatrixMultiplyCore::configure().

806 {
809 }
arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
Performs a mapping between Compute Library ActivationLayerInfo and the assembly Activation structure...

◆ is_configured()

bool is_configured ( ) const

Was the function successfully configured ?

Returns
True if the function is configured and ready to run

Definition at line 872 of file CpuGemmAssemblyDispatch.cpp.

873 {
874  return _arm_gemm && _arm_gemm->is_configured();
875 }

◆ isVarWeightsKernel()

bool isVarWeightsKernel ( ) const
inline

Indicates if the convolution executes in variable weights mode.

Similar to CpuGemm::isVarWeightsKernel

Definition at line 130 of file CpuGemmAssemblyDispatch.h.

References arm_compute::test::validation::run().

131  {
132  return _arm_gemm && _arm_gemm->isVarWeightsKernel();
133  }

◆ prepare()

void prepare ( ITensorPack constants)
overridevirtual

Prepare the function for executing.

Any one off pre-processing step required by the function is handled here

Parameters
[in]constantsVector that contains the constants tensors.
Note
Prepare stage might not need all the function's buffers' backing memory to be available in order to execute

Reimplemented from INEOperator.

Definition at line 866 of file CpuGemmAssemblyDispatch.cpp.

References ARM_COMPUTE_ERROR_ON.

867 {
868  ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
869  _arm_gemm->prepare(tensors);
870 }
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466

◆ run()

void run ( ITensorPack tensors)
overridevirtual

Run the kernels contained in the function.

Parameters
[in]tensorsVector that contains the tensors to operate on.

Reimplemented from INEOperator.

Definition at line 877 of file CpuGemmAssemblyDispatch.cpp.

References ARM_COMPUTE_ERROR_ON.

878 {
879  ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
880  _arm_gemm->run(tensors);
881 }
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466

◆ validate()

Status validate ( const ITensorInfo a,
const ITensorInfo b,
const ITensorInfo c,
const ITensorInfo d,
const AsmGemmInfo info 
)
static

Indicates whether or not this function can be used to process the given parameters.

Parameters
[in]aInput tensor info (Matrix A)
[in]bInput tensor info (Matrix B)
[in]cInput tensor info (Matrix C) used to pass the bias for quantized calculations
[in]dOutput tensor to store the result of matrix multiplication. Data type supported: same as input0.
[in]infoGEMM meta-data
Returns
a status.

Definition at line 764 of file CpuGemmAssemblyDispatch.cpp.

References arm_compute::ANY, ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED, ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES, ARM_COMPUTE_RETURN_ERROR_ON_MSG, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR, ARM_COMPUTE_UNUSED, arm_compute::BFLOAT16, ITensorInfo::data_type(), ITensorInfo::element_size(), arm_compute::F16, arm_compute::F32, CpuGemmAssemblyDispatch::has_opt_impl(), arm_compute::is_data_type_quantized_per_channel(), arm_compute::QASYMM8, arm_compute::QASYMM8_SIGNED, arm_compute::QSYMM8_PER_CHANNEL, arm_compute::S32, arm_compute::S8, arm_compute::U32, arm_compute::U8, and AsmGemmInfo::weight_format.

Referenced by CpuGemmAssemblyDispatch::configure(), CpuGemm::configure(), CpuGemmDirectConv2d::validate(), CpuGemm::validate(), and CpuGemmLowpMatrixMultiplyCore::validate().

765 {
770 
771 #ifndef __aarch64__
772  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
773 #endif /* __aarch64__ */
778  if(is_data_type_quantized_per_channel(b->data_type()))
779  {
781  }
782  else
783  {
785  }
786  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
787  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
788  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input");
789  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
790  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
791  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
792  arm_compute::WeightFormat expected_weight_format;
793  const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
794  if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
795  {
796  // Correctness check: if the format expected by the kernel is
797  // not "any", make sure that the one found matches the format
798  // intended by the caller.
799  ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
800  "The format expected by the kernel does not correspond with the one requested by the user.");
801  }
802  return ret;
803 }
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:115
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor)
Definition: Validate.h:121
SimpleTensor< float > b
Definition: DFT.cpp:157
1 channel, 1 U8 per channel
1 channel, 1 F32 per channel
WeightFormat
Memory layouts for the weights tensor.
Definition: Types.h:1948
1 channel, 1 F16 per channel
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:159
1 channel, 1 S32 per channel
16-bit brain floating-point number
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
1 channel, 1 U32 per channel
bool is_data_type_quantized_per_channel(DataType dt)
Check if a given data type is of per channel type.
Definition: Utils.h:1107
quantized, asymmetric fixed-point 8-bit number unsigned
quantized, symmetric per channel fixed-point 8-bit number
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
quantized, asymmetric fixed-point 8-bit number signed
signed 8-bit number
static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
Indicates whether or not there is an optimal assembly implementation that can be used to process the ...

◆ workspace()

experimental::MemoryRequirements workspace ( ) const
overridevirtual

Return the memory requirements required by the workspace.

Reimplemented from INEOperator.

Definition at line 883 of file CpuGemmAssemblyDispatch.cpp.

References ARM_COMPUTE_ERROR_ON.

884 {
885  ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
886  return _arm_gemm->workspace();
887 }
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466

The documentation for this class was generated from the following files: