76 if(validate_gemm_kernel(gemm_kernel.gemm_type))
79 return gemm_kernel.gemm_type;
84 return gemm_kernel.gemm_type;
87 inline bool validate_lhs_rhs_info_native(
const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const ITensorInfo *a,
const ITensorInfo *
b,
const GEMMReshapeInfo &reshape_info)
90 TensorInfo mm_result_s32_info{};
106 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_native(
auto_heuristics::CommonQuery query,
const ITensorInfo *a,
const ITensorInfo *
b,
const GEMMReshapeInfo &reshape_info)
111 if(validate_lhs_rhs_info_native(config.lhs_info, config.rhs_info, a,
b, reshape_info))
114 return { config.lhs_info, config.rhs_info };
119 return { config.lhs_info, config.rhs_info };
123 inline bool validate_lhs_rhs_info_reshaped_only_rhs(
const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const ITensorInfo *a,
const ITensorInfo *
b,
const ITensorInfo *output,
124 unsigned int m,
unsigned int n,
unsigned int k,
bool reinterpret_input_as_3d,
int depth_output_gemm3d)
127 TensorInfo tmp_b_info{};
139 GEMMKernelInfo gemm_kernel_info;
140 gemm_kernel_info.m = m;
141 gemm_kernel_info.n = n;
142 gemm_kernel_info.k = k;
143 gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
144 gemm_kernel_info.depth_output_gemm3d = depth_output_gemm3d;
145 gemm_kernel_info.lhs_info = lhs_info;
146 gemm_kernel_info.rhs_info = rhs_info;
148 TensorInfo output_info_copy(*output);
158 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(
auto_heuristics::CommonQuery query,
bool reinterpret_input_as_3d,
int depth_output_gemm3d,
159 const ITensorInfo *a,
160 const ITensorInfo *
b,
const ITensorInfo *output)
165 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a,
b, output, query.
m, query.
n, query.
k, reinterpret_input_as_3d, depth_output_gemm3d))
168 return { config.lhs_info, config.rhs_info };
173 return { config.lhs_info, config.rhs_info };
191 : _memory_group(std::move(memory_manager)),
205 _gemm_output_stage_multipliers(),
206 _gemm_output_stage_shifts(),
208 _original_b(nullptr),
212 _is_gemm_reshaped(true),
213 _reshape_b_only_on_first_run(false),
215 _run_output_stage(false),
216 _convert_to_qasymm8(false),
217 _run_offset_contribution(false)
233 _is_prepared =
false;
242 _b_offset = _convert_to_qasymm8 ? -128 :
b->info()->quantization_info().uniform().offset;
248 _mm_native_kernel->set_target(gpu_target);
249 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
259 const unsigned int n =
b->info()->dimension(0);
264 const auto reshape_info =
GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
269 if(_convert_to_qasymm8)
278 const ICLTensor *matrix_b = _convert_to_qasymm8 ? &_qasymm8_weights :
b;
279 if(_is_gemm_reshaped)
283 if(!_reshape_b_only_on_first_run)
285 _memory_group.
manage(&_tmp_b);
292 a->
info(), _convert_to_qasymm8 ? _qasymm8_weights.
info() :
b->info(), output->
info());
295 _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights :
b, &_tmp_b, rhs_info);
306 if(!_reshape_b_only_on_first_run)
308 _memory_group.
manage(&_vector_sum_col);
312 _mtx_b_reduction_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights :
b, &_vector_sum_col, reduction_info);
320 _memory_group.
manage(&_vector_sum_row);
323 _mtx_a_reduction_kernel->configure(compile_context, a, &_vector_sum_row, reduction_info);
327 gemm_kernel_info.
m = m;
328 gemm_kernel_info.
n = n;
329 gemm_kernel_info.
k = k;
332 gemm_kernel_info.
lhs_info = lhs_info;
333 gemm_kernel_info.
rhs_info = rhs_info;
334 gemm_kernel_info.
a_offset = _a_offset;
335 gemm_kernel_info.
b_offset = _b_offset;
353 _mm_reshaped_only_rhs_kernel->configure(compile_context, _matrix_a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ?
nullptr : &_vector_sum_col,
354 _b_offset == 0 ?
nullptr : &_vector_sum_row, c, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
358 _run_output_stage =
true;
360 _memory_group.
manage(&_mm_result_s32);
362 if(_is_gemm_reshaped)
364 _mm_reshaped_only_rhs_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, gemm_kernel_info);
371 _matrix_a->
info(), _convert_to_qasymm8 ? _qasymm8_weights.
info() : matrix_b->
info(), reshape_info);
374 _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, reshape_info);
376 _offset_contribution_output_stage_kernel->configure(compile_context, &_mm_result_s32, _a_offset == 0 ?
nullptr : &_vector_sum_col, _b_offset == 0 ?
nullptr : &_vector_sum_row, c, output,
378 _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
386 _gemm_output_stage_multipliers.
map();
387 _gemm_output_stage_shifts.
map();
390 _gemm_output_stage_multipliers.
unmap();
391 _gemm_output_stage_shifts.
unmap();
395 _run_offset_contribution =
true;
396 if(_is_gemm_reshaped)
399 _mm_reshaped_only_rhs_kernel->configure(compile_context, _matrix_a, matrix_b, output, gemm_kernel_info);
406 a->
info(), _convert_to_qasymm8 ? _qasymm8_weights.
info() :
b->info(), reshape_info);
409 _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, output, lhs_info, rhs_info, reshape_info);
413 _offset_contribution_kernel->configure(compile_context, output, _a_offset == 0 ?
nullptr : &_vector_sum_col, _b_offset == 0 ?
nullptr : &_vector_sum_row, c, a->
info()->
dimension(0), _a_offset,
418 if(_is_gemm_reshaped)
420 if(!_reshape_b_only_on_first_run)
426 if(_a_offset != 0 && !_reshape_b_only_on_first_run)
448 int32_t b_offset =
b->quantization_info().uniform().offset;
461 const unsigned int n =
b->dimension(0);
463 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
473 if(convert_to_qasymm8)
482 matrix_b_info = &tmp_b_info;
488 lhs_info = res.lhs_info;
489 rhs_info = res.rhs_info;
519 gemm_kernel_info.
m = m;
520 gemm_kernel_info.
n = n;
521 gemm_kernel_info.
k = k;
524 gemm_kernel_info.
lhs_info = lhs_info;
525 gemm_kernel_info.
rhs_info = rhs_info;
526 gemm_kernel_info.
a_offset = a_offset;
527 gemm_kernel_info.
b_offset = b_offset;
541 a_offset == 0 ?
nullptr : &info_vector_sum_col,
542 b_offset == 0 ?
nullptr : &info_vector_sum_row,
544 &gemm_output_stage_multipliers_shifts_info,
545 &gemm_output_stage_multipliers_shifts_info));
568 lhs_info = res.lhs_info;
569 rhs_info = res.rhs_info;
577 a_offset == 0 ?
nullptr : &info_vector_sum_col,
578 b_offset == 0 ?
nullptr : &info_vector_sum_row,
582 gemmlowp_output_stage,
583 &gemm_output_stage_multipliers_shifts_info,
584 &gemm_output_stage_multipliers_shifts_info));
599 lhs_info = res.lhs_info;
600 rhs_info = res.rhs_info;
610 a_offset == 0 ?
nullptr : &info_vector_sum_col,
611 b_offset == 0 ?
nullptr : &info_vector_sum_row,
613 a_offset, b_offset));
626 if(_is_gemm_reshaped)
628 if(!_reshape_b_only_on_first_run)
636 if(_a_offset != 0 && !_reshape_b_only_on_first_run)
648 if(_is_gemm_reshaped)
656 if(_run_output_stage)
661 if(_run_offset_contribution)
672 if(_convert_to_qasymm8)
678 if(_is_gemm_reshaped && _reshape_b_only_on_first_run)
689 if(_a_offset != 0 && _reshape_b_only_on_first_run)
uint8_t * ptr_to_element(const Coordinates &id) const
Return a pointer to the element at the passed coordinates.
Quantize using a fixed point multiplication.
TensorInfo * info() const override
Interface to be implemented by the child class to return the tensor's metadata.
Descriptor used by the GEMM kernels.
void prepare() override
Prepare the function for executing.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
static Status validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, int32_t a_offset, int32_t b_offset)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpOffsetContribu...
static Status validate(const ITensorInfo *input, const ITensorInfo *output, ConvertPolicy policy, uint32_t shift)
Static function to check if given info will lead to a valid configuration of CLDepthConvertLayerKerne...
static CLScheduler & get()
Access the scheduler singleton.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
GPUTarget target() const
Get the target GPU.
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular,...
GEMM reshape information class.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
OpenCL kernel to multiply matrices with QASYMM8/QASYMM8_SIGNED data type.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
Calculate the matrix multiplication output shape of two tensors.
bool is_used() const
Flags if the tensor is used or not.
GEMMLowpOutputStageInfo gemmlowp_output_stage() const
GEMMLowp output stage.
TensorShape compute_reductionA_shape(const ITensorInfo &b)
Calculate the reductionA shape used in GEMMLowp.
static Status validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpMatrixBReducti...
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
OpenCL kernel used to add the offset contribution after the matrix multiplication and perform the out...
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Store the tensor's metadata.
CLTensorAllocator * allocator()
Return a pointer to the tensor's allocator.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, const GEMMInfo &gemm_info=GEMMInfo())
Initialise the kernel's inputs, output.
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
void run() override
Run the kernels contained in the function.
GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
CLGEMMKernelType
OpenCL GEMM kernel types.
bool is_data_type_quantized_symmetric(DataType dt)
Check if a given data type is of symmetric quantized type.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
GEMMLowpOutputStageType type
GEMMLowp output stage type.
OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
GEMMLHSMatrixInfo lhs_info
LHS matrix information used to retrieve the number of rows processed by each thread.
void init(const TensorInfo &input, size_t alignment=0)
Initialize a tensor based on the passed TensorInfo.
Copyright (c) 2017-2021 Arm Limited.
bool is_b_reshaped() const
Flag which specifies if the matrix B has been reshaped.
void map(bool blocking=true)
Enqueue a map operation of the allocated buffer.
bool is_quantized_per_channel
GEMMLowp quantized per-channel flag.
std::vector< int32_t > gemmlowp_shifts
GEMMLowp output stage multiplier used for quantizing to QASYMM8.
void mark_as_unused() const
Marks a tensor as unused.
1 channel, 1 S32 per channel
void manage(IMemoryManageable *obj) override
Sets a object to be managed by the given memory group.
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
unsigned int m
Number of LHS rows.
GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query)
Select gemm config based on default heuristics.
static Status validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, const ITensorInfo *output, int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage, const ITensorInfo *output_multipliers, const ITensorInfo *output_shifts)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpOffsetContribu...
std::string to_string(const ROIPoolingLayerInfo &pool_info)
Formatted output of the ROIPoolingInfo type.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info=GEMMInfo())
Static function to check if given info will lead to a valid configuration of CLGEMMLowpMatrixMultiply...
unsigned int n
Number of RHS columns.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMKernelInfo &gemm_info, const ITensorInfo *vector_sum_col=nullptr, const ITensorInfo *vector_sum_row=nullptr, const ITensorInfo *bias=nullptr, const ITensorInfo *output_multipliers=nullptr, const ITensorInfo *output_shifts=nullptr)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpMatrixMultiply...
bool is_data_type_quantized_per_channel(DataType dt)
Check if a given data type is of per channel type.
GEMM RHS (Right Hand Side) matrix information.
int32_t b_offset
Offset to be added to each element of the matrix B.
~CLGEMMLowpMatrixMultiplyCore()
Default destructor.
quantized, asymmetric fixed-point 8-bit number unsigned
std::vector< int32_t > gemmlowp_multipliers
GEMMLowp output stage multiplier used for quantizing to QASYMM8.
UniformQuantizationInfo uniform() const
Return per layer quantization info.
GEMMLowpOutputStageInfo output_stage
GEMMLowp output stage information.
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
GEMMLowp output stage info.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
cl::CommandQueue & queue()
Accessor for the associated CL command queue.
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
virtual QuantizationInfo quantization_info() const =0
Get the quantization settings (scale and offset) of the tensor.
void enqueue(ICLKernel &kernel, bool flush=true)
Schedule the execution of the passed kernel if possible.
quantized, symmetric fixed-point 8-bit number
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
bool is_a_reshaped() const
Flag which specifies if the matrix A has been reshaped.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpMatrixMultiply...
quantized, symmetric per channel fixed-point 8-bit number
TensorShape compute_reductionB_shape(const ITensorInfo &a)
Calculate the reductionB shape used in GEMMLowp.
CLGEMMLowpMatrixMultiplyCore(std::shared_ptr< IMemoryManager > memory_manager=nullptr)
Constructor.
int32_t a_offset
Offset to be added to each element of the matrix A.
unsigned int k
Number of rows for the rhs matrix.
void allocate() override
Allocate size specified by TensorInfo of OpenCL memory.
Interface for the depth conversion kernel.
Memory group resources scope handling class.
Interface for OpenCL tensor.
GEMMRHSMatrixInfo rhs_info
RHS matrix information used for reshaping the RHS matrix.
virtual size_t total_size() const =0
Returns the total size of the tensor in bytes.
GPUTarget
Available GPU Targets.
Native GEMM kernel with configurable block size.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
Static function to check if given info will lead to a valid configuration of CLGEMMLowpMatrixAReducti...
unsigned int m
Number of rows for the lhs matrix.
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
quantized, asymmetric fixed-point 8-bit number signed
unsigned int n
Number of columns for the rhs matrix.
DataType output_data_type
Output tensor data type to use if the output is not initialized.
void unmap()
Enqueue an unmap operation of the allocated and mapped buffer.
OpenCL kernel used to add the offset contribution after the matrix multiplication.
OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration of CLGEMMReshapeRHSMatrixKe...
OpenCL kernel to multiply matrices with QASYMM8 data type when only the input matrix RHS (input1) has...