45 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info,
46 const ITensorInfo *mm_input,
47 const ITensorInfo *mm_weights,
48 const ITensorInfo *
bias,
50 const TensorInfo *mm_res_info,
51 const TensorInfo *outstage_tensor_info)
55 gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
62 Status NEQLSTMLayer::validate_layer_norm(
const ITensorInfo &in,
const ITensorInfo &weight,
const ITensorInfo &
bias)
66 const TensorInfo out{in};
70 void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g,
const ITensor *in)
74 Tensor &out = get_layer_norm_output(g);
75 _memory_group.manage(&out);
76 out.allocator()->init(*(in->info()));
78 get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
79 get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
82 NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() =
default;
100 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
106 Iterator input_iter{_src, _window};
107 Iterator output_iter{_dst, _window};
110 _window, [&](
const Coordinates &) { memcpy(output_iter.ptr(), input_iter.ptr(), _row_size); }, input_iter,
118 _dequantize_input_to_forget_weights(),
119 _quantize_input_to_forget_weights(),
120 _transpose_input_to_forget_weights(),
121 _transpose_input_to_cell_weights(),
122 _transpose_input_to_output_weights(),
123 _transpose_input_to_input_weights(),
124 _transpose_recurrent_to_forget_weights(),
125 _transpose_recurrent_to_cell_weights(),
126 _transpose_recurrent_to_output_weights(),
127 _transpose_recurrent_to_input_weights(),
128 _transpose_projection_weights(),
129 _input_to_input_reduction(),
130 _recurrent_to_input_reduction(),
131 _input_to_forget_reduction(),
132 _recurrent_to_forget_reduction(),
133 _input_to_cell_reduction(),
134 _recurrent_to_cell_reduction(),
135 _input_to_output_reduction(),
136 _recurrent_to_output_reduction(),
137 _projection_reduction(),
138 _projection_bias_add(),
139 _mm_input_to_forget(),
140 _mm_recurrent_to_forget(),
141 _pixelwise_mul_cell_to_forget(),
142 _input_to_forget_outstage(),
143 _recurrent_to_forget_outstage(),
144 _cell_to_forget_outstage(),
145 _accumulate_input_recurrent_forget(),
146 _accumulate_cell_forget(),
147 _forget_gate_sigmoid(),
149 _input_to_cell_outstage(),
150 _mm_recurrent_to_cell(),
151 _recurrent_to_cell_outstage(),
152 _accumulate_input_recurrent_modulation(),
155 _mm_input_to_input(),
156 _input_to_input_outstage(),
157 _mm_recurrent_to_input(),
158 _recurrent_to_input_outstage(),
159 _accumulate_input_recurrent_input(),
160 _pixelwise_mul_cell_to_input(),
161 _cell_to_input_outstage(),
162 _accumulate_cell_input(),
163 _input_gate_sigmoid(),
164 _pixelwise_mul_forget_cell(),
165 _pixelwise_mul_input_cell(),
168 _mm_input_to_output(),
169 _input_to_output_outstage(),
170 _mm_recurrent_to_output(),
171 _recurrent_to_output_outstage(),
172 _accumulate_input_recurrent_output(),
173 _pixelwise_mul_cell_to_output(),
174 _cell_to_output_outstage(),
175 _accumulate_cell_to_output(),
176 _output_gate_sigmoid(),
178 _pixelwise_mul_hidden(),
181 _projection_outstage(),
182 _accumulate_projection(),
184 _projection_bias_copy(),
185 _projection_output_to_accumulate_copy(),
186 _projection_accumulate_to_output_copy(),
187 _hidden_to_output_copy(),
190 _layer_norm_weights(),
194 _memory_group =
MemoryGroup(std::move(memory_manager));
205 float gemmlowp_scale,
209 _memory_group.
manage(mm_res);
210 _memory_group.
manage(outstage_res);
216 mm.
configure(mm_input, mm_weights,
nullptr, mm_res);
221 outstage.
configure(mm_res,
bias, outstage_res, gemmlowp_info);
245 cell_state_out, output_state_out);
250 cell_state_out, output_state_out);
270 _convert_input_to_forget_weights_to_qsymm8 =
true;
283 _quantize_input_to_forget_weights.
configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
289 cell_state_in->
info(), output_state_in->
info(), cell_state_out->
info(), output_state_out->
info(),
290 output->
info(), lstm_params_info));
298 cell_state_in->
info(), output_state_in->
info(), cell_state_out->
info(), output_state_out->
info(),
299 output->
info(), lstm_params_info));
302 const int batch_size =
input->info()->dimension(1);
312 ? &_input_to_forget_weights_symm8
331 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
332 set_layer_norm_bias(lstm_params.
input_gate_bias(), LayerNormGate::Input);
341 const int32_t cell_shift = log2(qcell_state_in.
scale);
344 int16_t quantized_cell_clip = 0;
349 _has_cell_clipping = quantized_cell_clip > 0;
357 _input_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
358 _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
359 _input_to_input_reduction->configure(_input_to_input_weights->
info(), _input_to_input_eff_bias.
info(),
361 _recurrent_to_input_reduction->configure(
362 _recurrent_to_input_weights->
info(), _recurrent_to_input_eff_bias.
info(),
366 _input_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
367 _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
368 _input_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
369 _recurrent_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
370 _input_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
371 _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
375 _recurrent_to_forget_reduction->configure(
380 _recurrent_to_cell_reduction->configure(
385 _recurrent_to_output_reduction->configure(
390 _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
391 _projection_reduction->configure(
392 _projection_weights->
info(), _projection_eff_bias.
info(),
394 if (_projection_bias !=
nullptr)
396 _projection_bias_add.
configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias,
406 &_recurrent_to_forget_weights_transposed);
409 &_recurrent_to_output_weights_transposed);
413 &_input_to_input_weights_transposed);
415 &_recurrent_to_input_weights_transposed);
419 _transpose_projection_weights.
configure(_projection_weights, &_projection_weights_transposed);
434 configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
input,
435 &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias, &_mm_input_to_forget_res,
436 &_input_to_forget_outstage_res, input_to_forget_scale, mm_out_info, forget_gate_outstage_info);
440 configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info, output_state_in,
441 &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias, &_mm_recurrent_to_forget_res,
442 &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale, mm_out_info, forget_gate_outstage_info);
444 _accumulate_input_recurrent_forget.
configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
451 _memory_group.
manage(&_mul_cell_to_forget_res);
458 _memory_group.
manage(&_cell_to_forget_outstage_res);
459 const float cell_to_forget_scale =
460 std::pow(2, cell_shift) *
465 _cell_to_forget_outstage.
configure(&_mul_cell_to_forget_res,
nullptr, &_cell_to_forget_outstage_res,
468 _accumulate_cell_forget.
configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res,
473 Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
477 configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
479 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
486 _memory_group.
manage(&_forget_gate);
488 _forget_gate_sigmoid.
configure(forget_activation_input, &_forget_gate,
495 const float input_to_cell_scale =
input_to_cell_weights->info()->quantization_info().uniform().scale *
497 configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
input, &_input_to_cell_weights_transposed,
498 &_input_to_cell_eff_bias, &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
499 mm_out_info, cell_outstage_info);
503 configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info, output_state_in,
504 &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias, &_mm_recurrent_to_cell_res,
505 &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale, mm_out_info, cell_outstage_info);
507 _accumulate_input_recurrent_modulation.
configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
511 Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
515 configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
517 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
522 _memory_group.
manage(&_cell_gate);
524 _cell_gate_tanh.
configure(cell_activation_input, &_cell_gate,
531 _memory_group.
manage(&_input_gate);
544 configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
input,
545 &_input_to_input_weights_transposed, &_input_to_input_eff_bias, &_mm_input_to_input_res,
546 &_input_to_input_outstage_res, input_to_input_scale, mm_out_info, input_outstage_info);
548 const float recurrent_to_input_scale =
551 configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info, output_state_in,
552 &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
553 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
554 mm_out_info, input_outstage_info);
555 _accumulate_input_recurrent_input.
configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res,
563 _memory_group.
manage(&_mul_cell_to_input_res);
567 const float cell_to_input_scale =
568 std::pow(2, cell_shift) *
576 _memory_group.
manage(&_cell_to_input_outstage_res);
577 _cell_to_input_outstage.
configure(&_mul_cell_to_input_res,
nullptr, &_cell_to_input_outstage_res,
580 _accumulate_cell_input.
configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res,
585 Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
589 configure_layer_norm(LayerNormGate::Input, input_activation_input);
591 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
594 _input_gate_sigmoid.
configure(input_activation_input, &_input_gate,
603 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
606 _memory_group.
manage(&_mul_input_cell_res);
614 if (_has_cell_clipping)
616 _cell_clip.
configure(cell_state_out,
nullptr,
618 -quantized_cell_clip, quantized_cell_clip));
625 configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
input,
626 &_input_to_output_weights_transposed, &_input_to_output_eff_bias, &_mm_input_to_output_res,
627 &_input_to_output_outstage_res, input_to_output_scale, mm_out_info, output_outstage_info);
631 configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info, output_state_in,
632 &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias, &_mm_recurrent_to_output_res,
633 &_recurrent_to_output_outstage_res, recurrent_to_output_scale, mm_out_info, output_outstage_info);
635 _accumulate_input_recurrent_output.
configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res,
644 _memory_group.
manage(&_mul_cell_to_output_res);
649 const float cell_to_output_scale =
650 std::pow(2, cell_shift) *
658 _memory_group.
manage(&_cell_to_output_outstage_res);
659 _cell_to_output_outstage.
configure(&_mul_cell_to_output_res,
nullptr, &_cell_to_output_outstage_res,
663 _accumulate_cell_to_output.
configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res,
668 Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
672 configure_layer_norm(LayerNormGate::Output, output_activation_input);
674 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
678 _memory_group.
manage(&_output_gate);
680 _output_gate_sigmoid.
configure(output_activation_input, &_output_gate,
685 _hidden_tanh.
configure(cell_state_out, &_input_gate,
688 _memory_group.
manage(&_hidden_mul_res);
695 const float hidden_state_scale = std::pow(2, -15) / lstm_params.
hidden_state_scale() * std::pow(2, -15);
701 _projection_tensor_copy_required = (num_units !=
output_size);
702 ITensor *hidden_gate_result = output_state_out;
704 _memory_group.
manage(&_hidden_gate);
706 if (_projection_tensor_copy_required)
710 hidden_gate_result = &_hidden_gate;
713 _hidden_outstage.
configure(&_hidden_mul_res,
nullptr, hidden_gate_result, gemmlowp_info);
719 const TensorInfo projection_outstage_info(*output_state_out->
info());
727 TensorInfo projection_mm_out_info{mm_out_info};
730 configure_mm(_mm_projection, _projection_outstage, gemmlowp_info, hidden_gate_result,
731 &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res,
732 &_projection_outstage_res, projection_scale, projection_mm_out_info, projection_outstage_info);
734 ITensor *accumulate_destination = output_state_out;
736 if (_projection_tensor_copy_required)
741 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
742 accumulate_destination = &_projection_accumulate_res;
745 _accumulate_projection.
configure(&_projection_outstage_res, accumulate_destination, accumulate_destination,
749 if (_projection_tensor_copy_required)
751 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
755 int8_t quantized_projection_clip{0};
758 quantized_projection_clip =
762 if (quantized_projection_clip > 0)
764 _projection_clip.
configure(output_state_out,
nullptr,
766 -quantized_projection_clip, quantized_projection_clip));
767 _has_projection_clipping =
true;
772 if (_projection_tensor_copy_required)
774 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
780 _copy_output.
configure(output_state_out, output);
803 cell_state_in, output_state_in, cell_state_out, output_state_out, output);
809 const unsigned int batch_size =
input->dimension(1);
811 const unsigned int output_size = output_state_out->
dimension(_out_state_output_size_dimension_idx);
881 const int32_t cell_shift = log2(qcell_state_in.
scale);
885 int16_t quantized_cell_clip = 0;
960 const TensorInfo recurrent_to_input_weights_transposed(
996 input_to_forget_scale, &mm_out_info, &forget_outstage_info));
1001 &eff_bias_info, recurrent_to_forget_scale, &mm_out_info,
1002 &forget_outstage_info));
1014 const float cell_to_forget_scale = std::pow(2, cell_shift) *
1047 input_to_cell_scale, &mm_out_info, &cell_outstage_info));
1052 &eff_bias_info, recurrent_to_cell_scale, &mm_out_info,
1053 &cell_outstage_info));
1075 "Input gate bias must not be present when CIFG is used");
1105 const float input_to_input_scale = lstm_params.
input_to_input_weights()->quantization_info().uniform().scale *
1108 input_to_input_scale, &mm_out_info, &input_outstage_info));
1110 const float recurrent_to_input_scale =
1114 &eff_bias_info, recurrent_to_input_scale, &mm_out_info,
1115 &input_outstage_info));
1125 const float cell_to_input_scale = std::pow(2, cell_shift) *
1154 if (quantized_cell_clip > 0)
1159 -quantized_cell_clip, quantized_cell_clip)));
1168 input_to_output_scale, &mm_out_info, &output_outstage_info));
1173 &eff_bias_info, recurrent_to_output_scale, &mm_out_info,
1174 &output_outstage_info));
1215 const float hidden_state_scale = std::pow(2, -15) / lstm_params.
hidden_state_scale() * std::pow(2, -15);
1224 const bool projection_tensor_copy_required = num_units !=
output_size;
1242 const TensorInfo projection_outstage_info(*output_state_out);
1247 TensorInfo projection_mm_out_info{mm_out_info};
1251 &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
1252 &projection_outstage_info));
1254 if (projection_tensor_copy_required)
1263 if (projection_tensor_copy_required)
1269 int8_t quantized_projection_clip{0};
1275 if (quantized_projection_clip > 0)
1278 output_state_out,
nullptr,
1280 -quantized_projection_clip, quantized_projection_clip)));
1285 if (projection_tensor_copy_required)
1315 _mm_input_to_forget.
run();
1316 _input_to_forget_outstage.
run();
1318 _mm_recurrent_to_forget.
run();
1319 _recurrent_to_forget_outstage.
run();
1320 _accumulate_input_recurrent_forget.
run();
1324 _pixelwise_mul_cell_to_forget.
run();
1325 _cell_to_forget_outstage.
run();
1326 _accumulate_cell_forget.
run();
1329 if (_has_layer_norm)
1334 _forget_gate_sigmoid.
run();
1337 _mm_input_to_cell.
run();
1338 _input_to_cell_outstage.
run();
1340 _mm_recurrent_to_cell.
run();
1341 _recurrent_to_cell_outstage.
run();
1342 _accumulate_input_recurrent_modulation.
run();
1344 if (_has_layer_norm)
1349 _cell_gate_tanh.
run();
1354 _input_gate_sub.
run();
1358 _mm_input_to_input.
run();
1359 _input_to_input_outstage.
run();
1360 _mm_recurrent_to_input.
run();
1361 _recurrent_to_input_outstage.
run();
1362 _accumulate_input_recurrent_input.
run();
1366 _pixelwise_mul_cell_to_input.
run();
1367 _cell_to_input_outstage.
run();
1368 _accumulate_cell_input.
run();
1371 if (_has_layer_norm)
1376 _input_gate_sigmoid.
run();
1380 _pixelwise_mul_forget_cell.
run();
1381 _pixelwise_mul_input_cell.
run();
1382 _add_forget_cell.
run();
1384 if (_has_cell_clipping)
1390 _mm_input_to_output.
run();
1391 _input_to_output_outstage.
run();
1392 _mm_recurrent_to_output.
run();
1393 _recurrent_to_output_outstage.
run();
1394 _accumulate_input_recurrent_output.
run();
1397 _pixelwise_mul_cell_to_output.
run();
1398 _cell_to_output_outstage.
run();
1399 _accumulate_cell_to_output.
run();
1402 if (_has_layer_norm)
1407 _output_gate_sigmoid.
run();
1411 _pixelwise_mul_hidden.
run();
1412 _hidden_outstage.
run();
1415 if (_has_projection)
1417 _mm_projection.
run();
1418 _projection_outstage.
run();
1420 if (_projection_tensor_copy_required)
1422 _projection_output_to_accumulate_copy.run();
1425 _accumulate_projection.
run();
1427 if (_projection_tensor_copy_required)
1429 _projection_accumulate_to_output_copy.run();
1432 if (_has_projection_clipping)
1434 _projection_clip.
run();
1439 if (_projection_tensor_copy_required)
1441 _hidden_to_output_copy.run();
1453 if (_convert_input_to_forget_weights_to_qsymm8)
1457 _dequantize_input_to_forget_weights.
run();
1458 _quantize_input_to_forget_weights.
run();
1468 _transpose_input_to_forget_weights.
run();
1469 _transpose_input_to_cell_weights.
run();
1470 _transpose_input_to_output_weights.
run();
1471 _transpose_recurrent_to_forget_weights.
run();
1472 _transpose_recurrent_to_cell_weights.
run();
1473 _transpose_recurrent_to_output_weights.
run();
1478 std::fill_n(
reinterpret_cast<int16_t *
>(_ones.
buffer()),
1489 _input_to_input_reduction->window(), packII);
1494 _recurrent_to_input_reduction->window(), packRI);
1498 _transpose_input_to_input_weights.
run();
1499 _transpose_recurrent_to_input_weights.
run();
1513 _input_to_forget_reduction->window(), packIF);
1518 _recurrent_to_forget_reduction->window(), packRF);
1528 _recurrent_to_cell_reduction->window(), packRC);
1533 _input_to_output_reduction->window(), packIO);
1538 _recurrent_to_output_reduction->window(), packRO);
1540 if (_has_projection)
1547 if (_projection_bias !=
nullptr)
1549 _projection_bias_add.
run();
1554 _transpose_projection_weights.
run();
1557 if (!_projection_tensor_copy_required)
1572 _is_prepared =
true;