45 cpu::AsmGemmInfo init_assembly_metadata(
const GEMMInfo &
info)
47 cpu::AsmGemmInfo asm_info;
49 asm_info.reinterpret_input_as_3d =
info.reinterpret_input_as_3d();
50 asm_info.depth_output_gemm3d =
info.depth_output_gemm3d();
51 asm_info.activation_info =
info.activation_info();
52 asm_info.fast_mode =
info.fast_math();
53 asm_info.fixed_format =
info.fixed_format();
54 asm_info.weight_format =
info.weight_format();
67 const bool is_c_bias = beta == 1 && c !=
nullptr;
69 (c ==
nullptr || beta == 0.f || beta == 1.f) &&
70 !(!
b->are_values_constant() &&
b->tensor_shape().z() > 1);
74 _reshape_b_only_on_first_run =
b->are_values_constant();
75 _run_vector_matrix_multiplication = a->
dimension(1) < 2;
76 _run_alpha_scale = alpha != 1.f;
77 _run_bias_addition = is_c_bias;
78 _run_addition = beta != 0 && beta != 1 && c !=
nullptr;
83 const ITensorInfo *c_to_use = is_c_bias ? c :
nullptr;
84 _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
85 _asm_glue->configure(a,
b, c_to_use, d, asm_info);
88 auto asm_mem_req = _asm_glue->workspace();
89 _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
90 _aux_mem[Pretraspose] = asm_mem_req[Pretraspose];
95 _alpha_scale_func = std::make_unique<cpu::CpuActivation>();
96 _alpha_scale_func->configure(d,
nullptr,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
102 ITensorInfo *gemm_output_to_use = (_run_bias_addition) ? &_tmp_d : d;
104 _mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
107 if(_run_vector_matrix_multiplication)
110 _mm_kernel->configure(a,
b, gemm_output_to_use, alpha,
false);
115 const int n =
b->dimension(0);
119 _interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
120 _interleave_kernel->configure(a, &_tmp_a);
124 _transpose_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
125 _transpose_kernel->configure(
b, &_tmp_b);
129 _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha,
true,
GEMMReshapeInfo(
m,
n,
k));
132 if(_run_bias_addition)
134 _add_bias = std::make_unique<cpu::CpuAdd>();
143 _ma_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixAdditionKernel>();
144 _ma_kernel->configure(c, d, beta);
150 _activation_func = std::make_unique<cpu::CpuActivation>();
151 _activation_func->configure(d,
nullptr,
gemm_info.activation_info());
158 const bool is_c_bias = beta == 1 && c !=
nullptr;
159 const bool run_addition = c !=
nullptr && beta != 0 && beta != 1;
185 const size_t input_pad_right = (dim0_sz -
b->dimension(1)) %
block_by;
186 const size_t kernel_area = (dim0_sz -
b->dimension(1)) / input_pad_right;
187 ARM_COMPUTE_RETURN_ERROR_ON_MSG((dim0_sz - kernel_area * input_pad_right) !=
b->dimension(1),
"The product AB is defined only if A number of columns and B number of rows are related");
235 (c ==
nullptr || beta == 0.f || beta == 1.f) &&
236 !(!
b->are_values_constant() &&
b->tensor_shape().z() > 1);
244 const bool run_vector_matrix_multiplication = a->
dimension(1) < 2;
246 const bool run_interleave_transpose = !run_vector_matrix_multiplication && !
b->are_values_constant();
252 const int n =
b->dimension(0);
254 int mult_transpose1xW_width = 1;
255 int mult_interleave4x4_height = 1;
266 if(run_interleave_transpose)
268 matrix_a_info = &tmp_a_info;
269 matrix_b_info = &tmp_b_info;
315 if(_asm_glue && _asm_glue->is_configured())
320 _asm_glue->run(asm_pack);
324 _alpha_scale_func->run(
pack);
334 if(!_run_vector_matrix_multiplication)
340 if(!_reshape_b_only_on_first_run)
348 mm_pack.add_const_tensor(
ACL_SRC_0, interleaved_a.
get());
349 mm_pack.add_const_tensor(
ACL_SRC_1, transposed_b.
get());
355 if(_run_bias_addition)
358 _add_bias->run(
pack);
373 _activation_func->run(
pack);
381 if(_asm_glue && _asm_glue->is_configured())
383 _asm_glue->prepare(tensors);
385 else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication)
409 return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a,
b, c, d, asm_info);
412 bool CpuGemm::isVarWeightsKernel()
const
414 return _asm_glue && _asm_glue->isVarWeightsKernel();