50 using namespace misc::shape_calculator;
54 using ElementsProcessed = Steps;
57 const ITensorInfo *src1,
58 const ITensorInfo *
dst,
59 const GEMMKernelInfo &gemm_info,
60 const ITensorInfo *vector_sum_col,
61 const ITensorInfo *vector_sum_row,
62 const ITensorInfo *
bias,
63 const ITensorInfo *output_multipliers,
64 const ITensorInfo *output_shifts)
78 "The number of dimensions for the LHS matrix must be <= 4");
80 "The number of dimensions for the RHS matrix must be <= 3");
82 const GEMMRHSMatrixInfo rhs_info = gemm_info.rhs_info;
83 const GEMMLHSMatrixInfo lhs_info = gemm_info.lhs_info;
84 const GEMMLowpOutputStageInfo
output_stage = gemm_info.output_stage;
87 "Only 2,3,4,8,16 are supported for k0");
90 "Only 2,3,4,8,16 are supported for n0");
93 const int m = gemm_info.m;
94 const int n = gemm_info.n;
95 const int k = gemm_info.k;
97 TensorShape tensor_shape1{src1->tensor_shape()};
98 tensor_shape1.set(0, n);
99 tensor_shape1.set(1, k);
101 const TensorInfo tensor_info1 = src1->clone()->set_tensor_shape(tensor_shape1);
102 const TensorInfo tensor_info_reshaped1 =
106 if (gemm_info.reinterpret_input_as_3d)
116 const TensorShape expected_dst_shape =
compute_mm_shape(*src0, *src1, gemm_info);
117 if (
dst->total_size() != 0)
119 const TensorInfo tensor_info_dst =
dst->clone()->set_tensor_shape(expected_dst_shape);
140 "Only GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT is supported");
146 if (gemm_info.a_offset != 0)
153 if (gemm_info.b_offset != 0)
158 const bool reinterpret_as_3d =
159 expected_dst_shape.num_dimensions() > 1 && expected_dst_shape.y() != vector_sum_row->tensor_shape().x();
163 (expected_dst_shape[1] * expected_dst_shape[2]));
166 if (expected_dst_shape.num_dimensions() > 1)
168 const unsigned int dst_batch_idx = reinterpret_as_3d ? 3 : 2;
170 TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
171 vector_sum_row_shape.collapse_from(1);
172 TensorShape collapsed_dst_shape(expected_dst_shape);
173 collapsed_dst_shape.collapse_from(dst_batch_idx);
176 "vector_sum_row must have the same number of batches of dst tensor");
178 if (gemm_info.a_offset != 0)
180 TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
181 vector_sum_col_shape.collapse_from(1);
184 vector_sum_col_shape[1] != vector_sum_row_shape[1],
185 "vector_sum_col tensor must have the same number of batches of "
186 "vector_sum_row_shape or the number of batches must be set to 1");
191 if (
dst->total_size() != 0)
197 if (output_multipliers !=
nullptr && output_shifts !=
nullptr)
214 const ITensorInfo *src1,
216 const GEMMKernelInfo &gemm_info,
217 ITensorInfo *vector_sum_col,
218 const ITensorInfo *vector_sum_row,
220 ITensorInfo *output_multipliers,
221 ITensorInfo *output_shifts,
222 ElementsProcessed &num_elements_processed)
224 const GEMMLowpOutputStageInfo
output_stage = gemm_info.output_stage;
226 unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
227 unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
228 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
229 bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d != 0);
233 bool window_changed =
false;
237 if (reinterpret_input_as_3d == reinterpret_output_as_3d)
239 reinterpret_output_as_3d =
false;
243 const TensorShape expected_dst_shape =
compute_mm_shape(*src0, *src1, gemm_info);
247 *
dst, src0->clone()->set_tensor_shape(expected_dst_shape).set_data_type(
output_stage.output_data_type));
254 TensorInfo tmp_info(*
dst);
256 if (reinterpret_output_as_3d)
260 TensorShape tmp_shape(
dst->tensor_shape());
261 tmp_shape.collapse(2U, 1U);
262 tmp_info.set_tensor_shape(tmp_shape);
266 num_elems_processed_per_iteration_x = gemm_info.rhs_info.n0;
267 num_elems_processed_per_iteration_y = gemm_info.lhs_info.m0;
270 calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
276 if (gemm_info.a_offset != 0)
278 AccessWindowHorizontal vector_sum_col_access(vector_sum_col, 0, num_elems_processed_per_iteration_x);
286 AccessWindowHorizontal bias_access(
bias, 0, num_elems_processed_per_iteration_x);
290 if (output_multipliers !=
nullptr &&
output_stage.is_quantized_per_channel)
292 AccessWindowHorizontal output_multipliers_access(output_multipliers, 0,
293 num_elems_processed_per_iteration_x);
294 AccessWindowHorizontal output_shifts_access(output_shifts, 0, num_elems_processed_per_iteration_x);
302 Window collapsed = win;
303 const unsigned int dimension_to_collapse = std::min(
static_cast<unsigned int>(
dst->num_dimensions()), 2u);
304 collapsed = win.collapse(win, dimension_to_collapse);
308 return std::make_pair(err, collapsed);
330 output_multipliers, output_shifts));
336 const int32_t a_offset = gemm_info.
a_offset;
337 const int32_t b_offset = gemm_info.
b_offset;
342 _is_quantized_per_channel =
output_stage.is_quantized_per_channel;
346 if (_reinterpret_input_as_3d == _reinterpret_output_as_3d)
348 _reinterpret_input_as_3d =
false;
349 _reinterpret_output_as_3d =
false;
354 _slide_matrix_b = (src1->
num_dimensions() >= num_dimensions_src0);
356 ElementsProcessed num_elements_processed{};
360 output_multipliers, output_shifts, num_elements_processed);
362 ICLKernel::configure_internal(win_config.second);
367 const unsigned int internal_m = _reinterpret_output_as_3d ? gemm_info.
m :
dst->dimension(1);
371 const unsigned int internal_m0 = std::min(internal_m, lhs_info.
m0);
374 const unsigned int partial_store_m0 = internal_m % internal_m0;
375 const unsigned int partial_store_n0 = gemm_info.
n % rhs_info.n0;
379 build_opts.
add_option_if(_reinterpret_input_as_3d,
"-DREINTERPRET_INPUT_AS_3D");
380 build_opts.
add_option_if(_reinterpret_output_as_3d,
"-DREINTERPRET_OUTPUT_AS_3D");
381 build_opts.
add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d,
383 build_opts.
add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d,
386 build_opts.
add_option_if(rhs_info.interleave,
"-DRHS_INTERLEAVE");
387 build_opts.
add_option_if(_use_dummy_work_items,
"-DDUMMY_WORK_ITEMS");
400 std::string
kernel_name(
"gemmlowp_mm_reshaped_only_rhs_");
406 _fuse_output_stage =
true;
408 if (a_offset != 0 && vector_sum_col !=
nullptr)
419 if (!_is_quantized_per_channel)
421 build_opts.
add_option(
"-DRESULT_MULTIPLIER=" +
427 build_opts.
add_option(
"-DRESULT_MULTIPLIER=0");
430 build_opts.
add_option_if(_is_quantized_per_channel,
"-DPER_CHANNEL_QUANTIZATION");
453 _config_id += (_reinterpret_input_as_3d ?
"3di_" :
"");
454 _config_id += (_reinterpret_output_as_3d ?
"3do_" :
"");
485 ElementsProcessed num_elements_processed{};
487 output_multipliers, output_shifts));
490 vector_sum_col !=
nullptr ? vector_sum_col->
clone().get() :
nullptr,
491 vector_sum_row !=
nullptr ? vector_sum_row->
clone().get() :
nullptr,
492 bias !=
nullptr ?
bias->clone().get() :
nullptr,
493 output_multipliers !=
nullptr ? output_multipliers->
clone().get() :
nullptr,
494 output_shifts !=
nullptr ? output_shifts->
clone().get() :
nullptr,
495 num_elements_processed)
503 cl::CommandQueue &queue)
514 const auto vector_sum_col =
516 const auto vector_sum_row =
518 const auto output_shifts =
520 const auto output_multipliers =
524 if (src1->info()->num_dimensions() < 3)
536 if (_reinterpret_input_as_3d)
539 const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
540 const unsigned int total_cross_plane_pad = src0->info()->padding().top + src0->info()->padding().bottom;
541 _kernel.setArg<cl_uint>(idx0,
static_cast<unsigned int>(total_cross_plane_pad));
544 if (_reinterpret_output_as_3d)
547 const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
548 const unsigned int total_cross_plane_pad =
dst->info()->padding().top +
dst->info()->padding().bottom;
549 _kernel.setArg<cl_uint>(idx0,
static_cast<unsigned int>(total_cross_plane_pad));
572 if (!_slide_matrix_b)
574 slice_b = slice_matrix_b;
577 unsigned int idx = 0;
578 add_2D_tensor_argument(idx, src0,
slice);
579 add_2D_tensor_argument(idx, src1, slice_b);
580 add_2D_tensor_argument(idx,
dst,
slice);
581 _kernel.setArg<cl_uint>(idx++,
static_cast<unsigned int>(src0->info()->strides_in_bytes()[2]));
582 _kernel.setArg<cl_uint>(idx++,
static_cast<unsigned int>(src1->info()->strides_in_bytes()[2]));
583 _kernel.setArg<cl_uint>(idx++,
static_cast<unsigned int>(
dst->info()->strides_in_bytes()[2]));
584 if (_reinterpret_input_as_3d)
590 if (_reinterpret_output_as_3d)
596 if (_fuse_output_stage)
598 add_2D_tensor_argument_if((vector_sum_col !=
nullptr), idx, vector_sum_col, win_vector_sum_col);
599 add_2D_tensor_argument_if((vector_sum_row !=
nullptr), idx, vector_sum_row, win_vector_sum_row);
600 add_1D_tensor_argument_if((
bias !=
nullptr), idx,
bias, biases_slice);
601 add_1D_tensor_argument_if(_is_quantized_per_channel, idx, output_multipliers, biases_slice);
602 add_1D_tensor_argument_if(_is_quantized_per_channel, idx, output_shifts, biases_slice);
604 enqueue(queue, *
this,
slice, lws_hint(), _use_dummy_work_items);