43 : _memory_group(std::move(memory_manager)),
44 _fully_connected_input_gate(),
46 _subtract_input_gate(),
47 _pixelwise_mul_input_gate(),
48 _activation_input_gate(),
49 _fully_connected_forget_gate(),
50 _accum_forget_gate1(),
51 _pixelwise_mul_forget_gate(),
52 _activation_forget_gate(),
53 _fully_connected_cell_state(),
58 _pixelwise_mul_cell_state1(),
59 _activation_cell_state(),
61 _pixelwise_mul_cell_state2(),
62 _fully_connected_output(),
63 _pixelwise_mul_output_state1(),
66 _activation_output_state(),
67 _pixelwise_mul_output_state2(),
68 _fully_connected_output_state(),
72 _concat_scratch_buffer(),
73 _concat_inputs_forget_gate(),
74 _concat_weights_forget_gate(),
75 _concat_weights_input_gate(),
76 _concat_weights_output(),
78 _mean_std_norm_input_gate(),
79 _pixelwise_mul_input_gate_coeff(),
80 _accum_input_gate_bias(),
81 _mean_std_norm_forget_gate(),
82 _pixelwise_mul_forget_gate_coeff(),
83 _accum_forget_gate_bias(),
84 _mean_std_norm_cell_gate(),
85 _pixelwise_mul_cell_gate_coeff(),
86 _accum_cell_gate_bias(),
87 _mean_std_norm_output_gate(),
88 _pixelwise_mul_output_gate_coeff(),
89 _accum_output_gate_bias(),
109 _cell_state_activation(),
112 _input_layer_norm_out1(),
113 _input_layer_norm_out2(),
114 _forget_layer_norm_out1(),
115 _forget_layer_norm_out2(),
116 _cell_layer_norm_out1(),
117 _cell_layer_norm_out2(),
118 _output_layer_norm_out1(),
119 _output_layer_norm_out2(),
120 _run_peephole_opt(false),
121 _run_cifg_opt(false),
122 _perform_cell_clipping(false),
123 _has_projection_weights(false),
124 _perform_projection_clipping(false),
126 _is_layer_norm_lstm(false)
150 float cell_threshold,
151 float projection_threshold)
156 cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
157 cell_threshold, projection_threshold);
179 float cell_threshold,
180 float projection_threshold)
185 scratch_buffer, output_state_out, cell_state_out, output);
190 scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
191 cell_threshold, projection_threshold);
204 cell_state_in->
info(), scratch_buffer->
info(), output_state_out->
info(), cell_state_out->
info(), output->
info(),
205 lstm_params_info, activation_info, cell_threshold, projection_threshold));
216 std::vector<const ICLTensor *> inputs_vector;
217 inputs_vector.emplace_back(
input);
218 inputs_vector.emplace_back(output_state_in);
222 _memory_group.
manage(&_forget_gate_out2);
223 _concat_inputs_forget_gate.
configure(compile_context, inputs_vector, &_forget_gate_out2,
Window::DimX);
225 std::vector<const ICLTensor *> weights_vector;
233 _concat_weights_forget_gate.
configure(compile_context, weights_vector, &_forget_gate_out6,
Window::DimX);
235 _memory_group.
manage(&_forget_gate_out5);
236 _fully_connected_forget_gate.
configure(compile_context, &_forget_gate_out2, &_forget_gate_out6,
238 _memory_group.
manage(&_forget_gate_out1);
239 _memory_group.
manage(&_forget_gate_out3);
242 CLTensor *forget_gate_out = &_forget_gate_out5;
247 _run_peephole_opt =
true;
248 _memory_group.
manage(&_forget_gate_out4);
252 _accum_forget_gate1.
configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3,
256 forget_gate_out = &_forget_gate_out3;
262 if (_is_layer_norm_lstm)
266 _memory_group.
manage(&_forget_layer_norm_out1);
267 _memory_group.
manage(&_forget_layer_norm_out2);
268 _mean_std_norm_forget_gate.
configure(compile_context, forget_gate_out);
269 _pixelwise_mul_forget_gate_coeff.
configure(compile_context, forget_gate_out,
277 forget_gate_out = &_forget_layer_norm_out2;
279 _activation_forget_gate.
configure(compile_context, forget_gate_out,
nullptr,
288 CLTensor *input_gate_out = &_input_gate_out1;
291 _memory_group.
manage(&_input_gate_out1);
294 _subtract_input_gate.
configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1,
297 _run_cifg_opt =
true;
304 std::vector<const ICLTensor *> lstm_weights;
311 _concat_weights_input_gate.
configure(compile_context, lstm_weights, &_input_gate_out2,
Window::DimX);
313 _memory_group.
manage(&_input_gate_out1);
315 _memory_group.
manage(&_input_gate_out3);
316 _fully_connected_input_gate.
configure(compile_context, &_forget_gate_out2, &_input_gate_out2,
321 input_gate_out = &_input_gate_out3;
322 if (_run_peephole_opt)
324 _memory_group.
manage(&_input_gate_out4);
328 _accum_input_gate1.
configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1,
332 input_gate_out = &_input_gate_out1;
339 if (_is_layer_norm_lstm)
343 _memory_group.
manage(&_input_layer_norm_out1);
344 _memory_group.
manage(&_input_layer_norm_out2);
345 _mean_std_norm_input_gate.
configure(compile_context, input_gate_out);
346 _pixelwise_mul_input_gate_coeff.
configure(compile_context, input_gate_out,
354 input_gate_out = &_input_layer_norm_out2;
356 _activation_input_gate.
configure(compile_context, input_gate_out,
nullptr,
369 _memory_group.
manage(&_cell_state_out1);
371 (_is_layer_norm_lstm) ?
nullptr : cell_bias, &_cell_state_out1);
372 _memory_group.
manage(&_cell_state_out2);
375 _memory_group.
manage(&_cell_state_out3);
376 _gemm_cell_state1.
configure(compile_context, output_state_in, &_cell_state_out2,
nullptr, &_cell_state_out3, 1.f,
379 _memory_group.
manage(&_cell_state_out4);
380 _accum_cell_state1.
configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4,
382 CLTensor *cell_state_out_ptr = &_cell_state_out4;
383 if (_is_layer_norm_lstm)
387 _memory_group.
manage(&_cell_layer_norm_out1);
388 _memory_group.
manage(&_cell_layer_norm_out2);
389 _mean_std_norm_cell_gate.
configure(compile_context, cell_state_out_ptr);
390 _pixelwise_mul_cell_gate_coeff.
configure(compile_context, cell_state_out_ptr,
395 _accum_cell_gate_bias.
configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2,
398 cell_state_out_ptr = &_cell_layer_norm_out2;
400 _activation_cell_state.
configure(compile_context, cell_state_out_ptr,
nullptr, activation_info);
401 _memory_group.
manage(&_cell_state_out5);
402 _pixelwise_mul_cell_state1.
configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1,
405 _pixelwise_mul_cell_state2.
configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1,
407 _accum_cell_state2.
configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1,
412 if (cell_threshold != 0.f)
414 _perform_cell_clipping =
true;
415 _cell_clip.
configure(compile_context, &_cell_state_out1,
nullptr,
417 cell_threshold, -cell_threshold));
426 std::vector<const ICLTensor *> in_out_weights;
435 _memory_group.
manage(&_output1);
436 _memory_group.
manage(&_output4);
438 _fully_connected_output.
configure(compile_context, &_forget_gate_out2, &_output2,
444 CLTensor *output_gate_out = &_output4;
449 _memory_group.
manage(&_output3);
454 output_gate_out = &_output1;
463 if (_is_layer_norm_lstm)
467 _memory_group.
manage(&_output_layer_norm_out1);
468 _memory_group.
manage(&_output_layer_norm_out2);
469 _mean_std_norm_output_gate.
configure(compile_context, output_gate_out);
470 _pixelwise_mul_output_gate_coeff.
configure(compile_context, output_gate_out,
478 output_gate_out = &_output_layer_norm_out2;
480 _activation_output.
configure(compile_context, output_gate_out,
nullptr,
496 _memory_group.
manage(&_cell_state_activation);
497 _activation_output_state.
configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
498 _pixelwise_mul_output_state2.
configure(compile_context, &_cell_state_activation, output_gate_out,
505 _has_projection_weights =
true;
510 if (projection_threshold != 0.f)
512 _perform_projection_clipping =
true;
513 _projection_clip.
configure(compile_context, output_state_out,
nullptr,
515 -projection_threshold, projection_threshold));
520 _copy_cell_state.
configure(compile_context, &_cell_state_out1, cell_state_out);
521 _copy_output.
configure(compile_context, output_state_out, output);
524 std::vector<const ICLTensor *> scratch_inputs;
527 scratch_inputs.emplace_back(input_gate_out);
529 scratch_inputs.emplace_back(&_cell_state_out1);
530 scratch_inputs.emplace_back(forget_gate_out);
531 scratch_inputs.emplace_back(output_gate_out);
557 float cell_threshold,
558 float projection_threshold)
563 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
570 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
592 const unsigned int num_batches =
input->dimension(1);
646 std::vector<const ITensorInfo *> inputs_vector;
647 inputs_vector.emplace_back(
input);
648 inputs_vector.emplace_back(output_state_in);
672 &forget_gate, &forget_gate,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
683 std::vector<const ITensorInfo *> lstm_weights;
716 &input_gate,
nullptr,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
747 if (cell_threshold != 0.f)
752 cell_threshold, -cell_threshold)));
755 std::vector<const ITensorInfo *> in_out_weights;
784 &output_gate_tmp,
nullptr,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
795 if (projection_threshold != 0.f)
798 output_state_out, output_state_out,
799 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,
800 projection_threshold)));
809 std::vector<const ITensorInfo *> inputs_vector_info_raw;
812 inputs_vector_info_raw.push_back(&input_gate);
814 inputs_vector_info_raw.push_back(&cell_state_tmp);
815 inputs_vector_info_raw.push_back(&forget_gate);
816 inputs_vector_info_raw.push_back(&output_gate_tmp);
828 _concat_inputs_forget_gate.
run();
830 _fully_connected_forget_gate.
run();
832 if (_run_peephole_opt)
834 _pixelwise_mul_forget_gate.
run();
835 _accum_forget_gate1.
run();
837 if (_is_layer_norm_lstm)
839 _mean_std_norm_forget_gate.
run();
840 _pixelwise_mul_forget_gate_coeff.
run();
841 _accum_forget_gate_bias.
run();
843 _activation_forget_gate.
run();
848 _subtract_input_gate.
run();
852 _fully_connected_input_gate.
run();
854 if (_run_peephole_opt)
856 _pixelwise_mul_input_gate.
run();
857 _accum_input_gate1.
run();
860 if (_is_layer_norm_lstm)
862 _mean_std_norm_input_gate.
run();
863 _pixelwise_mul_input_gate_coeff.
run();
864 _accum_input_gate_bias.
run();
866 _activation_input_gate.
run();
869 _fully_connected_cell_state.
run();
874 _gemm_cell_state1.
run();
875 _accum_cell_state1.
run();
876 if (_is_layer_norm_lstm)
878 _mean_std_norm_cell_gate.
run();
879 _pixelwise_mul_cell_gate_coeff.
run();
880 _accum_cell_gate_bias.
run();
882 _activation_cell_state.
run();
883 _pixelwise_mul_cell_state1.
run();
884 _pixelwise_mul_cell_state2.
run();
885 _accum_cell_state2.
run();
887 if (_perform_cell_clipping)
892 _fully_connected_output.
run();
894 if (_run_peephole_opt)
896 _pixelwise_mul_output_state1.
run();
897 _accum_output1.
run();
899 if (_is_layer_norm_lstm)
901 _mean_std_norm_output_gate.
run();
902 _pixelwise_mul_output_gate_coeff.
run();
903 _accum_output_gate_bias.
run();
905 _activation_output.
run();
907 _activation_output_state.
run();
908 _pixelwise_mul_output_state2.
run();
910 if (_has_projection_weights)
912 _fully_connected_output_state.
run();
913 if (_perform_projection_clipping)
915 _projection_clip.
run();
919 _copy_cell_state.
run();
922 _concat_scratch_buffer.
run();
929 _concat_weights_forget_gate.
run();
932 _concat_weights_input_gate.
run();
934 _concat_weights_output.
run();