Compute Library
 21.02
NEGEMMAssemblyDispatch Class Reference

Assembly kernel glue. More...

#include <NEGEMMAssemblyDispatch.h>

Collaboration diagram for NEGEMMAssemblyDispatch:
[legend]

Data Structures

class  IFallback
 

Public Member Functions

 NEGEMMAssemblyDispatch (std::shared_ptr< IMemoryManager > memory_manager=nullptr, IWeightsManager *weights_manager=nullptr)
 Constructor. More...
 
 NEGEMMAssemblyDispatch (const NEGEMMAssemblyDispatch &)=delete
 Prevent instances of this class from being copy constructed. More...
 
NEGEMMAssemblyDispatchoperator= (const NEGEMMAssemblyDispatch &)=delete
 Prevent instances of this class from being copied. More...
 
 NEGEMMAssemblyDispatch (NEGEMMAssemblyDispatch &&)=default
 
NEGEMMAssemblyDispatchoperator= (NEGEMMAssemblyDispatch &&)=default
 
 ~NEGEMMAssemblyDispatch ()=default
 
void configure (const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info)
 If supported create a Compute Library function else fallback to the arm_gemm function. More...
 
bool is_configured () const
 Was the function successfully configured ? More...
 
void prepare () override
 Prepare the function for executing. More...
 
void run () override
 Run the kernels contained in the function. More...
 
- Public Member Functions inherited from IFunction
virtual ~IFunction ()=default
 Destructor. More...
 

Static Public Member Functions

static Status validate (const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
 Indicates whether or not this function can be used to process the given parameters. More...
 
static bool is_activation_supported (const ActivationLayerInfo &activation)
 Checks if activation is supported by the gemm assembly dispatcher. More...
 

Detailed Description

Assembly kernel glue.

Definition at line 58 of file NEGEMMAssemblyDispatch.h.

Constructor & Destructor Documentation

◆ NEGEMMAssemblyDispatch() [1/3]

NEGEMMAssemblyDispatch ( std::shared_ptr< IMemoryManager memory_manager = nullptr,
IWeightsManager weights_manager = nullptr 
)

Constructor.

Definition at line 745 of file NEGEMMAssemblyDispatch.cpp.

746  : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
747 {
748 }

◆ NEGEMMAssemblyDispatch() [2/3]

Prevent instances of this class from being copy constructed.

◆ NEGEMMAssemblyDispatch() [3/3]

◆ ~NEGEMMAssemblyDispatch()

~NEGEMMAssemblyDispatch ( )
default

Member Function Documentation

◆ configure()

void configure ( const ITensor a,
const ITensor b,
const ITensor c,
ITensor d,
const AsmGemmInfo info 
)

If supported create a Compute Library function else fallback to the arm_gemm function.

Parameters
[in]aInput tensor (Matrix A)
[in]bInput tensor (Matrix B)
[in]cInput tensor (Matrix C) used to pass the bias for quantized calculations
[out]dOutput tensor to store the result of matrix multiplication. Data type supported: same as input0.
[in]infoGEMM meta-data

Definition at line 787 of file NEGEMMAssemblyDispatch.cpp.

References AsmGemmInfo::activation_info, ARM_COMPUTE_ERROR_ON_NULLPTR, arm_compute::test::validation::b, arm_compute::BFLOAT16, ITensorInfo::data_type(), arm_compute::F16, arm_compute::F32, ITensor::info(), arm_compute::test::validation::info, arm_compute::QASYMM8, arm_compute::QASYMM8_SIGNED, arm_compute::S32, arm_compute::S8, arm_compute::U8, and NEGEMMAssemblyDispatch::validate().

788 {
790  arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
791 
792  //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
793  if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info))
794  {
795  return;
796  }
797 
798  switch(a->info()->data_type())
799  {
800  case DataType::F32:
801  create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
802  break;
803 #ifdef __aarch64__
804  case DataType::U8:
805  case DataType::QASYMM8:
806  if(d->info()->data_type() == DataType::S32)
807  {
808  create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
809  }
810  else
811  {
812  create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
813  }
814  break;
815  case DataType::S8:
817  if(d->info()->data_type() == DataType::S32)
818  {
819  create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
820  }
821  else
822  {
823  create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
824  }
825  break;
826 #endif /* __aarch64__ */
827 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
828  case DataType::BFLOAT16:
829  create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
830  break;
831 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
832 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
833  case DataType::F16:
834  create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
835  break;
836 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
837  default:
838  break;
839  }
840 }
SimpleTensor< float > b
Definition: DFT.cpp:157
1 channel, 1 U8 per channel
1 channel, 1 F32 per channel
arm_compute::ActivationLayerInfo::ActivationFunction Activation
Constant TensorID specifying an equivalent of null tensor.
Definition: Types.h:70
1 channel, 1 F16 per channel
1 channel, 1 S32 per channel
16-bit brain floating-point number
quantized, asymmetric fixed-point 8-bit number unsigned
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
Indicates whether or not this function can be used to process the given parameters.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
quantized, asymmetric fixed-point 8-bit number signed
signed 8-bit number

◆ is_activation_supported()

bool is_activation_supported ( const ActivationLayerInfo activation)
static

Checks if activation is supported by the gemm assembly dispatcher.

Parameters
[in]activationActivation to check
Returns
True if activation is supported else false

Definition at line 781 of file NEGEMMAssemblyDispatch.cpp.

References tf_frozen_model_extractor::None.

Referenced by NEGEMM::configure(), and NEGEMMLowpMatrixMultiplyCore::configure().

782 {
783  arm_gemm::Activation act = map_to_arm_gemm_activation(activation);
784  return act.type != arm_gemm::Activation::Type::None;
785 }
arm_compute::ActivationLayerInfo::ActivationFunction Activation
Constant TensorID specifying an equivalent of null tensor.
Definition: Types.h:70

◆ is_configured()

bool is_configured ( ) const

Was the function successfully configured ?

Returns
True if the function is configured and ready to run

Definition at line 848 of file NEGEMMAssemblyDispatch.cpp.

849 {
850  return _arm_gemm != nullptr && _arm_gemm->is_configured();
851 }

◆ operator=() [1/2]

NEGEMMAssemblyDispatch& operator= ( const NEGEMMAssemblyDispatch )
delete

Prevent instances of this class from being copied.

◆ operator=() [2/2]

NEGEMMAssemblyDispatch& operator= ( NEGEMMAssemblyDispatch &&  )
default

◆ prepare()

void prepare ( )
overridevirtual

Prepare the function for executing.

Any one off pre-processing step required by the function is handled here

Note
Prepare stage might not need all the function's buffers' backing memory to be available in order to execute

Reimplemented from IFunction.

Definition at line 842 of file NEGEMMAssemblyDispatch.cpp.

References ARM_COMPUTE_ERROR_ON.

843 {
844  ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
845  _arm_gemm->prepare();
846 }
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466

◆ run()

void run ( )
overridevirtual

Run the kernels contained in the function.

For Neon kernels:

  • Multi-threading is used for the kernels which are parallelisable.
  • By default std::thread::hardware_concurrency() threads are used.
Note
CPPScheduler::set_num_threads() can be used to manually set the number of threads

For OpenCL kernels:

  • All the kernels are enqueued on the queue associated with CLScheduler.
  • The queue is then flushed.
Note
The function will not block until the kernels are executed. It is the user's responsibility to wait.
Will call prepare() on first run if hasn't been done

Implements IFunction.

Definition at line 853 of file NEGEMMAssemblyDispatch.cpp.

References ARM_COMPUTE_ERROR_ON.

854 {
855  MemoryGroupResourceScope scope_mg(_memory_group);
856 
857  ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
858  _arm_gemm->run();
859 }
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466

◆ validate()

Status validate ( const ITensorInfo a,
const ITensorInfo b,
const ITensorInfo c,
const ITensorInfo d,
const AsmGemmInfo info 
)
static

Indicates whether or not this function can be used to process the given parameters.

Parameters
[in]aInput tensor info (Matrix A)
[in]bInput tensor info (Matrix B)
[in]cInput tensor info (Matrix C) used to pass the bias for quantized calculations
[in]dOutput tensor to store the result of matrix multiplication. Data type supported: same as input0.
[in]infoGEMM meta-data
Returns
a status.

Definition at line 750 of file NEGEMMAssemblyDispatch.cpp.

References ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED, ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES, ARM_COMPUTE_RETURN_ERROR_ON_MSG, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR, ARM_COMPUTE_UNUSED, arm_compute::BFLOAT16, ITensorInfo::data_type(), ITensorInfo::element_size(), arm_compute::F16, arm_compute::F32, arm_compute::is_data_type_quantized_per_channel(), arm_compute::QASYMM8, arm_compute::QASYMM8_SIGNED, arm_compute::QSYMM8_PER_CHANNEL, arm_compute::S32, arm_compute::S8, arm_compute::U32, and arm_compute::U8.

Referenced by NEGEMMAssemblyDispatch::configure(), NEGEMM::configure(), NEGEMMConv2d::validate(), NEGEMM::validate(), and NEGEMMLowpMatrixMultiplyCore::validate().

751 {
756 
757 #ifndef __aarch64__
758  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
759 #endif /* __aarch64__ */
764  if(is_data_type_quantized_per_channel(b->data_type()))
765  {
767  }
768  else
769  {
771  }
772  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
773  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
774  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input");
775  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
776  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
777  ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
778  return Status{};
779 }
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:108
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor)
Definition: Validate.h:114
SimpleTensor< float > b
Definition: DFT.cpp:157
1 channel, 1 U8 per channel
1 channel, 1 F32 per channel
1 channel, 1 F16 per channel
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:163
1 channel, 1 S32 per channel
16-bit brain floating-point number
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
1 channel, 1 U32 per channel
bool is_data_type_quantized_per_channel(DataType dt)
Check if a given data type is of per channel type.
Definition: Utils.h:1245
quantized, asymmetric fixed-point 8-bit number unsigned
quantized, symmetric per channel fixed-point 8-bit number
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:545
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
quantized, asymmetric fixed-point 8-bit number signed
signed 8-bit number

The documentation for this class was generated from the following files: