13 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
14 #include "arm_compute/runtime/NEON/functions/NEPermute.h"
15 #include "arm_compute/runtime/NEON/functions/NESplit.h"
16 #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
26 virtual void Execute()
const override;
35 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
36 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
37 mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers;
38 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
40 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
45 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
46 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
55 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
56 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
57 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
59 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
60 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
61 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
63 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
65 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
66 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
67 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
68 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
74 arm_compute::Tensor m_PermuteFirstOut;
75 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
76 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
77 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
78 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
79 arm_compute::Tensor concat_out;
81 void FreeUnusedTensors();