140 void initialize_reshaped_weight_info(
const ITensorInfo &weights, ITensorInfo &reshaped_weights)
150 TensorShape collapsed_weights = weights.tensor_shape();
151 collapsed_weights.collapse(3);
152 reshaped_weights.set_tensor_shape(collapsed_weights);
157 CpuGemmConv2d::WeightTransformMethod CpuGemmConv2d::get_wt_method(
const ITensorInfo &weights)
162 return WeightTransformMethod::FusedReshapeAndTranspose;
164 return has_holes(weights) ? WeightTransformMethod::ReshapeThenTranspose
165 : WeightTransformMethod::ReinterpretThenTranspose;
168 CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(
const ITensorInfo *
src,
169 const ITensorInfo *weights,
171 const Size2D &dilation,
172 const ActivationLayerInfo &
act_info)
177 const unsigned int kernel_width = weights->dimension(
idx_width);
178 const unsigned int kernel_height = weights->dimension(
idx_height);
179 unsigned int conv_w = 0;
180 unsigned int conv_h = 0;
188 const bool skip_col2im =
190 (bool(CpuGemmConv2d::validate_gemm3d(
src, weights,
act_info, conv_h,
true))));
198 const bool skip_col2im =
200 (bool(CpuGemmConv2d::validate_gemm3d(
src, weights,
act_info, conv_h,
false))));
203 return {
false,
true};
208 return {
false,
false};
211 CpuGemmConv2d::CpuGemmConv2d()
212 : _weights_reshape(nullptr),
213 _weights_reshape_and_transpose_kernel(nullptr),
226 _is_quantized(false),
228 _wt_method(WeightTransformMethod::ReshapeThenTranspose),
230 _aux_mem(AuxTensorIdx::Count)
240 bool enable_fast_math,
247 _skip_im2col, fixed_format, weight_format));
250 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
251 ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
252 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU};
257 TensorInfo tmp_weights{*weights};
260 const QuantizationInfo iqinfo =
src->quantization_info();
262 const QuantizationInfo oqinfo = (
dst->total_size() == 0) ? iqinfo :
dst->quantization_info();
263 const UniformQuantizationInfo uiqinfo = iqinfo.uniform();
264 const UniformQuantizationInfo uoqinfo = oqinfo.uniform();
267 tmp_src.set_quantization_info(QuantizationInfo(uiqinfo.scale, -uiqinfo.offset));
270 const UniformQuantizationInfo uwqinfo = wqinfo.uniform();
271 tmp_weights.set_quantization_info(QuantizationInfo(uwqinfo.scale, -uwqinfo.offset));
275 PixelValue type_min{};
276 PixelValue type_max{};
278 int32_t min_activation = type_min.get<int32_t>();
279 int32_t max_activation = type_max.get<int32_t>();
281 if (supported_acts.count(
act_info.activation()) != 0)
294 _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
295 _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases,
dst,
296 GEMMInfo(
false,
false,
true, gemm_3d_depth, _skip_im2col,
false,
output_info,
false,
297 enable_fast_math,
false,
act_info, fixed_format, weight_format,
300 auto mm_mem_req = _mm_gemmlowp->workspace();
301 for (
unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
303 _aux_mem[cont] = mm_mem_req[cont];
309 const GEMMInfo &gemm_info =
310 GEMMInfo(
false,
false,
true , gemm_3d_depth,
311 _skip_im2col ,
false,
312 GEMMLowpOutputStageInfo(),
false, enable_fast_math,
false,
act_info, fixed_format, weight_format,
316 _mm_gemm = std::make_unique<CpuGemm>();
317 _mm_gemm->configure(
src, weights, biases,
dst, 1.0f, 1.0f, gemm_info);
318 auto mm_mem_req = _mm_gemm->workspace();
319 for (
unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
321 _aux_mem[cont] = mm_mem_req[cont];
326 Status CpuGemmConv2d::validate_mm(
const ITensorInfo *
src,
327 const ITensorInfo *weights,
328 const ITensorInfo *biases,
329 const ITensorInfo *
dst,
330 const ActivationLayerInfo &
act_info,
331 bool enable_fast_math,
339 const bool is_activation_enabled =
act_info.enabled();
345 const QuantizationInfo &iqinfo =
src->quantization_info();
346 const QuantizationInfo &wqinfo = weights->quantization_info();
347 const QuantizationInfo &oqinfo = (
dst->total_size() == 0) ? iqinfo :
dst->quantization_info();
348 const UniformQuantizationInfo uoqinfo = oqinfo.uniform();
351 PixelValue type_min{};
352 PixelValue type_max{};
354 int32_t min_activation = type_min.get<int32_t>();
355 int32_t max_activation = type_max.get<int32_t>();
357 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
358 ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
359 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU};
360 if (is_activation_enabled && supported_acts.count(
act_info.activation()) != 0)
374 std::unique_ptr<ITensorInfo> input_qa =
src->clone();
375 std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
376 input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset));
377 weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset));
380 GEMMInfo(
false,
false,
true, gemm_3d_depth, skip_im2col,
false,
387 const GEMMInfo gemm_info =
388 GEMMInfo(
false,
false,
true , gemm_3d_depth,
390 GEMMLowpOutputStageInfo(),
false, enable_fast_math,
false,
act_info, fixed_format, weight_format,
399 Status CpuGemmConv2d::validate_gemm3d(
const ITensorInfo *
input_info,
401 const ActivationLayerInfo &
act_info,
406 const unsigned int mult_y = skip_im2col ? 1
U : gemm_3d_depth;
407 const unsigned int mult_z = skip_im2col ? gemm_3d_depth : 1
U;
410 const TensorInfo dummy_input_info(TensorShape(4U, 4U * mult_y, 1U * mult_z), 1,
data_type,
412 const TensorInfo dummy_weights_info(TensorShape(4U, 4U), 1,
data_type,
weights_info->quantization_info());
413 const TensorInfo dummy_output_info(TensorShape(4U, 4U, gemm_3d_depth), 1,
data_type,
416 return validate_mm(&dummy_input_info, &dummy_weights_info,
nullptr, &dummy_output_info,
act_info,
false,
417 gemm_3d_depth, skip_im2col);
428 bool enable_fast_math,
458 unsigned int conv_w = 0;
459 unsigned int conv_h = 0;
464 "Output shape does not match the expected one");
467 const CpuGemmConv2d::SkipInfo skip_info =
469 _skip_im2col = skip_info.skip_im2col;
470 _skip_col2im = skip_info.skip_col2im;
473 unsigned int stride_x = 0;
474 unsigned int stride_y = 0;
475 std::tie(stride_x, stride_y) =
conv_info.stride();
478 initialize_reshaped_weight_info(*weights, _weights_reshaped);
484 unsigned int input_pad_right = 0;
491 _im2col_kernel = std::make_unique<kernels::CpuIm2ColKernel>();
492 _im2col_kernel->configure(
src, &_im2col_output,
Size2D(kernel_width, kernel_height),
conv_info,
false, dilation,
496 gemm_input_to_use = &_im2col_output;
499 const unsigned int mat_weights_cols = weights->
dimension(idx_kernels);
509 shape_gemm.
set(0, mat_weights_cols);
510 shape_gemm.
set(1, conv_w * conv_h);
512 _gemm_output =
TensorInfo(shape_gemm, 1, output_data_type);
517 gemm_output_to_use = &_gemm_output;
526 gemm_output_to_use = &_gemm_output_3d;
531 const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
559 configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use,
act_info, enable_fast_math,
560 gemm_3d_depth, fixed_format,
weights_info.weight_format());
563 _run_wt = !isVarWeightsKernel();
568 _col2im_kernel = std::make_unique<kernels::CpuCol2ImKernel>();
569 _col2im_kernel->configure(gemm_output_to_use,
dst,
Size2D(conv_w, conv_h));
574 _reshape = std::make_unique<CpuReshape>();
575 _reshape->configure(gemm_output_to_use,
dst);
579 _aux_mem[Im2ColOutput] =
591 bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0;
592 gemm_trans_wei = _mm_gemm !=
nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei;
593 gemm_trans_wei = _mm_gemmlowp !=
nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei;
611 const bool enable_fast_math)
618 unsigned int conv_w = 0;
619 unsigned int conv_h = 0;
623 const CpuGemmConv2d::SkipInfo skip_info =
626 const bool skip_im2col = skip_info.skip_im2col;
627 const bool skip_col2im = skip_info.skip_col2im;
628 const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
654 bool enable_fast_math,
690 const bool append_bias =
false;
695 unsigned int conv_w = 0;
696 unsigned int conv_h = 0;
702 const CpuGemmConv2d::SkipInfo skip_info =
704 const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
710 if (biases !=
nullptr)
728 unsigned int mat_weights_cols = weights->
dimension(idx_kernels);
729 unsigned int mat_weights_rows =
733 initialize_reshaped_weight_info(*weights, weights_reshaped_info);
736 weights_to_use = &weights_reshaped_info;
741 int input_pad_right = 0;
747 (weights->
dimension(idx_channel) + input_pad_right);
753 shape_im2col.
set(0, mat_weights_rows);
754 shape_im2col.
set(1, conv_w * conv_h);
755 shape_im2col.
set(2, 1);
758 im2col_reshaped_info.set_quantization_info(
src->quantization_info());
762 gemm_input_to_use = &im2col_reshaped_info;
770 shape_gemm.
set(0, mat_weights_cols);
771 shape_gemm.
set(1, conv_w * conv_h);
772 info_gemm =
TensorInfo(shape_gemm, 1, output_data_type);
776 info_gemm =
TensorInfo(
dst->tensor_shape(), 1, output_data_type);
778 info_gemm.set_quantization_info(
dst->quantization_info()).set_data_layout(
src->data_layout());
779 gemm_output_to_use = &info_gemm;
784 enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
803 auto gemm_input_to_use =
src;
808 bool out_has_padding = _skip_col2im && (
dst->info()->padding().bottom != 0 ||
dst->info()->padding().top != 0);
815 gemm_input_to_use = im2col_output.
get();
819 const ITensor *out_to_use = out_has_padding ? gemm_output.
get() :
dst;
824 auto gemm_output_to_use = gemm_output.
get();
828 gemm_output_to_use = &gemm3d;
830 if (_skip_col2im && !out_has_padding)
832 gemm_output_to_use =
dst;
842 const bool use_reinterpreted_wei = (_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose);
844 _weights_reshaped, *weights,
846 !use_reinterpreted_wei);
848 const bool use_reshaped_wei = (_run_wt && (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
849 _wt_method == WeightTransformMethod::FusedReshapeAndTranspose));
851 false , !use_reshaped_wei ,
855 if (use_reinterpreted_wei)
859 else if (use_reshaped_wei)
865 _is_quantized ? _mm_gemmlowp->run(gemm_pack) : _mm_gemm->run(gemm_pack);
881 else if (out_has_padding)
898 _wt_method = get_wt_method(*(weights->info()));
901 case (WeightTransformMethod::FusedReshapeAndTranspose):
904 _weights_reshape_and_transpose_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
905 _weights_reshape_and_transpose_kernel->configure(weights->info(),
nullptr, &_weights_reshaped);
908 case (WeightTransformMethod::ReshapeThenTranspose):
911 _weights_reshape = std::make_unique<CpuReshape>();
912 _weights_reshape->configure(weights->info(), &_weights_reshaped);
915 case (WeightTransformMethod::ReinterpretThenTranspose):
942 case (WeightTransformMethod::FusedReshapeAndTranspose):
946 _weights_reshape_and_transpose_kernel->window(),
pack);
947 weights->mark_as_unused();
951 case (WeightTransformMethod::ReshapeThenTranspose):
954 _weights_reshape->run(
pack);
955 weights->mark_as_unused();
959 case (WeightTransformMethod::ReinterpretThenTranspose):
971 _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
980 bool CpuGemmConv2d::isVarWeightsKernel()
const
982 return _mm_gemm && _mm_gemm->isVarWeightsKernel();