72 if (!constant_weights)
78 if (
bool(gemm_kernel))
80 if (validate_gemm_kernel(gemm_kernel.gemm_type))
83 to_string(gemm_kernel.gemm_type).c_str());
84 return gemm_kernel.gemm_type;
89 to_string(gemm_kernel.gemm_type).c_str());
90 return gemm_kernel.gemm_type;
93 inline bool validate_lhs_rhs_info_reshaped_only_rhs(
const GEMMLHSMatrixInfo &lhs_info,
94 const GEMMRHSMatrixInfo &rhs_info,
98 const ITensorInfo *output,
99 GEMMKernelInfo gemm_kernel_info)
102 TensorInfo tmp_b_info{};
110 gemm_kernel_info.lhs_info = lhs_info;
111 gemm_kernel_info.rhs_info = rhs_info;
112 gemm_kernel_info.has_pad_y =
false;
114 rhs_info, gemm_kernel_info)))
118 gemm_kernel_info.has_pad_y =
true;
120 rhs_info, gemm_kernel_info)))
128 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
130 GEMMKernelInfo kernel_info,
131 const ITensorInfo *a,
132 const ITensorInfo *
b,
133 const ITensorInfo *c,
134 const ITensorInfo *output)
139 if (validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a,
b, c, output, kernel_info))
142 "Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
144 return {config.lhs_info, config.rhs_info};
149 "Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ",
151 return {config.lhs_info, config.rhs_info};
155 inline bool validate_lhs_rhs_info_reshaped(
const GEMMLHSMatrixInfo &lhs_info,
156 const GEMMRHSMatrixInfo &rhs_info,
157 const ITensorInfo *a,
158 const ITensorInfo *
b,
159 const ITensorInfo *c,
160 const ITensorInfo *output,
161 GEMMKernelInfo gemm_kernel_info,
162 bool reinterpret_input_as_3d)
165 TensorInfo tmp_a_info{};
166 TensorInfo tmp_b_info{};
183 gemm_kernel_info.lhs_info = lhs_info;
184 gemm_kernel_info.rhs_info = rhs_info;
186 rhs_info, gemm_kernel_info)))
194 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
196 GEMMKernelInfo kernel_info,
197 const ITensorInfo *a,
198 const ITensorInfo *
b,
199 const ITensorInfo *c,
200 const ITensorInfo *output,
201 bool reinterpret_input_as_3d)
206 if (validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a,
b, c, output, kernel_info,
207 reinterpret_input_as_3d))
210 "Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
212 return {config.lhs_info, config.rhs_info};
217 "Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ",
to_string(config.lhs_info).c_str(),
219 return {config.lhs_info, config.rhs_info};
232 _reshape_b_only_on_first_run(false),
235 _aux_mem(AuxTensorIdx::Count)
251 const unsigned int n =
b->dimension(0);
253 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
268 _mm_native_kernel->set_target(gpu_target);
274 _mm_native_kernel->configure(compile_context, a,
b, c, output, alpha, beta, config.lhs_info, config.rhs_info,
278 void ClGemm::configure_reshaped(
const CLCompileContext &compile_context,
285 const GEMMInfo &gemm_info)
288 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
289 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
290 const unsigned int n =
b->dimension(0);
291 const unsigned int k = a->dimension(0);
292 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
293 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
295 bool broadcast_bias = gemm_info.broadcast_bias();
297 GEMMKernelInfo kernel_info;
301 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
302 kernel_info.reinterpret_input_as_3d =
false;
303 kernel_info.broadcast_bias = broadcast_bias;
304 kernel_info.activation_info = gemm_info.activation_info();
307 _reshape_lhs_kernel->set_target(gpu_target);
308 _mm_reshaped_kernel->set_target(gpu_target);
310 GEMMLHSMatrixInfo lhs_info{};
311 GEMMRHSMatrixInfo rhs_info{};
314 std::tie(lhs_info, rhs_info) =
316 kernel_info, a,
b, c, output, gemm_info.reinterpret_input_as_3d());
318 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
319 _reshape_rhs_kernel->configure(compile_context,
b, &_tmp_b, rhs_info);
322 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info,
329 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.
total_size());
332 void ClGemm::configure_reshaped_only_rhs(
const CLCompileContext &compile_context,
339 const GEMMInfo &gemm_info)
342 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
343 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
344 const unsigned int n =
b->dimension(0);
345 const unsigned int k = a->dimension(0);
346 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
347 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
349 bool broadcast_bias = gemm_info.broadcast_bias();
351 GEMMKernelInfo kernel_info;
355 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
356 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
357 kernel_info.broadcast_bias = broadcast_bias;
358 kernel_info.activation_info = gemm_info.activation_info();
361 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
363 GEMMLHSMatrixInfo lhs_info{};
364 GEMMRHSMatrixInfo rhs_info{};
367 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(
368 auto_heuristics::CommonQuery{gpu_target,
data_type, m, n, k, batch_size}, kernel_info, a,
b, c, output);
371 _reshape_rhs_kernel->configure(compile_context,
b, &_tmp_b, rhs_info);
378 kernel_info.has_pad_y =
false;
379 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info,
385 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.
total_size());
388 void ClGemm::configure_reshaped_only_rhs_mmul(
const CLCompileContext &compile_context,
395 const GEMMInfo &gemm_info)
398 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
399 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
400 const unsigned int n =
b->dimension(0);
401 const unsigned int k = a->dimension(0);
402 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
403 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
405 bool broadcast_bias = gemm_info.broadcast_bias();
407 GEMMKernelInfo kernel_info;
411 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
412 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
413 kernel_info.broadcast_bias = broadcast_bias;
414 kernel_info.activation_info = gemm_info.activation_info();
417 _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
419 GEMMLHSMatrixInfo lhs_info{};
420 GEMMRHSMatrixInfo rhs_info{};
426 rhs_info = gemm_config.rhs_info;
431 _reshape_rhs_kernel->configure(compile_context,
b, &_tmp_b, rhs_info);
434 kernel_info.has_pad_y =
false;
435 _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info,
436 rhs_info, kernel_info);
441 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.
total_size());
444 Status ClGemm::validate_native(
const ITensorInfo *a,
445 const ITensorInfo *
b,
446 const ITensorInfo *c,
447 const ITensorInfo *output,
450 const GEMMInfo &gemm_info)
458 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
459 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
460 const unsigned int n =
b->dimension(0);
461 const unsigned int k = a->dimension(0);
462 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
463 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
464 const bool broadcast_bias = gemm_info.broadcast_bias();
466 GEMMKernelInfo kernel_info;
470 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
471 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
472 kernel_info.broadcast_bias = broadcast_bias;
473 kernel_info.activation_info = gemm_info.activation_info();
480 a,
b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
485 Status ClGemm::validate_reshaped(
const ITensorInfo *a,
486 const ITensorInfo *
b,
487 const ITensorInfo *c,
488 const ITensorInfo *output,
491 const GEMMInfo &gemm_info)
496 TensorInfo tmp_a_info{};
497 TensorInfo tmp_b_info{};
502 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
503 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
504 const unsigned int n =
b->dimension(0);
505 const unsigned int k = a->dimension(0);
506 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
507 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
508 const bool broadcast_bias = gemm_info.broadcast_bias();
510 GEMMKernelInfo kernel_info;
514 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
515 kernel_info.reinterpret_input_as_3d =
false;
516 kernel_info.broadcast_bias = broadcast_bias;
517 kernel_info.activation_info = gemm_info.activation_info();
519 GEMMLHSMatrixInfo lhs_info;
520 GEMMRHSMatrixInfo rhs_info;
524 const auto gemm_config =
527 rhs_info = gemm_config.rhs_info;
539 beta, lhs_info, rhs_info, kernel_info));
544 Status ClGemm::validate_reshaped_only_rhs(
const ITensorInfo *a,
545 const ITensorInfo *
b,
546 const ITensorInfo *c,
547 const ITensorInfo *output,
550 const GEMMInfo &gemm_info)
555 TensorInfo tmp_b_info{};
560 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
561 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
562 const unsigned int n =
b->dimension(0);
563 const unsigned int k = a->dimension(0);
564 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
565 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
566 const bool broadcast_bias = gemm_info.broadcast_bias();
568 GEMMKernelInfo kernel_info;
572 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
573 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
574 kernel_info.broadcast_bias = broadcast_bias;
575 kernel_info.activation_info = gemm_info.activation_info();
577 GEMMLHSMatrixInfo lhs_info;
578 GEMMRHSMatrixInfo rhs_info;
585 rhs_info = gemm_config.rhs_info;
591 kernel_info.has_pad_y =
false;
593 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
595 kernel_info.has_pad_y =
true;
597 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
602 Status ClGemm::validate_reshaped_only_rhs_mmul(
const ITensorInfo *a,
603 const ITensorInfo *
b,
604 const ITensorInfo *c,
605 const ITensorInfo *output,
608 const GEMMInfo &gemm_info)
612 TensorInfo tmp_b_info{};
617 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
618 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
619 const unsigned int n =
b->dimension(0);
620 const unsigned int k = a->dimension(0);
621 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
622 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
623 const bool broadcast_bias = gemm_info.broadcast_bias();
625 GEMMKernelInfo kernel_info;
629 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
630 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
631 kernel_info.broadcast_bias = broadcast_bias;
632 kernel_info.activation_info = gemm_info.activation_info();
634 GEMMLHSMatrixInfo lhs_info;
635 GEMMRHSMatrixInfo rhs_info;
642 rhs_info = gemm_config.rhs_info;
650 kernel_info.has_pad_y =
false;
652 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
678 const unsigned int n =
b->dimension(0);
680 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
683 _gemm_kernel_type = auto_select_gemm_kernel(
685 _reshape_b_only_on_first_run,
b->are_values_constant());
691 switch (_gemm_kernel_type)
695 configure_native(compile_context, a,
b, c_to_use, output, alpha, beta, gemm_info);
700 configure_reshaped(compile_context, a,
b, c_to_use, output, alpha, beta, gemm_info);
705 configure_reshaped_only_rhs(compile_context, a,
b, c_to_use, output, alpha, beta, gemm_info);
710 configure_reshaped_only_rhs_mmul(compile_context, a,
b, c_to_use, output, alpha, beta, gemm_info);
731 const unsigned int n =
b->dimension(0);
733 const unsigned int batch_size = reinterpret_input_as_3d ? a->
dimension(3) : a->
dimension(2);
752 const ITensorInfo *c_to_use = fuse_add_c ? c :
nullptr;
754 switch (gemm_kernel_type)
774 validate_reshaped_only_rhs_mmul(a,
b, c_to_use, output, alpha, beta, gemm_info));
801 switch (_gemm_kernel_type)
814 if (!_reshape_b_only_on_first_run)
833 if (!_reshape_b_only_on_first_run)
842 const unsigned int cross_plane_pad_dst =
dst->info()->padding().top +
dst->info()->padding().bottom;
843 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
861 if (!_reshape_b_only_on_first_run)
870 const unsigned int cross_plane_pad_dst =
dst->info()->padding().top +
dst->info()->padding().bottom;
871 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
903 if ((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) &&
904 (src1 !=
nullptr && rhs_aux !=
nullptr) && rhs_aux)