53 Status get_gemmlowp_output_stage_info(
const ITensorInfo *
src,
54 const ITensorInfo *weights,
55 const ITensorInfo *
dst,
56 const ActivationLayerInfo &act,
57 GEMMLowpOutputStageInfo &gemmlowp_output_stage_info)
60 const QuantizationInfo oq_info =
dst->quantization_info();
61 const UniformQuantizationInfo iq_unif =
src->quantization_info().uniform();
62 const UniformQuantizationInfo wq_unif = weights->quantization_info().uniform();
63 const UniformQuantizationInfo oq_unif = oq_info.uniform();
65 float multiplier = (iq_unif.scale * wq_unif.scale) / oq_unif.scale;
66 int32_t output_multiplier;
76 gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
77 gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
78 gemmlowp_output_stage_info.gemmlowp_offset = oq_unif.offset;
80 gemmlowp_output_stage_info.gemmlowp_min_bound = type_min;
81 gemmlowp_output_stage_info.gemmlowp_max_bound = type_max;
86 Status validate_mm(
const ITensorInfo *
src,
87 const ITensorInfo *weights,
88 const ITensorInfo *biases,
89 const ITensorInfo *
dst,
90 const ActivationLayerInfo &act,
91 bool enable_fast_math,
98 const QuantizationInfo src_quantization_info(
src->quantization_info().uniform().scale,
99 -
src->quantization_info().uniform().offset);
100 const QuantizationInfo weights_quantization_info(weights->quantization_info().uniform().scale,
101 -weights->quantization_info().uniform().offset);
103 GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
107 gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
108 gemm_info.set_fast_math(enable_fast_math);
111 TensorInfo
src_info =
src->clone()->set_quantization_info(src_quantization_info);
112 TensorInfo
weights_info = weights->clone()->set_quantization_info(weights_quantization_info);
119 gemm_info.set_weight_format(weight_format);
121 gemm_info.set_fast_math(enable_fast_math);
131 _convert_weights(nullptr),
132 _transpose_weights(nullptr),
134 _mm_gemmlowp(nullptr),
136 _converted_weights(),
139 _trans_weights_idx(AuxTensorIdx::Count),
141 _needs_weights_conversion(false),
142 _needs_weights_reshape(false),
143 _is_fc_after_conv(false),
144 _is_quantized_asymmetric(false),
146 _enable_fast_math(false),
147 _fixed_format(false),
149 _dynamic_weights(false)
161 if (_is_quantized_asymmetric)
166 -
src->quantization_info().uniform().offset);
183 _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
194 _mm_gemm = std::make_unique<CpuGemm>();
195 _mm_gemm->configure(
src, weights, biases,
dst, 1.f, 1.0f, gemm_info);
199 void CpuFullyConnected::configure_conv_fc(
const ITensorInfo *
src,
200 const ITensorInfo *weights,
201 const ITensorInfo *biases,
203 const ActivationLayerInfo &act)
212 _flatten = std::make_unique<CpuFlatten>();
213 _flatten->configure(
src, &_flattened_src);
216 configure_mm(&_flattened_src, weights, biases,
dst, act);
219 void CpuFullyConnected::configure_fc_fc(
const ITensorInfo *
src,
220 const ITensorInfo *weights,
221 const ITensorInfo *biases,
223 const ActivationLayerInfo &act)
228 configure_mm(
src, weights, biases,
dst, act);
244 _needs_weights_conversion =
false;
247 _is_fc_after_conv =
true;
249 _is_prepared =
false;
250 _trans_weights_idx = AuxTensorIdx::Count;
265 const bool is_batched_fc_layer =
dst->dimension(1) > 1;
266 if (is_batched_fc_layer)
269 (std::equal(
src->tensor_shape().cbegin() + 3,
src->tensor_shape().cend(),
270 dst->tensor_shape().cbegin() + 1));
274 _is_fc_after_conv =
src->num_dimensions() > 1;
278 if (_needs_weights_reshape)
281 _transpose_weights = std::make_unique<kernels::CpuTransposeKernel>();
282 _transpose_weights->configure(weights, &_reshaped_weights);
285 weights_to_use = &_reshaped_weights;
286 _trans_weights_idx = AuxTensorIdx::TransposedWeights;
293 _convert_weights = std::make_unique<CpuConvertFullyConnectedWeights>();
294 _convert_weights->configure(weights_to_use, &_converted_weights,
src->tensor_shape(),
298 weights_to_use = &_converted_weights;
299 _needs_weights_conversion =
true;
300 _trans_weights_idx = AuxTensorIdx::ConvertedWeights;
303 if (_is_fc_after_conv)
315 if (_needs_weights_reshape || _needs_weights_conversion)
317 _trans_weights = *weights_to_use;
321 auto gemm_mem_req = (_is_quantized_asymmetric) ? _mm_gemmlowp->workspace() : _mm_gemm->workspace();
322 for (
unsigned int i = 0; i < gemm_mem_req.size(); ++i)
324 _aux_mem[i] = gemm_mem_req[i];
327 if (_aux_mem[Pretranspose].size > 0)
334 _dynamic_weights ? MemoryLifetime::Temporary
335 : (_is_quantized_asymmetric && biases && !(biases->
are_values_constant())) ? MemoryLifetime::Persistent
346 _dynamic_weights ? MemoryLifetime::Temporary
348 : MemoryLifetime::Persistent,
352 offset_int_vec(ConvertedWeights), _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Persistent,
355 _aux_mem[FlattenedSrc] =
407 bool is_fc_after_conv =
true;
413 const ITensorInfo &converted_weights = weights_reshaped
427 const bool is_batched_fc_layer =
dst->dimension(1) > 1;
429 if (biases !=
nullptr)
442 if (is_batched_fc_layer)
445 (std::equal(
src->tensor_shape().cbegin() + 3,
src->tensor_shape().cend(),
446 dst->tensor_shape().cbegin() + 1));
450 is_fc_after_conv =
src->num_dimensions() > 1;
453 if (!weights_reshaped)
457 weights_to_use = &reshaped_weights;
465 weights_to_use = &converted_weights;
468 if (is_fc_after_conv)
472 (weights_to_use->
dimension(1) != (
src->dimension(0) *
src->dimension(1) *
src->dimension(2))));
476 src_to_use = &flatten_src;
494 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
497 #endif // ARM_COMPUTE_ASSERTS_ENABLED
505 if (_is_fc_after_conv)
508 _flatten->run(flatten_pack);
513 if (_needs_weights_reshape || _needs_weights_conversion)
519 if (_is_quantized_asymmetric)
521 _mm_gemmlowp->run(gemm_pack);
525 _mm_gemm->run(gemm_pack);
531 if (!_is_prepared || _dynamic_weights)
533 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
534 ++_asrt_prepare_count;
536 #endif // ARM_COMPUTE_ASSERTS_ENABLED
544 const ITensor *cur_weights = weights;
547 if (_needs_weights_reshape)
555 cur_weights = reshaped_weights.
get();
559 if (_needs_weights_conversion)
562 _convert_weights->run(convert_pack);
565 cur_weights = converted_weights.
get();
572 if (!_is_quantized_asymmetric)
574 _mm_gemm->prepare(gemm_pack);
578 _mm_gemmlowp->prepare(gemm_pack);