50 using namespace experimental;
51 using namespace misc::shape_calculator;
52 using namespace utils::cast;
56 : _weights_reshape_kernel(nullptr), _im2col_kernel(nullptr), _mm_gemm(nullptr), _mm_gemmlowp(nullptr), _col2im_kernel(nullptr), _activation_kernel(nullptr), _im2col_output(), _weights_reshaped(),
57 _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _append_bias(false), _is_prepared(false), _use_post_ops(false), _aux_mem(AuxTensorIdx::Count)
75 gemmlowp_output_stage,
95 _mm_gemmlowp = std::make_unique<ClGemmLowpMatrixMultiplyCore>();
96 _mm_gemmlowp->configure(compile_context, &tmp_src, weights, biases,
dst,
gemm_info);
101 auto mm_mem_req = _mm_gemmlowp->workspace();
102 for(
unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
104 _aux_mem[cont] = mm_mem_req[cont];
110 _mm_gemm = std::make_unique<ClGemm>();
111 _mm_gemm->configure(compile_context, &tmp_src, weights, biases,
dst, 1.0f, 1.0f,
gemm_info);
112 auto mm_mem_req = _mm_gemm->workspace();
113 for(
unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
115 _aux_mem[cont] = mm_mem_req[cont];
120 Status ClGemmConv2d::validate_mm(
const ITensorInfo *
src,
const ITensorInfo *weights,
const ITensorInfo *biases,
const ITensorInfo *
dst,
121 const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
int gemm_3d_depth,
bool skip_im2col,
const ActivationLayerInfo &
act_info,
const experimental::PostOpList<ITensorInfo *> &
post_ops)
125 const GEMMInfo &
gemm_info = GEMMInfo(
false,
131 gemmlowp_output_stage,
144 const QuantizationInfo input_quantization_info =
src->quantization_info();
145 const QuantizationInfo weights_quantization_info = weights->quantization_info();
147 std::unique_ptr<ITensorInfo> src_qa =
src->clone();
148 std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
149 src_qa->set_quantization_info(QuantizationInfo(input_quantization_info.uniform().scale, -input_quantization_info.uniform().offset));
150 weights_qa->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
180 const unsigned int num_kernels = weights->
dimension(idx_kernels);
191 _fuse_activation =
true;
192 _use_post_ops = conv2d_info.
post_ops.size() > 0;
198 unsigned int stride_x = 0;
199 unsigned int stride_y = 0;
203 unsigned int conv_w = 0;
204 unsigned int conv_h = 0;
212 unsigned int mat_weights_cols = num_kernels / conv2d_info.
num_groups;
215 _append_bias =
false;
217 _weights_reshape_kernel = std::make_unique<kernels::ClWeightsReshapeKernel>();
218 if(conv2d_info.
num_groups != 1 && biases !=
nullptr)
222 biases_to_use =
nullptr;
224 _weights_reshape_kernel->configure(compile_context, weights, biases, &_weights_reshaped, conv2d_info.
num_groups);
228 _weights_reshape_kernel->configure(compile_context, weights,
nullptr, &_weights_reshaped, conv2d_info.
num_groups);
235 _im2col_kernel = std::make_unique<opencl::kernels::ClIm2ColKernel>();
239 _im2col_kernel->configure(compile_context,
src, &_im2col_output,
Size2D(kernel_width, kernel_height), conv2d_info.
conv_info, _append_bias, conv2d_info.
dilation, conv2d_info.
num_groups);
246 gemm_input_to_use = &_im2col_output;
256 shape_gemm.
set(0, mat_weights_cols);
257 shape_gemm.
set(1, conv_w * conv_h);
263 gemm_output_to_use = &_gemm_output;
273 const auto output_quant_info = (
dst->total_size() == 0) ? iq_info : oq_info;
275 const unsigned int num_filters = (is_quantized_per_channel) ? num_kernels : 1;
291 auto min_activation = min_val.get<int32_t>();
292 auto max_activation = max_val.get<int32_t>();
294 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
295 ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
296 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
307 _fuse_activation =
false;
321 configure_mm(compile_context, gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, conv2d_info.
act_info, conv2d_info.
post_ops);
327 _col2im_kernel = std::make_unique<opencl::kernels::ClCol2ImKernel>();
330 _col2im_kernel->configure(compile_context, gemm_output_to_use,
dst,
Size2D(conv_w, conv_h), conv2d_info.
num_groups);
335 "Output shape does not match the expected one");
338 if(!_fuse_activation && !_use_post_ops)
340 _activation_kernel = std::make_unique<opencl::kernels::ClActivationKernel>();
341 _activation_kernel->configure(compile_context,
dst,
nullptr, conv2d_info.
act_info);
357 if(!is_quantized_per_channel)
375 const unsigned int num_kernels = weights->
dimension(idx_kernels);
387 bool fuse_activation =
true;
388 bool use_post_ops = conv2d_info.
post_ops.size() > 0;
394 "ClGemmConv2d does not support post ops with col2im or im2col operation");
397 if(biases !=
nullptr)
417 unsigned int conv_w = 0;
418 unsigned int conv_h = 0;
427 unsigned int mat_weights_cols = num_kernels / conv2d_info.
num_groups;
430 bool append_bias =
false;
432 if(conv2d_info.
num_groups != 1 && biases !=
nullptr)
436 biases_to_use =
nullptr;
445 weights_to_use = &weights_reshaped_info;
449 const Size2D kernel_dims(kernel_width, kernel_height);
457 gemm_input_to_use = &im2col_reshaped_info;
466 shape_gemm.
set(0, mat_weights_cols);
467 shape_gemm.
set(1, conv_w * conv_h);
470 info_gemm.set_quantization_info(
dst->quantization_info()).set_data_layout(
src->data_layout());
471 gemm_output_to_use = &info_gemm;
483 const auto output_quant_info = (
dst->total_size() == 0) ? iq_info : oq_info;
484 const unsigned int num_filters = (is_quantized_per_channel) ? num_kernels : 1;
494 int min_activation = 0;
495 int max_activation = 0;
497 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
498 ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
499 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
510 fuse_activation =
false;
523 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, conv2d_info.
act_info,
534 if(!fuse_activation && !use_post_ops)
549 auto gemm_input_to_use =
src;
550 auto gemm_output_to_use =
dst;
565 gemm_input_to_use = im2col_output.
get();
569 gemm_output_to_use = gemm_output.
get();
583 _mm_gemmlowp->run(pack_mm);
588 _mm_gemm->run(pack_mm);
604 if(!_fuse_activation && !_use_post_ops)
638 _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);