74 if (
bool(gemm_kernel))
76 if (validate_gemm_kernel(gemm_kernel.gemm_type))
79 to_string(gemm_kernel.gemm_type).c_str());
80 return gemm_kernel.gemm_type;
85 to_string(gemm_kernel.gemm_type).c_str());
86 return gemm_kernel.gemm_type;
90 inline bool validate_lhs_rhs_info_native(
const GEMMLHSMatrixInfo &lhs_info,
91 const GEMMRHSMatrixInfo &rhs_info,
94 const GEMMReshapeInfo &reshape_info)
97 TensorInfo mm_result_s32_info{};
117 const ITensorInfo *a,
118 const ITensorInfo *
b,
119 const GEMMReshapeInfo &reshape_info)
124 if (validate_lhs_rhs_info_native(config.lhs_info, config.rhs_info, a,
b, reshape_info))
127 "Use native config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
129 return {config.lhs_info, config.rhs_info};
135 return {config.lhs_info, config.rhs_info};
139 inline bool validate_lhs_rhs_info_reshaped_only_rhs(
const GEMMLHSMatrixInfo &lhs_info,
140 const GEMMRHSMatrixInfo &rhs_info,
141 const ITensorInfo *a,
142 const ITensorInfo *
b,
143 const ITensorInfo *output,
147 bool reinterpret_input_as_3d,
148 int depth_output_gemm3d)
151 TensorInfo tmp_b_info{};
163 GEMMKernelInfo gemm_kernel_info;
164 gemm_kernel_info.m = m;
165 gemm_kernel_info.n = n;
166 gemm_kernel_info.k = k;
167 gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
168 gemm_kernel_info.depth_output_gemm3d = depth_output_gemm3d;
169 gemm_kernel_info.lhs_info = lhs_info;
170 gemm_kernel_info.rhs_info = rhs_info;
172 TensorInfo output_info_copy(*output);
183 inline bool validate_lhs_rhs_info_reshaped_only_rhs_mmul(
const GEMMLHSMatrixInfo &lhs_info,
184 const GEMMRHSMatrixInfo &rhs_info,
185 const ITensorInfo *a,
186 const ITensorInfo *
b,
187 const ITensorInfo *output,
191 bool reinterpret_input_as_3d,
192 int depth_output_gemm3d)
195 TensorInfo tmp_b_info{};
207 GEMMKernelInfo gemm_kernel_info;
208 gemm_kernel_info.m = m;
209 gemm_kernel_info.n = n;
210 gemm_kernel_info.k = k;
211 gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
212 gemm_kernel_info.depth_output_gemm3d = depth_output_gemm3d;
213 gemm_kernel_info.lhs_info = lhs_info;
214 gemm_kernel_info.rhs_info = rhs_info;
216 TensorInfo output_info_copy(*output);
227 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
229 bool reinterpret_input_as_3d,
230 int depth_output_gemm3d,
231 const ITensorInfo *a,
232 const ITensorInfo *
b,
233 const ITensorInfo *output)
238 if (validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a,
b, output, query.
m, query.
n,
239 query.
k, reinterpret_input_as_3d, depth_output_gemm3d))
242 "Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
244 return {config.lhs_info, config.rhs_info};
249 "Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ",
251 return {config.lhs_info, config.rhs_info};
255 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
257 bool reinterpret_input_as_3d,
258 int depth_output_gemm3d,
259 const ITensorInfo *a,
260 const ITensorInfo *
b,
261 const ITensorInfo *output)
265 validate_lhs_rhs_info_reshaped_only_rhs_mmul(config.lhs_info, config.rhs_info, a,
b, output, query.
m, query.
n,
266 query.
k, reinterpret_input_as_3d, depth_output_gemm3d);
268 "Use reshaped_only_rhs_mmul config from default heuristics: LHS info: %s ; RHS info: %s ",
270 return {config.lhs_info, config.rhs_info};
289 : _weights_to_qasymm8(std::make_unique<
ClCastKernel>()),
298 _aux_mem(AuxTensorIdx::Count)
319 _b_offset = _convert_to_qasymm8 ? -128 :
b->quantization_info().uniform().offset;
320 _gemm_info = gemm_info;
326 _mm_native_kernel->set_target(gpu_target);
327 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
328 _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
337 const unsigned int n =
b->dimension(0);
339 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
342 const auto reshape_info =
GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
344 _gemm_kernel_type = auto_select_gemm_kernel(
347 if (_convert_to_qasymm8)
350 _qasymm8_weights = *
b;
355 ITensorInfo *matrix_b = _convert_to_qasymm8 ? &_qasymm8_weights :
b;
362 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(
364 depth_output_gemm3d, a, _convert_to_qasymm8 ? &_qasymm8_weights :
b, output);
367 _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights :
b, &_tmp_b,
376 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs_mmul(
378 depth_output_gemm3d, a, _convert_to_qasymm8 ? &_qasymm8_weights :
b, output);
381 _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights :
b, &_tmp_b,
394 _mtx_b_reduction_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights :
b,
395 &_vector_sum_col, reduction_info);
404 _mtx_a_reduction_kernel->configure(compile_context, a, &_vector_sum_row, reduction_info);
408 gemm_kernel_info.
m = m;
409 gemm_kernel_info.
n = n;
410 gemm_kernel_info.
k = k;
413 gemm_kernel_info.
lhs_info = lhs_info;
414 gemm_kernel_info.
rhs_info = rhs_info;
415 gemm_kernel_info.
a_offset = _a_offset;
416 gemm_kernel_info.
b_offset = _b_offset;
430 if (num_filters == 1)
443 _mm_reshaped_only_rhs_kernel->configure(
444 compile_context, a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ?
nullptr : &_vector_sum_col,
445 _b_offset == 0 ?
nullptr : &_vector_sum_row, c !=
nullptr ? c :
nullptr,
446 &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
452 _mm_reshaped_only_rhs_mmul_kernel->configure(
453 compile_context, a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ?
nullptr : &_vector_sum_col,
454 _b_offset == 0 ?
nullptr : &_vector_sum_row, c !=
nullptr ? c :
nullptr,
455 &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
459 _run_output_stage =
true;
463 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, matrix_b, &_mm_result_s32,
468 _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, matrix_b, &_mm_result_s32,
475 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(
477 _convert_to_qasymm8 ? &_qasymm8_weights : matrix_b, reshape_info);
480 _mm_native_kernel->configure(compile_context, a, matrix_b, &_mm_result_s32, lhs_info, rhs_info,
483 _offset_contribution_output_stage_kernel->configure(
484 compile_context, &_mm_result_s32, _a_offset == 0 ?
nullptr : &_vector_sum_col,
485 _b_offset == 0 ?
nullptr : &_vector_sum_row, c !=
nullptr ? c :
nullptr, output, a->
dimension(0),
486 _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers,
487 &_gemm_output_stage_shifts);
493 _run_offset_contribution =
true;
497 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info);
502 _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info);
508 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(
510 _convert_to_qasymm8 ? &_qasymm8_weights :
b, reshape_info);
513 _mm_native_kernel->configure(compile_context, a, matrix_b, output, lhs_info, rhs_info, reshape_info);
517 _offset_contribution_kernel->configure(compile_context, output, _a_offset == 0 ?
nullptr : &_vector_sum_col,
518 _b_offset == 0 ?
nullptr : &_vector_sum_row, c !=
nullptr ? c :
nullptr,
523 _aux_mem[RhsQAsymm8] =
525 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary,
527 if (is_gemm_reshaped(_gemm_kernel_type))
530 _aux_mem[RhsQAsymm8] =
536 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.
total_size());
540 _aux_mem[VecSumCol] =
542 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary,
547 _aux_mem[VecSumRow] =
573 int32_t b_offset =
b->quantization_info().uniform().offset;
586 const unsigned int n =
b->dimension(0);
588 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
591 bool reshape_matrix_b = is_gemm_reshaped(
601 if (convert_to_qasymm8)
608 if (reshape_matrix_b)
610 matrix_b_info = &tmp_b_info;
617 lhs_info = res.lhs_info;
618 rhs_info = res.rhs_info;
651 gemm_kernel_info.
m = m;
652 gemm_kernel_info.
n = n;
653 gemm_kernel_info.
k = k;
656 gemm_kernel_info.
lhs_info = lhs_info;
657 gemm_kernel_info.
rhs_info = rhs_info;
658 gemm_kernel_info.
a_offset = a_offset;
659 gemm_kernel_info.
b_offset = b_offset;
666 const TensorInfo gemm_output_stage_multipliers_shifts_info(
673 if (reshape_matrix_b &&
677 matrix_a_info, matrix_b_info, output, gemm_kernel_info, a_offset == 0 ?
nullptr : &info_vector_sum_col,
678 b_offset == 0 ?
nullptr : &info_vector_sum_row, c, &gemm_output_stage_multipliers_shifts_info,
679 &gemm_output_stage_multipliers_shifts_info));
685 if (reshape_matrix_b)
690 *matrix_a_info, *matrix_b_info, reshape_info))
695 matrix_a_info, matrix_b_info, &mm_result_s32_info, gemm_kernel_info));
702 *matrix_a_info, *matrix_b_info,
false, reshape_info))
710 lhs_info = res.lhs_info;
711 rhs_info = res.rhs_info;
715 matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
720 &mm_result_s32_info, a_offset == 0 ?
nullptr : &info_vector_sum_col,
721 b_offset == 0 ?
nullptr : &info_vector_sum_row, c, output, a_offset, b_offset, gemmlowp_output_stage,
722 &gemm_output_stage_multipliers_shifts_info, &gemm_output_stage_multipliers_shifts_info));
727 if (reshape_matrix_b)
731 matrix_a_info, matrix_b_info, output, gemm_kernel_info));
739 lhs_info = res.lhs_info;
740 rhs_info = res.rhs_info;
744 matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));
751 output, a_offset == 0 ?
nullptr : &info_vector_sum_col, b_offset == 0 ?
nullptr : &info_vector_sum_row,
752 c, a_offset, b_offset));
780 const ITensor *matrix_b = _convert_to_qasymm8 ? rhs_qasymm8.
get() :
b;
782 if (is_gemm_reshaped(_gemm_kernel_type))
784 matrix_b = tmp_b.
get();
785 if (!_reshape_b_only_on_first_run)
795 if (_a_offset != 0 && !_reshape_b_only_on_first_run)
810 if (is_gemm_reshaped(_gemm_kernel_type))
813 if (_run_offset_contribution)
852 if (_run_output_stage)
866 if (_run_offset_contribution)
888 if (_convert_to_qasymm8)
895 if (is_gemm_reshaped(_gemm_kernel_type) && _reshape_b_only_on_first_run)
905 if (_a_offset != 0 && _reshape_b_only_on_first_run)
922 if (multiplier_tensor !=
nullptr && multiplier_tensor->
info()->
total_size() > 0)
927 num_filters *
sizeof(int32_t));