46 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info,
47 const ITensorInfo *mm_input,
48 const ITensorInfo *mm_weights,
49 const ITensorInfo *
bias,
51 const TensorInfo *mm_res_info,
52 const TensorInfo *outstage_tensor_info)
56 gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
77 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
88 Iterator input_iter{_src, _window};
89 Iterator output_iter{_dst, _window};
92 _window, [&](
const Coordinates &) { memcpy(output_iter.ptr(), input_iter.ptr(), _row_size); }, input_iter,
112 for (
auto &norm : _layer_norms)
114 norm = std::make_unique<CLQLSTMLayerNormalizationKernel>();
117 _memory_group =
MemoryGroup(std::move(memory_manager));
122 void CLQLSTMLayer::configure_layer_norm(LayerNormGate g,
const ICLTensor *in)
126 CLTensor *out = &get_layer_norm_output(g);
127 _memory_group.
manage(out);
130 get_layer_norm(g).
configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
133 Status CLQLSTMLayer::validate_layer_norm(
const ITensorInfo &in,
const ITensorInfo &weight,
const ITensorInfo &
bias)
137 const TensorInfo out{in};
141 void CLQLSTMLayer::configure_mm(
const CLCompileContext &compile_context,
142 CLGEMMLowpMatrixMultiplyCore &mm,
143 CLGEMMLowpOutputStage &outstage,
144 GEMMLowpOutputStageInfo &gemmlowp_info,
145 const ICLTensor *mm_input,
146 const ICLTensor *mm_weights,
147 const ICLTensor *
bias,
149 CLTensor *outstage_res,
150 float gemmlowp_scale,
151 const TensorInfo &mm_res_info,
152 const TensorInfo &outstage_tensor_info)
154 _memory_group.
manage(mm_res);
155 _memory_group.
manage(outstage_res);
157 mm_res->allocator()->init(mm_res_info);
158 outstage_res->allocator()->init(outstage_tensor_info);
161 mm.configure(compile_context, mm_input, mm_weights,
nullptr, mm_res);
165 &gemmlowp_info.gemmlowp_shift);
166 outstage.configure(compile_context, mm_res,
bias, outstage_res, gemmlowp_info);
167 mm_res->allocator()->allocate();
190 output_state_in, cell_state_out, output_state_out, output, lstm_params);
214 cell_state_out, output_state_out, output);
219 cell_state_out, output_state_out, output, lstm_params);
229 output_state_in->
info(), cell_state_out->
info(), output_state_out->
info(), output->
info(), lstm_params_info));
231 const int batch_size =
input->info()->dimension(1);
258 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
259 set_layer_norm_bias(lstm_params.
input_gate_bias(), LayerNormGate::Input);
268 const int32_t cell_shift = log2(qcell_state_in.
scale);
271 int16_t quantized_cell_clip = 0;
276 _has_cell_clipping = quantized_cell_clip > 0;
284 _input_to_input_reduction->configure(compile_context, _input_to_input_weights->
info(),
285 _input_to_input_eff_bias.
info(),
287 _recurrent_to_input_reduction->configure(
288 compile_context, _recurrent_to_input_weights->
info(), _recurrent_to_input_eff_bias.
info(),
292 _input_to_forget_eff_bias.
info(),
294 _recurrent_to_forget_reduction->configure(
299 _recurrent_to_cell_reduction->configure(
303 _input_to_output_eff_bias.
info(),
305 _recurrent_to_output_reduction->configure(
310 _projection_reduction->configure(
311 compile_context, _projection_weights->
info(), _projection_eff_bias.
info(),
313 if (_projection_bias !=
nullptr)
315 _projection_bias_add.
configure(compile_context, _projection_bias, &_projection_eff_bias,
322 &_input_to_forget_weights_transposed);
324 &_input_to_cell_weights_transposed);
326 &_input_to_output_weights_transposed);
328 &_recurrent_to_forget_weights_transposed);
330 &_recurrent_to_cell_weights_transposed);
332 &_recurrent_to_output_weights_transposed);
336 &_input_to_input_weights_transposed);
338 &_recurrent_to_input_weights_transposed);
342 _transpose_projection_weights.
configure(compile_context, _projection_weights, &_projection_weights_transposed);
357 configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
input,
358 &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias, &_mm_input_to_forget_res,
359 &_input_to_forget_outstage_res, input_to_forget_scale, mm_out_info, forget_gate_outstage_info);
363 configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
364 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
365 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
366 mm_out_info, forget_gate_outstage_info);
368 _accumulate_input_recurrent_forget.
configure(compile_context, &_input_to_forget_outstage_res,
369 &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
376 _memory_group.
manage(&_mul_cell_to_forget_res);
383 _memory_group.
manage(&_cell_to_forget_outstage_res);
384 const float cell_to_forget_scale =
385 std::pow(2, cell_shift) *
390 _cell_to_forget_outstage.
configure(compile_context, &_mul_cell_to_forget_res,
nullptr,
391 &_cell_to_forget_outstage_res, gemmlowp_info);
393 _accumulate_cell_forget.
configure(compile_context, &_recurrent_to_forget_outstage_res,
394 &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
399 CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
403 configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
405 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
412 _memory_group.
manage(&_forget_gate);
414 _forget_gate_sigmoid.
configure(compile_context, forget_activation_input, &_forget_gate,
421 const float input_to_cell_scale =
input_to_cell_weights->info()->quantization_info().uniform().scale *
423 configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
input,
424 &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias, &_mm_input_to_cell_res,
425 &_input_to_cell_outstage_res, input_to_cell_scale, mm_out_info, cell_outstage_info);
429 configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info, output_state_in,
430 &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias, &_mm_recurrent_to_cell_res,
431 &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale, mm_out_info, cell_outstage_info);
433 _accumulate_input_recurrent_modulation.
configure(compile_context, &_input_to_cell_outstage_res,
434 &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
438 CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
442 configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
444 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
448 _memory_group.
manage(&_cell_gate);
450 _cell_gate_tanh.
configure(compile_context, cell_activation_input, &_cell_gate,
457 _memory_group.
manage(&_input_gate);
470 configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
input,
471 &_input_to_input_weights_transposed, &_input_to_input_eff_bias, &_mm_input_to_input_res,
472 &_input_to_input_outstage_res, input_to_input_scale, mm_out_info, input_outstage_info);
474 const float recurrent_to_input_scale =
477 configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
478 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
479 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
480 mm_out_info, input_outstage_info);
481 _accumulate_input_recurrent_input.
configure(compile_context, &_input_to_input_outstage_res,
482 &_recurrent_to_input_outstage_res,
490 _memory_group.
manage(&_mul_cell_to_input_res);
494 const float cell_to_input_scale =
495 std::pow(2, cell_shift) *
503 _memory_group.
manage(&_cell_to_input_outstage_res);
504 _cell_to_input_outstage.
configure(compile_context, &_mul_cell_to_input_res,
nullptr,
505 &_cell_to_input_outstage_res, gemmlowp_info);
507 _accumulate_cell_input.
configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res,
512 CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
516 configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
518 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
521 _input_gate_sigmoid.
configure(compile_context, input_activation_input, &_input_gate,
527 _pixelwise_mul_forget_cell.
configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f,
530 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
533 _memory_group.
manage(&_mul_input_cell_res);
535 _pixelwise_mul_input_cell.
configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f,
538 _add_forget_cell.
configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out,
542 if (_has_cell_clipping)
544 _cell_clip.
configure(compile_context, cell_state_out,
nullptr,
546 -quantized_cell_clip, quantized_cell_clip));
553 configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
input,
554 &_input_to_output_weights_transposed, &_input_to_output_eff_bias, &_mm_input_to_output_res,
555 &_input_to_output_outstage_res, input_to_output_scale, mm_out_info, output_outstage_info);
559 configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
560 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
561 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
562 mm_out_info, output_outstage_info);
564 _accumulate_input_recurrent_output.
configure(compile_context, &_recurrent_to_output_outstage_res,
565 &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
574 _memory_group.
manage(&_mul_cell_to_output_res);
579 const float cell_to_output_scale =
580 std::pow(2, cell_shift) *
588 _memory_group.
manage(&_cell_to_output_outstage_res);
589 _cell_to_output_outstage.
configure(compile_context, &_mul_cell_to_output_res,
nullptr,
590 &_cell_to_output_outstage_res, gemmlowp_info);
593 _accumulate_cell_to_output.
configure(compile_context, &_recurrent_to_output_outstage_res,
594 &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
599 CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
603 configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
605 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
609 _memory_group.
manage(&_output_gate);
611 _output_gate_sigmoid.
configure(compile_context, output_activation_input, &_output_gate,
616 _hidden_tanh.
configure(compile_context, cell_state_out, &_input_gate,
619 _memory_group.
manage(&_hidden_mul_res);
622 _pixelwise_mul_hidden.
configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f,
626 const float hidden_state_scale = std::pow(2, -15) / lstm_params.
hidden_state_scale() * std::pow(2, -15);
632 _projection_tensor_copy_required = (num_units !=
output_size);
633 ICLTensor *hidden_gate_result = output_state_out;
635 _memory_group.
manage(&_hidden_gate);
637 if (_projection_tensor_copy_required)
641 hidden_gate_result = &_hidden_gate;
644 _hidden_outstage.
configure(compile_context, &_hidden_mul_res,
nullptr, hidden_gate_result, gemmlowp_info);
650 const TensorInfo projection_outstage_info(*output_state_out->
info());
658 TensorInfo projection_mm_out_info{mm_out_info};
661 configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info, hidden_gate_result,
662 &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res,
663 &_projection_outstage_res, projection_scale, projection_mm_out_info, projection_outstage_info);
665 ICLTensor *accumulate_destination = output_state_out;
667 if (_projection_tensor_copy_required)
672 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
673 accumulate_destination = &_projection_accumulate_res;
676 _accumulate_projection.
configure(compile_context, &_projection_outstage_res, accumulate_destination,
680 if (_projection_tensor_copy_required)
682 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
686 int8_t quantized_projection_clip{0};
689 quantized_projection_clip =
693 if (quantized_projection_clip > 0)
695 _projection_clip.
configure(compile_context, output_state_out,
nullptr,
697 -quantized_projection_clip, quantized_projection_clip));
698 _has_projection_clipping =
true;
703 if (_projection_tensor_copy_required)
705 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
711 _copy_output.
configure(compile_context, output_state_out, output);
734 cell_state_in, output_state_in, cell_state_out, output_state_out, output);
740 const unsigned int batch_size =
input->dimension(1);
742 const unsigned int output_size = output_state_out->
dimension(_out_state_output_size_dimension_idx);
801 const int32_t cell_shift = log2(qcell_state_in.
scale);
805 int16_t quantized_cell_clip = 0;
898 input_to_forget_scale, &mm_out_info, &forget_outstage_info));
903 &eff_bias_info, recurrent_to_forget_scale, &mm_out_info,
904 &forget_outstage_info));
916 const float cell_to_forget_scale = std::pow(2, cell_shift) *
949 input_to_cell_scale, &mm_out_info, &cell_outstage_info));
954 &eff_bias_info, recurrent_to_cell_scale, &mm_out_info,
955 &cell_outstage_info));
977 "Input gate bias must not be present when CIFG is used");
996 const float input_to_input_scale = lstm_params.
input_to_input_weights()->quantization_info().uniform().scale *
999 input_to_input_scale, &mm_out_info, &input_outstage_info));
1001 const float recurrent_to_input_scale =
1005 &eff_bias_info, recurrent_to_input_scale, &mm_out_info,
1006 &input_outstage_info));
1016 const float cell_to_input_scale = std::pow(2, cell_shift) *
1035 &input_outstage_info, &input_gate_info,
1045 if (quantized_cell_clip > 0)
1050 -quantized_cell_clip, quantized_cell_clip)));
1059 input_to_output_scale, &mm_out_info, &output_outstage_info));
1064 &eff_bias_info, recurrent_to_output_scale, &mm_out_info,
1065 &output_outstage_info));
1106 const float hidden_state_scale = std::pow(2, -15) / lstm_params.
hidden_state_scale() * std::pow(2, -15);
1115 const bool projection_tensor_copy_required = num_units !=
output_size;
1133 const TensorInfo projection_outstage_info(*output_state_out);
1138 TensorInfo projection_mm_out_info{mm_out_info};
1142 &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
1143 &projection_outstage_info));
1145 if (projection_tensor_copy_required)
1154 if (projection_tensor_copy_required)
1160 int8_t quantized_projection_clip{0};
1166 if (quantized_projection_clip > 0)
1169 output_state_out,
nullptr,
1171 -quantized_projection_clip, quantized_projection_clip)));
1176 if (projection_tensor_copy_required)
1206 _mm_input_to_forget.
run();
1207 _input_to_forget_outstage.
run();
1209 _mm_recurrent_to_forget.
run();
1210 _recurrent_to_forget_outstage.
run();
1211 _accumulate_input_recurrent_forget.
run();
1215 _pixelwise_mul_cell_to_forget.
run();
1216 _cell_to_forget_outstage.
run();
1217 _accumulate_cell_forget.
run();
1220 if (_has_layer_norm)
1225 _forget_gate_sigmoid.
run();
1228 _mm_input_to_cell.
run();
1229 _input_to_cell_outstage.
run();
1231 _mm_recurrent_to_cell.
run();
1232 _recurrent_to_cell_outstage.
run();
1233 _accumulate_input_recurrent_modulation.
run();
1235 if (_has_layer_norm)
1240 _cell_gate_tanh.
run();
1245 _input_gate_sub.
run();
1249 _mm_input_to_input.
run();
1250 _input_to_input_outstage.
run();
1251 _mm_recurrent_to_input.
run();
1252 _recurrent_to_input_outstage.
run();
1253 _accumulate_input_recurrent_input.
run();
1257 _pixelwise_mul_cell_to_input.
run();
1258 _cell_to_input_outstage.
run();
1259 _accumulate_cell_input.
run();
1262 if (_has_layer_norm)
1267 _input_gate_sigmoid.
run();
1271 _pixelwise_mul_forget_cell.
run();
1272 _pixelwise_mul_input_cell.
run();
1273 _add_forget_cell.
run();
1274 if (_has_cell_clipping)
1280 _mm_input_to_output.
run();
1281 _input_to_output_outstage.
run();
1282 _mm_recurrent_to_output.
run();
1283 _recurrent_to_output_outstage.
run();
1284 _accumulate_input_recurrent_output.
run();
1287 _pixelwise_mul_cell_to_output.
run();
1288 _cell_to_output_outstage.
run();
1289 _accumulate_cell_to_output.
run();
1292 if (_has_layer_norm)
1297 _output_gate_sigmoid.
run();
1301 _pixelwise_mul_hidden.
run();
1302 _hidden_outstage.
run();
1305 if (_has_projection)
1307 _mm_projection.
run();
1308 _projection_outstage.
run();
1310 if (_projection_tensor_copy_required)
1312 _projection_output_to_accumulate_copy.run();
1315 _accumulate_projection.
run();
1317 if (_projection_tensor_copy_required)
1319 _projection_accumulate_to_output_copy.run();
1322 if (_has_projection_clipping)
1324 _projection_clip.
run();
1329 if (_projection_tensor_copy_required)
1331 _hidden_to_output_copy.run();
1350 _transpose_input_to_forget_weights.
run();
1351 _transpose_input_to_cell_weights.
run();
1352 _transpose_input_to_output_weights.
run();
1353 _transpose_recurrent_to_forget_weights.
run();
1354 _transpose_recurrent_to_cell_weights.
run();
1355 _transpose_recurrent_to_output_weights.
run();
1361 std::fill_n(
reinterpret_cast<int16_t *
>(_ones.
buffer()),
1371 {
ACL_DST, &_input_to_input_eff_bias}};
1375 {
ACL_DST, &_recurrent_to_input_eff_bias}};
1380 _transpose_input_to_input_weights.
run();
1381 _transpose_recurrent_to_input_weights.
run();
1393 {
ACL_DST, &_input_to_forget_eff_bias}};
1397 {
ACL_DST, &_recurrent_to_forget_eff_bias}};
1404 {
ACL_DST, &_recurrent_to_cell_eff_bias}};
1408 {
ACL_DST, &_input_to_output_eff_bias}};
1412 {
ACL_DST, &_recurrent_to_output_eff_bias}};
1415 if (_has_projection)
1420 if (_projection_bias !=
nullptr)
1422 _projection_bias_add.
run();
1427 _transpose_projection_weights.
run();
1430 if (!_projection_tensor_copy_required)
1446 _is_prepared =
true;