24 #ifndef ARM_COMPUTE_CLQLSTMLAYER_H
25 #define ARM_COMPUTE_CLQLSTMLAYER_H
40 class CLCompileContext;
42 class CLQLSTMLayerNormalizationKernel;
48 class ClGemmLowpMatrixAReductionKernel;
69 CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager =
nullptr);
270 enum class LayerNormGate : uint8_t
278 static constexpr uint8_t _layer_norm_count =
static_cast<uint8_t
>(LayerNormGate::Count);
279 static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
305 float gemmlowp_scale,
312 class TensorCopyKernel
314 static constexpr uint32_t max_dimension_supported = 2;
345 CLTranspose _transpose_recurrent_to_forget_weights{};
346 CLTranspose _transpose_recurrent_to_cell_weights{};
347 CLTranspose _transpose_recurrent_to_output_weights{};
348 CLTranspose _transpose_recurrent_to_input_weights{};
350 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_input_reduction;
351 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_input_reduction;
352 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_forget_reduction;
353 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_forget_reduction;
354 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_cell_reduction;
355 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_cell_reduction;
356 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_output_reduction;
357 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_output_reduction;
358 std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _projection_reduction;
405 std::array<std::unique_ptr<CLQLSTMLayerNormalizationKernel>, _layer_norm_count> _layer_norms;
408 TensorCopyKernel _projection_bias_copy{};
409 TensorCopyKernel _projection_output_to_accumulate_copy{};
410 TensorCopyKernel _projection_accumulate_to_output_copy{};
411 TensorCopyKernel _hidden_to_output_copy{};
414 const ICLTensor *_input_to_input_weights{
nullptr};
415 const ICLTensor *_recurrent_to_input_weights{
nullptr};
416 const ICLTensor *_projection_bias{
nullptr};
417 const ICLTensor *_input_to_forget_weights{
nullptr};
418 const ICLTensor *_input_to_cell_weights{
nullptr};
419 const ICLTensor *_input_to_output_weights{
nullptr};
420 const ICLTensor *_recurrent_to_forget_weights{
nullptr};
421 const ICLTensor *_recurrent_to_cell_weights{
nullptr};
422 const ICLTensor *_recurrent_to_output_weights{
nullptr};
423 const ICLTensor *_projection_weights{
nullptr};
424 std::array<const ICLTensor *, _layer_norm_count> _layer_norm_weights{{}};
425 std::array<const ICLTensor *, _layer_norm_count> _layer_norm_bias{{}};
428 inline LayerNormIndexType getGateIndex(LayerNormGate g)
430 return static_cast<LayerNormIndexType
>(g);
433 inline void set_layer_norm_weight(
const ICLTensor *
t, LayerNormGate g)
435 _layer_norm_weights[getGateIndex(g)] =
t;
438 inline void set_layer_norm_bias(
const ICLTensor *
t, LayerNormGate g)
440 _layer_norm_bias[getGateIndex(g)] =
t;
443 inline const ICLTensor *get_layer_norm_weight(LayerNormGate g)
445 return _layer_norm_weights[getGateIndex(g)];
448 inline const ICLTensor *get_layer_norm_bias(LayerNormGate g)
450 return _layer_norm_bias[getGateIndex(g)];
455 return *_layer_norms[getGateIndex(g)];
458 inline void configure_layer_norm(LayerNormGate g,
const ICLTensor *in);
462 CLTensor _input_to_forget_weights_transposed{
nullptr};
463 CLTensor _input_to_cell_weights_transposed{
nullptr};
464 CLTensor _input_to_output_weights_transposed{
nullptr};
465 CLTensor _input_to_input_weights_transposed{
nullptr};
466 CLTensor _recurrent_to_forget_weights_transposed{
nullptr};
467 CLTensor _recurrent_to_cell_weights_transposed{
nullptr};
468 CLTensor _recurrent_to_output_weights_transposed{
nullptr};
469 CLTensor _recurrent_to_input_weights_transposed{
nullptr};
470 CLTensor _projection_weights_transposed{
nullptr};
471 CLTensor _input_to_input_eff_bias{
nullptr};
472 CLTensor _recurrent_to_input_eff_bias{
nullptr};
473 CLTensor _input_to_forget_eff_bias{
nullptr};
474 CLTensor _recurrent_to_forget_eff_bias{
nullptr};
475 CLTensor _input_to_cell_eff_bias{
nullptr};
476 CLTensor _recurrent_to_cell_eff_bias{
nullptr};
477 CLTensor _input_to_output_eff_bias{
nullptr};
478 CLTensor _recurrent_to_output_eff_bias{
nullptr};
479 CLTensor _projection_reduction_res{
nullptr};
480 CLTensor _projection_eff_bias{
nullptr};
481 CLTensor _mm_input_to_forget_res{
nullptr};
482 CLTensor _mm_recurrent_to_forget_res{
nullptr};
483 CLTensor _mul_cell_to_forget_res{
nullptr};
484 CLTensor _input_to_forget_outstage_res{
nullptr};
485 CLTensor _cell_to_forget_outstage_res{
nullptr};
486 CLTensor _recurrent_to_forget_outstage_res{
nullptr};
488 CLTensor _mm_input_to_cell_res{
nullptr};
489 CLTensor _input_to_cell_outstage_res{
nullptr};
490 CLTensor _mm_recurrent_to_cell_res{
nullptr};
491 CLTensor _recurrent_to_cell_outstage_res{
nullptr};
493 CLTensor _mul_input_cell_res{
nullptr};
494 CLTensor _mm_input_to_input_res{
nullptr};
495 CLTensor _input_to_input_outstage_res{
nullptr};
496 CLTensor _mm_recurrent_to_input_res{
nullptr};
497 CLTensor _mul_cell_to_input_res{
nullptr};
498 CLTensor _cell_to_input_outstage_res{
nullptr};
499 CLTensor _recurrent_to_input_outstage_res{
nullptr};
501 CLTensor _mm_input_to_output_res{
nullptr};
502 CLTensor _input_to_output_outstage_res{
nullptr};
503 CLTensor _mm_recurrent_to_output_res{
nullptr};
504 CLTensor _mul_cell_to_output_res{
nullptr};
505 CLTensor _cell_to_output_outstage_res{
nullptr};
506 CLTensor _recurrent_to_output_outstage_res{
nullptr};
510 CLTensor _mm_projection_res{
nullptr};
511 CLTensor _projection_outstage_res{
nullptr};
512 CLTensor _projection_out_res{
nullptr};
513 CLTensor _projection_accumulate_res{
nullptr};
515 std::array<CLTensor, _layer_norm_count> _layer_norm_output{{}};
517 inline CLTensor &get_layer_norm_output(LayerNormGate g)
519 return _layer_norm_output[getGateIndex(g)];
522 bool _is_prepared{
false};
523 bool _has_cifg{
false};
524 bool _has_cell_clipping{
false};
525 bool _has_projection{
false};
526 bool _has_projection_clipping{
false};
527 bool _has_peephole{
false};
528 bool _has_layer_norm{
false};
529 bool _projection_tensor_copy_required{
false};