24 #ifndef ARM_COMPUTE_NEQLSTMLAYER_H
25 #define ARM_COMPUTE_NEQLSTMLAYER_H
47 class NEQLSTMLayerNormalizationKernel;
73 NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager =
nullptr);
214 enum class LayerNormGate : uint8_t
222 static constexpr uint8_t _layer_norm_count =
static_cast<uint8_t
>(LayerNormGate::Count);
223 static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
247 float gemmlowp_scale,
254 class TensorCopyKernel
256 static constexpr uint32_t max_dimension_supported = 2;
292 NETranspose _transpose_recurrent_to_forget_weights;
294 NETranspose _transpose_recurrent_to_output_weights;
297 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_input_reduction;
298 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_input_reduction;
299 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_forget_reduction;
300 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_forget_reduction;
301 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_cell_reduction;
302 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_cell_reduction;
303 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_output_reduction;
304 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_output_reduction;
305 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _projection_reduction;
353 TensorCopyKernel _projection_bias_copy;
354 TensorCopyKernel _projection_output_to_accumulate_copy;
355 TensorCopyKernel _projection_accumulate_to_output_copy;
356 TensorCopyKernel _hidden_to_output_copy;
358 std::array<std::unique_ptr<NEQLSTMLayerNormalizationKernel>, _layer_norm_count> _layer_norms;
363 const ITensor *_input_to_input_weights{
nullptr};
364 const ITensor *_recurrent_to_input_weights{
nullptr};
365 const ITensor *_projection_bias{
nullptr};
366 const ITensor *_input_to_forget_weights{
nullptr};
367 const ITensor *_input_to_cell_weights{
nullptr};
368 const ITensor *_input_to_output_weights{
nullptr};
369 const ITensor *_recurrent_to_forget_weights{
nullptr};
370 const ITensor *_recurrent_to_cell_weights{
nullptr};
371 const ITensor *_recurrent_to_output_weights{
nullptr};
372 const ITensor *_projection_weights{
nullptr};
373 std::array<const ITensor *, _layer_norm_count> _layer_norm_weights{};
374 std::array<const ITensor *, _layer_norm_count> _layer_norm_bias{};
377 inline LayerNormIndexType getGateIndex(LayerNormGate g)
379 return static_cast<LayerNormIndexType
>(g);
382 inline void set_layer_norm_weight(
const ITensor *
t, LayerNormGate g)
384 _layer_norm_weights[getGateIndex(g)] =
t;
387 inline void set_layer_norm_bias(
const ITensor *
t, LayerNormGate g)
389 _layer_norm_bias[getGateIndex(g)] =
t;
392 inline const ITensor *get_layer_norm_weight(LayerNormGate g)
394 return _layer_norm_weights[getGateIndex(g)];
397 inline const ITensor *get_layer_norm_bias(LayerNormGate g)
399 return _layer_norm_bias[getGateIndex(g)];
402 inline std::unique_ptr<NEQLSTMLayerNormalizationKernel> &get_layer_norm(LayerNormGate g)
404 return _layer_norms[getGateIndex(g)];
407 void configure_layer_norm(LayerNormGate g,
const ITensor *in);
411 Tensor _input_to_forget_weights_f32{
nullptr};
412 Tensor _input_to_forget_weights_symm8{
nullptr};
414 Tensor _input_to_forget_weights_transposed{
nullptr};
415 Tensor _input_to_cell_weights_transposed{
nullptr};
416 Tensor _input_to_output_weights_transposed{
nullptr};
417 Tensor _input_to_input_weights_transposed{
nullptr};
418 Tensor _recurrent_to_forget_weights_transposed{
nullptr};
419 Tensor _recurrent_to_cell_weights_transposed{
nullptr};
420 Tensor _recurrent_to_output_weights_transposed{
nullptr};
421 Tensor _recurrent_to_input_weights_transposed{
nullptr};
422 Tensor _projection_weights_transposed{
nullptr};
423 Tensor _input_to_input_eff_bias{
nullptr};
424 Tensor _recurrent_to_input_eff_bias{
nullptr};
425 Tensor _input_to_forget_eff_bias{
nullptr};
426 Tensor _recurrent_to_forget_eff_bias{
nullptr};
427 Tensor _input_to_cell_eff_bias{
nullptr};
428 Tensor _recurrent_to_cell_eff_bias{
nullptr};
429 Tensor _input_to_output_eff_bias{
nullptr};
430 Tensor _recurrent_to_output_eff_bias{
nullptr};
431 Tensor _projection_reduction_res{
nullptr};
432 Tensor _projection_eff_bias{
nullptr};
433 Tensor _mm_input_to_forget_res{
nullptr};
434 Tensor _mm_recurrent_to_forget_res{
nullptr};
435 Tensor _mul_cell_to_forget_res{
nullptr};
436 Tensor _input_to_forget_outstage_res{
nullptr};
437 Tensor _cell_to_forget_outstage_res{
nullptr};
438 Tensor _recurrent_to_forget_outstage_res{
nullptr};
439 Tensor _forget_gate{
nullptr};
440 Tensor _mm_input_to_cell_res{
nullptr};
441 Tensor _input_to_cell_outstage_res{
nullptr};
442 Tensor _mm_recurrent_to_cell_res{
nullptr};
443 Tensor _recurrent_to_cell_outstage_res{
nullptr};
444 Tensor _cell_gate{
nullptr};
445 Tensor _mul_input_cell_res{
nullptr};
446 Tensor _mm_input_to_input_res{
nullptr};
447 Tensor _input_to_input_outstage_res{
nullptr};
448 Tensor _mm_recurrent_to_input_res{
nullptr};
449 Tensor _mul_cell_to_input_res{
nullptr};
450 Tensor _cell_to_input_outstage_res{
nullptr};
451 Tensor _recurrent_to_input_outstage_res{
nullptr};
452 Tensor _input_gate{
nullptr};
453 Tensor _mm_input_to_output_res{
nullptr};
454 Tensor _input_to_output_outstage_res{
nullptr};
455 Tensor _mm_recurrent_to_output_res{
nullptr};
456 Tensor _mul_cell_to_output_res{
nullptr};
457 Tensor _cell_to_output_outstage_res{
nullptr};
458 Tensor _recurrent_to_output_outstage_res{
nullptr};
459 Tensor _output_gate{
nullptr};
460 Tensor _hidden_mul_res{
nullptr};
461 Tensor _hidden_gate{
nullptr};
462 Tensor _mm_projection_res{
nullptr};
463 Tensor _projection_outstage_res{
nullptr};
464 Tensor _projection_out_res{
nullptr};
465 Tensor _projection_accumulate_res{
nullptr};
467 std::array<Tensor, _layer_norm_count> _layer_norm_output{};
469 inline Tensor &get_layer_norm_output(LayerNormGate g)
471 return _layer_norm_output[getGateIndex(g)];
474 bool _is_prepared{
false};
475 bool _has_cifg{
false};
476 bool _has_cell_clipping{
false};
477 bool _has_projection{
false};
478 bool _has_projection_clipping{
false};
479 bool _has_peephole{
false};
480 bool _has_layer_norm{
false};
481 bool _projection_tensor_copy_required{
false};
482 bool _convert_input_to_forget_weights_to_qsymm8{
false};