58 namespace weights_transformations
60 CLGEMMReshapeRHSMatrixKernelManaged::CLGEMMReshapeRHSMatrixKernelManaged()
96 _kernel->configure(compile_context, input, &_output, info);
123 if(
bool(gemm_kernel))
125 if(validate_gemm_kernel(gemm_kernel.gemm_type))
128 return gemm_kernel.gemm_type;
133 return gemm_kernel.gemm_type;
148 gemm_kernel_info.
lhs_info = lhs_info;
149 gemm_kernel_info.
rhs_info = rhs_info;
171 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
174 return { config.lhs_info, config.rhs_info };
179 return { config.lhs_info, config.rhs_info };
204 gemm_kernel_info.
lhs_info = lhs_info;
205 gemm_kernel_info.
rhs_info = rhs_info;
220 if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
223 return { config.lhs_info, config.rhs_info };
228 return { config.lhs_info, config.rhs_info };
234 : _memory_group(
std::move(memory_manager)),
235 _weights_manager(weights_manager),
245 _original_b(nullptr),
248 _reshape_b_only_on_first_run(false),
265 _mm_kernel->set_target(gpu_target);
270 _mm_kernel->configure(compile_context, a, b, c, output, alpha, beta,
false, reshape_info, gemm_info.
fp_mixed_precision(), gemm_info.
activation_info());
285 int mult_transpose1xW_width = 1;
286 int mult_interleave4x4_height = 1;
289 _reshape_lhs_kernel->set_target(gpu_target);
290 _mm_kernel->set_target(gpu_target);
294 mult_transpose1xW_width = 4;
295 mult_interleave4x4_height = 2;
301 rhs_info.
h0 = mult_transpose1xW_width;
308 lhs_info.
v0 = mult_interleave4x4_height;
312 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
false, gemm_info.
broadcast_bias());
317 _memory_group.
manage(&_tmp_a);
319 if(!_reshape_b_only_on_first_run && use_mm_b)
321 _memory_group.
manage(&_tmp_b);
325 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
331 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
332 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->
acquire(b, _reshape_rhs_kernel_managed.get()));
336 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
340 _mm_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta,
true, reshape_info, gemm_info.
fp_mixed_precision(), gemm_info.
activation_info());
347 if(!_reshape_b_only_on_first_run && use_mm_b)
376 _reshape_lhs_kernel->set_target(gpu_target);
377 _mm_kernel->set_target(gpu_target);
382 _memory_group.
manage(&_tmp_a);
384 if(!_reshape_b_only_on_first_run && use_mm_b)
386 _memory_group.
manage(&_tmp_b);
395 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(
auto_heuristics::CommonQuery{ gpu_target,
data_type, m, n, k, batch_size }, kernel_info, a->
info(), b->
info(),
403 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
404 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->
acquire(b, _reshape_rhs_kernel_managed.get()));
408 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
412 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
417 if(!_reshape_b_only_on_first_run && use_mm_b)
446 _mm_kernel->set_target(gpu_target);
451 if(!_reshape_b_only_on_first_run && use_mm_b)
453 _memory_group.
manage(&_tmp_b);
460 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(
auto_heuristics::CommonQuery{ gpu_target,
data_type, m, n, k, batch_size }, kernel_info, a->
info(), b->
info(),
461 c ==
nullptr ? nullptr : c->
info(), output->
info());
466 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
467 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->
acquire(b, _reshape_rhs_kernel_managed.get()));
471 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
479 kernel_info.has_pad_y =
false;
480 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
483 kernel_info.has_pad_y =
true;
484 _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
486 if(!_reshape_b_only_on_first_run && use_mm_b)
527 int mult_transpose1xW_width = 1;
528 int mult_interleave4x4_height = 1;
533 mult_transpose1xW_width = 4;
534 mult_interleave4x4_height = 2;
540 rhs_info.
h0 = mult_transpose1xW_width;
547 lhs_info.
v0 = mult_interleave4x4_height;
583 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
602 lhs_info = gemm_config.lhs_info;
603 rhs_info = gemm_config.rhs_info;
631 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
650 lhs_info = gemm_config.lhs_info;
651 rhs_info = gemm_config.rhs_info;
696 const ICLTensor *c_to_use = fuse_add_c ? c :
nullptr;
698 switch(_gemm_kernel_type)
702 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
707 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
712 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
717 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
734 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
745 const ITensorInfo *c_to_use = fuse_add_c ? c :
nullptr;
747 switch(gemm_kernel_type)
784 switch(_gemm_kernel_type)
796 if(!_reshape_b_only_on_first_run)
801 _weights_manager->
run(_original_b, _reshape_rhs_kernel_managed.get());
817 if(!_reshape_b_only_on_first_run)
822 _weights_manager->
run(_original_b, _reshape_rhs_kernel_managed.get());
835 if(!_reshape_b_only_on_first_run)
840 _weights_manager->
run(_original_b, _reshape_rhs_kernel_managed.get());
852 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
878 _weights_manager->
run(_original_b, _reshape_rhs_kernel_managed.get());
unsigned int top
top of the border
bool broadcast_bias
Flag used to broadcast the bias addition.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration of CLGEMMMatrixMultiplyResh...
void prepare() override
Prepare the function for executing.
~CLGEMM()
Default destructor.
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on default heuristics.
Descriptor used by the GEMM kernels.
void run() override
Run the kernels contained in the function.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
CLGEMM(std::shared_ptr< IMemoryManager > memory_manager=nullptr, IWeightsManager *weights_manager=nullptr)
Default constructor.
static CLScheduler & get()
Access the scheduler singleton.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration of CLGEMMMatrixMultiplyResh...
unsigned int v0
Number of vertical blocks of size (m0xk0) stored on the same output row.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
GPUTarget target() const
Get the target GPU.
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular...
GEMM reshape information class.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
Static function to check if given info will lead to a valid configuration of CLGEMMReshapeLHSMatrixKe...
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
bool fp_mixed_precision() const
Flag which specifies if a wider accumulator should be used.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Store the tensor's metadata.
CLTensorAllocator * allocator()
Return a pointer to the tensor's allocator.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
unsigned int bottom
bottom of the border
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
GPUTarget get_arch_from_target(GPUTarget target)
Helper function to get the GPU arch.
ActivationLayerInfo activation_info
Activation function to perform after the matrix multiplication.
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
CLGEMMKernelType
OpenCL GEMM kernel types.
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
GEMMLHSMatrixInfo lhs_info
LHS matrix information used to retrieve the number of rows processed by each thread.
bool transpose
True if the (k0xn0) block has to be transposed before been stored.
bool interleave
True if the v0 (m0xk0) blocks have to be interleaved in the output row.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision=false, const ActivationLayerInfo &activation_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of CLGEMMMatrixMultiplyKern...
Copyright (c) 2017-2021 Arm Limited.
void mark_as_unused() const
Marks a tensor as unused.
void manage(IMemoryManageable *obj) override
Sets a object to be managed by the given memory group.
bool transpose
True if the (m0xk0) block has to be transposed before been stored.
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
Native GEMM kernel with fixed block size.
bool are_weights_managed(const ITensor *weights)
Check if the weights are managed.
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
unsigned int m
Number of LHS rows.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
unsigned int n
Number of RHS columns.
OpenCL kernel to multiply two input matrices "A" and "B" and add a martix "C" if provided.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
unsigned int n0
Number of columns processed by the matrix multiplication.
OpenCL kernel to multiply matrices when only the input matrix RHS (input1) has been reshaped...
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
virtual PaddingSize padding() const =0
Padding of tensor.
cl::CommandQueue & queue()
Accessor for the associated CL command queue.
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Weights manager interface to handle weights transformations.
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
void enqueue(ICLKernel &kernel, bool flush=true)
Schedule the execution of the passed kernel if possible.
bool has_pad_y
Flag used to indicate if the input/output tensors have internal pad on the y direction.
void allocate() override
Allocate size specified by TensorInfo of OpenCL memory.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_RETURN_ERROR_MSG(...)
An error is returned with the given description.
Memory group resources scope handling class.
Interface for OpenCL tensor.
GEMMRHSMatrixInfo rhs_info
RHS matrix information used for reshaping the RHS matrix.
GPUTarget
Available GPU Targets.
void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info=GEMMInfo())
Initialise the kernel's inputs and output.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::string to_string(const ICLTensor &arg)
void free() override
Free allocated OpenCL memory.
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
bool interleave
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
bool is_zero(float a, float epsilon=0.00001f)
Checks if the input floating point number is 0.0f checking if the difference is within a range define...
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info=GEMMInfo())
Static function to check if given info will lead to a valid configuration of CLGEMM.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
unsigned int m0
Number of rows processed by the matrix multiplication.
ITensor * run(const ITensor *weights, ITransformWeights *weights_transform)
Run the reshape function.
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
OpenCL kernel to reshape the LHS matrix when performing the matrix multiplication.
OpenCL kernel to multiply matrices when both the input matrices LHS (input0) and RHS (input1) have be...
void tune_kernel_static(ICLKernel &kernel)
Tunes OpenCL kernel.
DataType
Available data types.
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration of CLGEMMReshapeRHSMatrixKe...
ITensor * acquire(const ITensor *weights, ITransformWeights *weights_transform)
Acquire the requested reshape tensor of the selected weights.