ArmNN
 25.11
Loading...
Searching...
No Matches
NeonUnidirectionalSequenceLstmWorkload.hpp
Go to the documentation of this file.
1//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
12#include "NeonBaseWorkload.hpp"
13
14#include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
15#include "arm_compute/runtime/NEON/functions/NEPermute.h"
16#include "arm_compute/runtime/NEON/functions/NESplit.h"
17#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
18
19namespace armnn
20{
21
22class NeonUnidirectionalSequenceLstmWorkload : public NeonBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>
23{
24public:
26 const WorkloadInfo& info);
27 virtual void Execute() const override;
28
29private:
30
31 //
32 // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
33 //
34 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
35 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
36 mutable std::vector<std::unique_ptr<arm_compute::NEQLSTMLayer>> m_Layers;
37 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
38 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
39
40 //
41 // ACL LSTM arm_compute::Tensors.
42 //
43 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
44 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
45 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
46 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
55 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
56 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
57 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
59 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
60
61 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
62 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
63 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
64 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
65
66 //
67 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
68 // Required to perform splitting, concatenation and permutations.
69 //
70 arm_compute::Tensor m_PermuteFirstOut;
71 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
72 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
73 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
74 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
75 arm_compute::Tensor concat_out;
76
77 void FreeUnusedTensors();
78};
79
80arm_compute::Status
82 const TensorInfo& outputStateIn,
83 const TensorInfo& cellStateIn,
84 const TensorInfo& outputStateOut,
85 const TensorInfo& cellStateOut,
86 const TensorInfo& output,
88 const LstmInputParamsInfo& paramsInfo);
89
90} //namespace armnn
NeonBaseWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
NeonUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Copyright (c) 2021 ARM Limited and Contributors.
arm_compute::Status NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const UnidirectionalSequenceLstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo)
LstmDescriptor UnidirectionalSequenceLstmDescriptor
Contains information about TensorInfos of a layer.