57 template <
typename TypeInput,
typename TypeOutput>
63 unsigned int num_threads)
70 std::vector<IScheduler::Workload> workloads(num_threads);
71 for (
unsigned int t = 0;
t < num_threads; ++
t)
73 workloads[
t] = [=](
const ThreadInfo &
info)
75 const unsigned int start = (
info.thread_id * wsize) / num_threads;
76 const unsigned int end = ((
info.thread_id + 1) * wsize) / num_threads;
94 void operator()(
void *x)
111 Params extract_parameters(
const ITensorInfo *a,
const ITensorInfo *
b,
const ITensorInfo *d,
const AsmGemmInfo &
info)
115 p.M = d->tensor_shape().y();
116 p.K = a->tensor_shape().x();
117 p.N = d->tensor_shape().x();
126 p.sections =
b->tensor_shape()[2] *
b->tensor_shape()[3];
130 p.multis =
b->tensor_shape().z();
131 p.batches = d->tensor_shape().total_size_upper(2) / p.multis;
135 if (
info.depth_output_gemm3d != 0)
137 p.M = d->tensor_shape().y() * d->tensor_shape().z();
138 p.batches = d->tensor_shape().total_size_upper(3) / p.multis;
147 const int granule_threshold = 200;
148 IScheduler::Hints scheduling_hint = IScheduler::Hints(
Window::DimX);
169 return scheduling_hint;
173 template <
typename TypeInput,
typename TypeOutput,
class OutputStage = arm_gemm::Nothing>
174 class Fallback :
public CpuGemmAssemblyDispatch::IFallback
178 ~Fallback() =
default;
191 const ITensorInfo *
b,
192 const ITensorInfo *c,
195 const AsmGemmInfo &gemm_info,
196 const OutputStage &os = {});
212 std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
213 set_requantize_data(
const std::vector<int32_t> &shifts,
const std::vector<int32_t> &multipliers);
216 void run(ITensorPack &tensors)
override;
217 void prepare(ITensorPack &tensors)
override;
218 bool is_configured()
const override;
220 bool isVarWeightsKernel()
const override
222 if (!_gemm_kernel_asm)
232 AsmGemmWorkspace = 0,
245 void configure_indirect(
const ITensorInfo *a,
const ITensorInfo *
b,
const ITensorInfo *d,
const AsmGemmInfo &
info);
247 void prepare_indirect_buffer(ITensorPack &tensors);
250 std::unique_ptr<CpuTranspose> _pre_pretranspose_b{
nullptr};
252 std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{
nullptr};
254 std::unique_ptr<INEKernel> _optimised_kernel{
nullptr};
256 TensorInfo _workspace_info{};
258 TensorInfo _pre_pretransposed_b_info{};
260 TensorInfo _pretranspose_info{};
262 bool _is_prepared{
false};
264 AsmGemmInfo _gemm_info{};
268 std::vector<int32_t> _shifts{};
269 std::vector<int32_t> right_shifts{};
270 std::vector<int32_t> left_shifts{};
272 std::vector<int32_t> _multipliers{};
274 std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
275 std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
276 std::vector<TypeInput> _indirect_pad{};
279 bool _B_pretranspose_required{
false};
280 bool _is_b_constant{
true};
281 bool _is_c_constant{
true};
284 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
285 std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
286 Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(
const std::vector<int32_t> &shifts,
287 const std::vector<int32_t> &multipliers)
289 _multipliers = multipliers;
291 bool need_left =
false;
292 for (
const auto s : _shifts)
294 left_shifts.push_back(std::max(-s, int32_t(0)));
295 right_shifts.push_back(std::min(-s, int32_t(0)));
296 if (s < 0 && !need_left)
301 return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
304 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
305 void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
308 const TypeInput *A_ptr =
reinterpret_cast<TypeInput *
>(a->buffer());
310 const int batches = a->info()->tensor_shape().total_size_upper(3);
311 const size_t stride_A = a->info()->strides_in_bytes().y() /
sizeof(TypeInput);
312 const size_t batch_stride_A = a->info()->strides_in_bytes()[3] /
sizeof(TypeInput);
313 const size_t multi_stride_A = a->info()->strides_in_bytes()[4] /
sizeof(TypeInput);
317 const size_t batch_stride = batch_size /
sizeof(TypeInput);
318 const int multi_size = batch_size *
batches;
319 const size_t multi_stride = multi_size /
sizeof(TypeInput);
321 for (int64_t m = 0; m <
multis; m++)
325 for (int64_t output_y = 0; output_y < _cp.
output_height; output_y++)
327 for (int64_t output_x = 0; output_x < _cp.
output_width; output_x++)
329 int64_t output_xy = (output_y * _cp.
output_width) + output_x;
331 for (int64_t kernel_y = 0; kernel_y < _cp.
kernel_height; kernel_y++)
333 for (int64_t kernel_x = 0; kernel_x < _cp.
kernel_width; kernel_x++)
337 int64_t kernel_xy = (kernel_y * _cp.
kernel_width) + kernel_x;
338 int64_t input_xy = (input_y * _cp.
input_width) + input_x;
343 .get()[m * multi_stride +
b * batch_stride + kernel_xy * output_hw + output_xy] =
344 _indirect_pad.data();
349 .get()[m * multi_stride +
b * batch_stride + kernel_xy * output_hw + output_xy] =
350 A_ptr + (m * multi_stride_A +
b * batch_stride_A + input_xy * stride_A);
360 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
361 void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(
const ITensorInfo *a,
362 const ITensorInfo *
b,
363 const ITensorInfo *d,
364 const AsmGemmInfo &
info)
371 zeropad = a->quantization_info().uniform().offset;
374 const int64_t input_width =
static_cast<int64_t
>(a->tensor_shape()[1]);
375 const int64_t input_height =
static_cast<int64_t
>(a->tensor_shape()[2]);
376 const int64_t input_channels =
static_cast<int64_t
>(a->tensor_shape()[0]);
377 const int64_t kernel_width =
static_cast<int64_t
>(
b->tensor_shape()[2]);
378 const int64_t kernel_height =
static_cast<int64_t
>(
b->tensor_shape()[3]);
379 const int64_t output_width =
static_cast<int64_t
>(d->tensor_shape()[1]);
380 const int64_t output_height =
static_cast<int64_t
>(d->tensor_shape()[2]);
389 info.ps_info.stride().first,
390 info.ps_info.stride().second,
397 _gemm_kernel_asm->set_convolution_parameters(_cp);
402 const unsigned int multis = 1;
403 const unsigned int batches = a->tensor_shape().total_size_upper(3);
407 using TypeInputPtr = TypeInput *;
408 const int batch_size = kernel_hw * output_hw *
sizeof(TypeInputPtr);
409 const size_t batch_stride = batch_size /
sizeof(TypeInputPtr);
410 const int multi_size = batch_size *
batches;
411 const size_t multi_stride = multi_size /
sizeof(TypeInputPtr);
413 _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
414 reinterpret_cast<const TypeInput **
>(malloc(multi_size *
multis)));
415 _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
416 reinterpret_cast<const TypeInput *
const **
>(malloc(
sizeof(TypeInput **) * kernel_hw *
multis *
batches)));
417 _indirect_pad = std::vector<TypeInput>(_cp.
input_channels, TypeInput(zeropad));
421 for (int64_t m = 0; m <
multis; m++)
425 for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
427 (_indirect_arg.get())[pos++] =
428 _indirect_buf.get() + m * multi_stride +
b * batch_stride + kernel_xy * output_hw;
433 _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
437 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
439 const ITensorInfo *
b,
440 const ITensorInfo *c,
443 const AsmGemmInfo &gemm_info,
444 const OutputStage &os)
448 _is_b_constant =
b->are_values_constant();
449 _is_c_constant = c ? c->are_values_constant() :
true;
451 _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(
args, os);
452 if (_gemm_kernel_asm ==
nullptr)
461 auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
463 acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.
filter);
464 const size_t workspace_size = _gemm_kernel_asm->get_working_size();
465 const unsigned int alignment = 4096;
466 _workspace_info = TensorInfo(TensorShape(workspace_size), 1,
DataType::U8);
467 _aux_mem[AsmGemmWorkspace] =
473 const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
474 if (window_size <
static_cast<unsigned int>(
args._maxthreads))
476 _gemm_kernel_asm->set_nthreads(window_size);
480 _optimised_kernel = std::move(acl_gemm_wrapper);
481 _gemm_info = gemm_info;
483 const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
484 if (run_pre_pretranspose_b)
486 _pre_pretranspose_b = std::make_unique<CpuTranspose>();
487 _pre_pretranspose_b->configure(
b, &_pre_pretransposed_b_info);
491 if (_gemm_kernel_asm->B_pretranspose_required())
501 lifetime = MemoryLifetime::Persistent;
507 lifetime = MemoryLifetime::Temporary;
510 const unsigned int alignment = 128;
511 _aux_mem[PrePretransposedB] =
516 if (_gemm_kernel_asm->B_pretranspose_required())
522 const unsigned int alignment = 128;
523 const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
524 _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1,
DataType::U8);
525 _aux_mem[Pretranspose] =
527 _B_pretranspose_required =
true;
533 configure_indirect(a,
b, d, gemm_info);
537 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
538 void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
549 _gemm_kernel_asm->set_quantized_bias(
550 reinterpret_cast<const int32_t *
>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
552 const ITensor *b_to_use =
b;
554 const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
555 CpuAuxTensorHandler pre_pretransposed_b(
556 offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
560 !run_pre_pretranspose_b);
561 if (run_pre_pretranspose_b)
564 ITensorPack pre_pretranspose_pack{{
ACL_SRC, b_to_use}, {
ACL_DST, pre_pretransposed_b.get()}};
565 _pre_pretranspose_b->run(pre_pretranspose_pack);
566 b_to_use = pre_pretransposed_b.get();
570 if (_gemm_kernel_asm->B_pretranspose_required())
575 const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
576 const auto in1_ptr =
reinterpret_cast<const TypeInput *
>(b_to_use->buffer() +
577 b_to_use->info()->offset_first_element_in_bytes());
578 const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
580 CpuAuxTensorHandler pretranspose(
offset_int_vec(Pretranspose), _pretranspose_info, tensors,
false);
582 run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
583 in1_ptr, ldb, multi_stride_b,
592 prepare_indirect_buffer(tensors);
599 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
600 bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured()
const
602 return _optimised_kernel !=
nullptr;
605 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
611 template <
typename TypeInput,
typename TypeOutput,
class OutputStage>
620 int lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
622 const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
624 const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
625 const size_t a_multi_idx = a_batch_idx + 1;
626 const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
627 const size_t d_multi_idx = d_batch_idx + 1;
629 int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size();
630 const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size();
632 int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size();
633 int multi_stride_b = 0;
634 const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
636 auto in0_ptr =
reinterpret_cast<const TypeInput *
>(a->buffer() + a->info()->offset_first_element_in_bytes());
637 const TypeInput *in1_ptr =
nullptr;
638 auto out_ptr =
reinterpret_cast<TypeOutput *
>(d->buffer() + d->info()->offset_first_element_in_bytes());
640 const ITensor *b_to_use =
b;
643 const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
644 CpuAuxTensorHandler pre_pretransposed_b(
645 offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
647 !run_pre_pretranspose_b );
648 if (b_to_use && !_is_b_constant && run_pre_pretranspose_b)
651 ITensorPack pre_pretranspose_pack{{
ACL_SRC, b_to_use}, {
ACL_DST, pre_pretransposed_b.get()}};
652 _pre_pretranspose_b->run(pre_pretranspose_pack);
653 b_to_use = pre_pretransposed_b.get();
657 if (b_to_use && !_gemm_kernel_asm->B_is_pretransposed())
659 ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
660 multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
662 reinterpret_cast<const TypeInput *
>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
666 if ((b_to_use && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() ==
DataType::S32))
670 _gemm_kernel_asm->set_quantized_bias(
671 reinterpret_cast<const int32_t *
>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
675 if (b_to_use && _B_pretranspose_required)
680 const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
681 const auto b_ptr =
reinterpret_cast<const TypeInput *
>(b_to_use->buffer() +
682 b_to_use->info()->offset_first_element_in_bytes());
683 const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
685 CpuAuxTensorHandler pretranspose(
offset_int_vec(Pretranspose), _pretranspose_info, tensors,
true);
690 _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
694 run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
695 b_ptr, ldb, multi_stride_b,
701 const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.
method, d->info()->data_type());
704 CpuAuxTensorHandler workspace(
offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors,
false);
705 if (workspace.get()->buffer() !=
nullptr)
707 _gemm_kernel_asm->set_working_space(
reinterpret_cast<void *
>(workspace.get()->buffer()));
708 const unsigned int split_dim = scheduling_hint.split_dimension();
709 const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
711 if (window_size < num_threads)
713 num_threads = window_size;
718 const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
719 num_threads = std::min(num_iterations, num_threads);
721 _gemm_kernel_asm->set_nthreads(num_threads);
728 TypeOutput *
bias =
nullptr;
731 bias =
reinterpret_cast<TypeOutput *
>(c->buffer() + c->info()->offset_first_element_in_bytes());
743 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr,
744 ldd, batch_stride_d, multi_stride_d,
bias, 0);
749 template <
typename TypeInput,
typename TypeOutput>
750 void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm,
751 const ITensorInfo *a,
752 const ITensorInfo *
b,
753 const ITensorInfo *c,
756 const AsmGemmInfo &
info)
758 Params p = extract_parameters(a,
b, d,
info);
764 arm_gemm::GemmArgs args(&
ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
765 info.fixed_format,
info.fast_mode, &cfg);
768 auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
769 fallback->configure(a,
b, c, d,
args,
info);
773 template <
typename TypeInput,
typename TypeOutput>
774 void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm,
775 const ITensorInfo *a,
776 const ITensorInfo *
b,
777 const ITensorInfo *c,
780 const AsmGemmInfo &
info)
783 Params p = extract_parameters(a,
b, d,
info);
789 arm_gemm::GemmArgs args(&
ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
790 info.fixed_format,
info.fast_mode, &cfg);
793 auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
796 const int32_t negation =
info.negated_offsets ? 1 : -1;
797 const int32_t a_offset = -a->quantization_info().uniform().offset * negation;
798 const int32_t b_offset = -
b->quantization_info().uniform().offset * negation;
799 const GEMMLowpOutputStageInfo os_info =
info.output_stage;
802 if (os_info.gemmlowp_shifts.size() > 1)
804 const auto requantize_data =
805 fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
807 nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset,
808 (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) :
nullptr, std::get<2>(requantize_data),
809 std::get<3>(requantize_data), os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
815 os_info.gemmlowp_multiplier, os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
819 fallback->configure(a,
b, c, d,
args,
info, gemm_requant_info);
838 Params p = extract_parameters(a,
b, d,
info);
844 arm_gemm::GemmArgs args(&
ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
845 info.fixed_format,
info.fast_mode, &cfg);
851 !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf,
args, {})),
852 "We could not find an optimized kernel for F32 input");
860 !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf,
args, {})),
861 "We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
866 !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
args, {})),
867 "We could not find an optimized kernel for U8 input and U8 output");
875 !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf,
args, {})),
876 "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
881 !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
args, {})),
882 "We could not find an optimized kernel for S8 input and S8 output");
886 #if defined(ARM_COMPUTE_ENABLE_BF16)
890 !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf,
args, {})),
891 "We could not find an optimized kernel for BFLOAT16 input and F32 output");
895 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
898 !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf,
args, {})),
899 "We could not find an optimized kernel for F16 input and F16 output");
919 "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
944 "Only F32 output supported for F32 input");
946 "Only F16 output supported for F16 input");
948 "Only F32 output supported for BFLOAT16 input");
950 "Only U32 output supported for U8 input");
952 "Only S32 output supported for S8 input");
955 "Only QASYMM8/S32 output supported for QASYMM8 input");
964 (expected_weight_format !=
info.weight_format),
965 "The format expected by the kernel does not correspond with the one requested by the user.");
991 create_arm_gemm<float, float>(_arm_gemm, a,
b, c, d, act,
info);
998 create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a,
b, c, d, act,
info);
1002 create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a,
b, c, d, act,
info);
1009 create_arm_gemm<int8_t, int32_t>(_arm_gemm, a,
b, c, d, act,
info);
1013 create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a,
b, c, d, act,
info);
1017 #if defined(ARM_COMPUTE_ENABLE_BF16)
1019 create_arm_gemm<bfloat16, float>(_arm_gemm, a,
b, c, d, act,
info);
1022 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1024 create_arm_gemm<float16_t, float16_t>(_arm_gemm, a,
b, c, d, act,
info);
1035 _arm_gemm->prepare(tensors);
1040 return _arm_gemm && _arm_gemm->is_configured();
1046 _arm_gemm->run(tensors);
1052 return _arm_gemm->workspace();