43 : _memory_group(std::move(memory_manager)),
44 _fully_connected_input_gate(),
46 _subtract_input_gate(),
47 _pixelwise_mul_input_gate(),
48 _activation_input_gate(),
49 _fully_connected_forget_gate(),
50 _accum_forget_gate1(),
51 _pixelwise_mul_forget_gate(),
52 _activation_forget_gate(),
53 _fully_connected_cell_state(),
55 _transpose_cell_state(),
58 _pixelwise_mul_cell_state1(),
59 _activation_cell_state(),
61 _pixelwise_mul_cell_state2(),
62 _fully_connected_output(),
63 _pixelwise_mul_output_state1(),
66 _activation_output_state(),
67 _pixelwise_mul_output_state2(),
68 _fully_connected_output_state(),
72 _concat_scratch_buffer(),
73 _concat_inputs_forget_gate(),
74 _concat_weights_forget_gate(),
75 _concat_weights_input_gate(),
76 _concat_weights_output(),
77 _mean_std_norm_input_gate(),
78 _pixelwise_mul_input_gate_coeff(),
79 _accum_input_gate_bias(),
80 _mean_std_norm_forget_gate(),
81 _pixelwise_mul_forget_gate_coeff(),
82 _accum_forget_gate_bias(),
83 _mean_std_norm_cell_gate(),
84 _pixelwise_mul_cell_gate_coeff(),
85 _accum_cell_gate_bias(),
86 _mean_std_norm_output_gate(),
87 _pixelwise_mul_output_gate_coeff(),
88 _accum_output_gate_bias(),
108 _cell_state_activation(),
111 _input_layer_norm_out1(),
112 _input_layer_norm_out2(),
113 _forget_layer_norm_out1(),
114 _forget_layer_norm_out2(),
115 _cell_layer_norm_out1(),
116 _cell_layer_norm_out2(),
117 _output_layer_norm_out1(),
118 _output_layer_norm_out2(),
119 _run_peephole_opt(false),
120 _run_cifg_opt(false),
121 _perform_cell_clipping(false),
122 _has_projection_weights(false),
123 _perform_projection_clipping(false),
125 _is_layer_norm_lstm(false)
139 const ITensor *output_state_in,
147 float cell_threshold,
148 float projection_threshold)
153 scratch_buffer, output_state_out, cell_state_out, output);
157 scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
158 cell_threshold, projection_threshold);
171 cell_state_in->
info(), scratch_buffer->
info(), output_state_out->
info(), cell_state_out->
info(), output->
info(),
172 lstm_params_info, activation_info, cell_threshold, projection_threshold));
184 std::vector<const ITensor *> inputs_vector;
185 inputs_vector.emplace_back(
input);
186 inputs_vector.emplace_back(output_state_in);
188 _memory_group.
manage(&_forget_gate_out2);
191 std::vector<const ITensor *> weights_vector;
198 _memory_group.
manage(&_forget_gate_out5);
199 _fully_connected_forget_gate.
configure(&_forget_gate_out2, &_forget_gate_out6,
201 _memory_group.
manage(&_forget_gate_out1);
202 _memory_group.
manage(&_forget_gate_out3);
205 Tensor *forget_gate_out = &_forget_gate_out5;
210 _run_peephole_opt =
true;
211 _memory_group.
manage(&_forget_gate_out4);
214 _accum_forget_gate1.
configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3,
218 forget_gate_out = &_forget_gate_out3;
224 if (_is_layer_norm_lstm)
228 _memory_group.
manage(&_forget_layer_norm_out1);
229 _memory_group.
manage(&_forget_layer_norm_out2);
230 _mean_std_norm_forget_gate.
configure(forget_gate_out);
239 forget_gate_out = &_forget_layer_norm_out2;
241 _activation_forget_gate.
configure(forget_gate_out,
nullptr,
250 Tensor *input_gate_out = &_input_gate_out1;
253 _memory_group.
manage(&_input_gate_out1);
257 _run_cifg_opt =
true;
264 std::vector<const ITensor *> lstm_weights;
270 _memory_group.
manage(&_input_gate_out1);
271 _memory_group.
manage(&_input_gate_out4);
273 _fully_connected_input_gate.
configure(&_forget_gate_out2, &_input_gate_out2,
277 input_gate_out = &_input_gate_out3;
279 if (_run_peephole_opt)
281 _memory_group.
manage(&_input_gate_out4);
284 _accum_input_gate1.
configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1,
288 input_gate_out = &_input_gate_out1;
295 if (_is_layer_norm_lstm)
299 _memory_group.
manage(&_input_layer_norm_out1);
300 _memory_group.
manage(&_input_layer_norm_out2);
301 _mean_std_norm_input_gate.
configure(input_gate_out);
310 input_gate_out = &_input_layer_norm_out2;
312 _activation_input_gate.
configure(input_gate_out,
nullptr,
325 _memory_group.
manage(&_cell_state_out1);
328 _memory_group.
manage(&_cell_state_out2);
330 _memory_group.
manage(&_cell_state_out3);
331 _gemm_cell_state1.
configure(output_state_in, &_cell_state_out2,
nullptr, &_cell_state_out3, 1.f, 0.f);
333 _memory_group.
manage(&_cell_state_out4);
335 Tensor *cell_state_out_ptr = &_cell_state_out4;
336 if (_is_layer_norm_lstm)
340 _memory_group.
manage(&_cell_layer_norm_out1);
341 _memory_group.
manage(&_cell_layer_norm_out2);
342 _mean_std_norm_cell_gate.
configure(cell_state_out_ptr);
348 _accum_cell_gate_bias.
configure(&_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2,
351 cell_state_out_ptr = &_cell_layer_norm_out2;
353 _activation_cell_state.
configure(cell_state_out_ptr,
nullptr, activation_info);
354 _memory_group.
manage(&_cell_state_out5);
355 _pixelwise_mul_cell_state1.
configure(cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1,
364 if (cell_threshold != 0.f)
366 _perform_cell_clipping =
true;
367 _cell_clip.
configure(&_cell_state_out1,
nullptr,
369 cell_threshold, -cell_threshold));
379 std::vector<const ITensor *> in_out_weights;
384 _memory_group.
manage(&_output1);
385 _memory_group.
manage(&_output4);
393 Tensor *output_gate_out = &_output4;
398 _memory_group.
manage(&_output3);
403 output_gate_out = &_output1;
412 if (_is_layer_norm_lstm)
416 _memory_group.
manage(&_output_layer_norm_out1);
417 _memory_group.
manage(&_output_layer_norm_out2);
418 _mean_std_norm_output_gate.
configure(output_gate_out);
427 output_gate_out = &_output_layer_norm_out2;
429 _activation_output.
configure(output_gate_out,
nullptr,
445 _memory_group.
manage(&_cell_state_activation);
446 _activation_output_state.
configure(&_cell_state_out1, &_cell_state_activation, activation_info);
447 _pixelwise_mul_output_state2.
configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1,
454 _has_projection_weights =
true;
459 if (projection_threshold != 0.f)
461 _perform_projection_clipping =
true;
462 _projection_clip.
configure(output_state_out,
nullptr,
464 -projection_threshold, projection_threshold));
469 _copy_cell_state.
configure(&_cell_state_out1, cell_state_out);
470 _copy_output.
configure(output_state_out, output);
473 std::vector<const ITensor *> scratch_inputs;
476 scratch_inputs.emplace_back(input_gate_out);
478 scratch_inputs.emplace_back(&_cell_state_out1);
479 scratch_inputs.emplace_back(forget_gate_out);
480 scratch_inputs.emplace_back(output_gate_out);
506 float cell_threshold,
507 float projection_threshold)
512 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
519 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
541 const unsigned int num_batches =
input->dimension(1);
591 std::vector<const ITensorInfo *> inputs_vector;
592 inputs_vector.emplace_back(
input);
593 inputs_vector.emplace_back(output_state_in);
620 &forget_gate, &forget_gate,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
631 std::vector<const ITensorInfo *> lstm_weights;
663 &input_gate,
nullptr,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
694 if (cell_threshold != 0.f)
699 cell_threshold, -cell_threshold)));
703 std::vector<const ITensorInfo *> in_out_weights;
732 &output_gate_tmp,
nullptr,
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
742 if (projection_threshold != 0.f)
745 output_state_out, output_state_out,
746 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,
747 projection_threshold)));
756 std::vector<const ITensorInfo *> inputs_vector_info_raw;
759 inputs_vector_info_raw.push_back(&input_gate);
761 inputs_vector_info_raw.push_back(&cell_state_tmp);
762 inputs_vector_info_raw.push_back(&forget_gate);
763 inputs_vector_info_raw.push_back(&output_gate_tmp);
775 _concat_inputs_forget_gate.
run();
776 _fully_connected_forget_gate.
run();
778 if (_run_peephole_opt)
780 _pixelwise_mul_forget_gate.
run();
781 _accum_forget_gate1.
run();
783 if (_is_layer_norm_lstm)
785 _mean_std_norm_forget_gate.
run();
786 _pixelwise_mul_forget_gate_coeff.
run();
787 _accum_forget_gate_bias.
run();
789 _activation_forget_gate.
run();
795 std::fill_n(
reinterpret_cast<half *
>(_ones.
buffer()),
800 std::fill_n(
reinterpret_cast<float *
>(_ones.
buffer()),
803 _subtract_input_gate.
run();
807 _fully_connected_input_gate.
run();
809 if (_run_peephole_opt)
811 _pixelwise_mul_input_gate.
run();
812 _accum_input_gate1.
run();
815 if (_is_layer_norm_lstm)
817 _mean_std_norm_input_gate.
run();
818 _pixelwise_mul_input_gate_coeff.
run();
819 _accum_input_gate_bias.
run();
821 _activation_input_gate.
run();
824 _fully_connected_cell_state.
run();
825 _transpose_cell_state.
run();
826 _gemm_cell_state1.
run();
827 _accum_cell_state1.
run();
828 if (_is_layer_norm_lstm)
830 _mean_std_norm_cell_gate.
run();
831 _pixelwise_mul_cell_gate_coeff.
run();
832 _accum_cell_gate_bias.
run();
835 _activation_cell_state.
run();
836 _pixelwise_mul_cell_state1.
run();
837 _pixelwise_mul_cell_state2.
run();
838 _accum_cell_state2.
run();
840 if (_perform_cell_clipping)
845 _fully_connected_output.
run();
846 if (_run_peephole_opt)
848 _pixelwise_mul_output_state1.
run();
849 _accum_output1.
run();
851 if (_is_layer_norm_lstm)
853 _mean_std_norm_output_gate.
run();
854 _pixelwise_mul_output_gate_coeff.
run();
855 _accum_output_gate_bias.
run();
857 _activation_output.
run();
859 _activation_output_state.
run();
860 _pixelwise_mul_output_state2.
run();
862 if (_has_projection_weights)
864 _fully_connected_output_state.
run();
865 if (_perform_projection_clipping)
867 _projection_clip.
run();
871 _copy_cell_state.
run();
874 _concat_scratch_buffer.
run();
881 _concat_weights_forget_gate.
run();
884 _concat_weights_input_gate.
run();
886 _concat_weights_output.
run();