ArmNN
 25.11
Loading...
Searching...
No Matches
RefLstmWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2019,2021-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "Lstm.hpp"
11#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14namespace armnn
15{
16
19 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40{}
41
43{
44 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45}
46
47void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
48{
49 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefLstmWorkload_Execute");
50
51 // This is a porting of the LSTM::Eval() method in the Android code base
52 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
53
54 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
55 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
56
57 const TensorShape& inputShape = inputInfo.GetShape();
58
59 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
60 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
61 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
62
63 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
64 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
65
66 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
67 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
68 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
69
70 const uint32_t nBatch = inputShape[0];
71 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
72
73 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
74 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
75 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
76
77 // Index the scratch buffers pointers to the global scratch buffer.
78 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
79 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
80 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
81 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
82
83 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
84 MakeDecoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Decoder<float>> cellScratchDecoder =
86 MakeDecoder<float>(outputInfo, outputs[0]->Map());
87 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
88 MakeDecoder<float>(outputInfo, outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->Map());
91
92 if (useCifg)
93 {
94 *cellScratch += (0 * nCell * nBatch);
95 *forgetGateScratch += (1 * nCell * nBatch);
96 *outputGateScratch += (2 * nCell * nBatch);
97
98 *cellScratchDecoder += (0 * nCell * nBatch);
99 *forgetGateScratchDecoder += (1 * nCell * nBatch);
100 *outputGateScratchDecoder += (2 * nCell * nBatch);
101 }
102 else
103 {
104 *inputGateScratch += (0 * nCell * nBatch);
105 *cellScratch += (1 * nCell * nBatch);
106 *forgetGateScratch += (2 * nCell * nBatch);
107 *outputGateScratch += (3 * nCell * nBatch);
108
109 *inputGateScratchDecoder += (0 * nCell * nBatch);
110 *cellScratchDecoder += (1 * nCell * nBatch);
111 *forgetGateScratchDecoder += (2 * nCell * nBatch);
112 *outputGateScratchDecoder += (3 * nCell * nBatch);
113 }
114
115 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
116 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
117 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
118 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
119 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
120 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
121 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
122
123 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
124 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
125 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
126 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
127 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
128 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
129 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
130
131 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
132 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
133 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
134 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
135 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
136 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
137 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
138
139 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
140 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
141 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
142
143 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
144 std::unique_ptr<Decoder<float>> projectionBiasTensor;
145
146 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
147 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
148 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
149 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
150
151 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
152 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
153
154 if (useLayerNorm)
155 {
156 if (!useCifg)
157 {
158 inputLayerNormWeights = MakeDecoder<float>(
159 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
160 }
161 forgetLayerNormWeights = MakeDecoder<float>(
162 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
163 cellLayerNormWeights = MakeDecoder<float>(
164 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
165 outputLayerNormWeights = MakeDecoder<float>(
166 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
167 }
168
169 if (!useCifg)
170 {
171 inputToInputWeightsTensor = MakeDecoder<float>(
172 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
173 inputGateBiasTensor = MakeDecoder<float>(
174 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
175 recurrentToInputWeightsTensor = MakeDecoder<float>(
176 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
177 }
178
179 if (usePeephole)
180 {
181 cellToForgetWeightsTensor = MakeDecoder<float>(
182 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
183 cellToOutputWeightsTensor = MakeDecoder<float>(
184 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
185 }
186
187 if (!useCifg && usePeephole)
188 {
189 cellToInputWeightsTensor = MakeDecoder<float>(
190 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
191 }
192
193 if (m_Data.m_Parameters.m_ProjectionEnabled)
194 {
195 projectionWeightsTensor = MakeDecoder<float>(
196 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
197 if (m_ProjectionBiasTensor)
198 {
199 projectionBiasTensor = MakeDecoder<float>(
200 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
201 }
202 }
203
204 LstmImpl(m_Data.m_Parameters,
205 inputInfo,
206 outputInfo,
207 inputToOutputWeightsShape,
208 recurrentToOutputWeightsShape,
209 inputData,
210 outputStateIn,
211 cellStateIn,
212 outputStateOut,
213 cellStateOut,
214 output,
215 cellStateOutDecoder,
216 outputDecoder,
217 inputToInputWeightsTensor,
218 inputToForgetWeightsTensor,
219 inputToCellWeightsTensor,
220 inputToOutputWeightsTensor,
221 recurrentToInputWeightsTensor,
222 recurrentToForgetWeightsTensor,
223 recurrentToCellWeightsTensor,
224 recurrentToOutputWeightsTensor,
225 cellToInputWeightsTensor,
226 cellToForgetWeightsTensor,
227 cellToOutputWeightsTensor,
228 inputGateBiasTensor,
229 forgetGateBiasTensor,
230 cellBiasTensor,
231 outputGateBiasTensor,
232 projectionWeightsTensor,
233 projectionBiasTensor,
234 inputLayerNormWeights,
235 forgetLayerNormWeights,
236 cellLayerNormWeights,
237 outputLayerNormWeights,
238 inputGateScratch,
239 cellScratch,
240 forgetGateScratch,
241 outputGateScratch,
242 inputGateScratchDecoder,
243 cellScratchDecoder,
244 forgetGateScratchDecoder,
245 outputGateScratchDecoder,
246 m_LayerNormEpsilon);
247}
248
249} //namespace armnn
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
RefBaseWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float > > &inputData, std::unique_ptr< Decoder< float > > &outputStateIn, std::unique_ptr< Decoder< float > > &cellStateIn, std::unique_ptr< Encoder< float > > &outputStateOut, std::unique_ptr< Encoder< float > > &cellStateOut, std::unique_ptr< Encoder< float > > &output, std::unique_ptr< Decoder< float > > &cellStateOutDecoder, std::unique_ptr< Decoder< float > > &outputDecoder, std::unique_ptr< Decoder< float > > &inputToInputWeightsTensor, std::unique_ptr< Decoder< float > > &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &inputToCellWeightsTensor, std::unique_ptr< Decoder< float > > &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToInputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &inputGateBiasTensor, std::unique_ptr< Decoder< float > > &forgetGateBiasTensor, std::unique_ptr< Decoder< float > > &cellBiasTensor, std::unique_ptr< Decoder< float > > &outputGateBiasTensor, std::unique_ptr< Decoder< float > > &projectionWeightsTensor, std::unique_ptr< Decoder< float > > &projectionBiasTensor, std::unique_ptr< Decoder< float > > &inputLayerNormWeights, std::unique_ptr< Decoder< float > > &forgetLayerNormWeights, std::unique_ptr< Decoder< float > > &cellLayerNormWeights, std::unique_ptr< Decoder< float > > &outputLayerNormWeights, std::unique_ptr< Encoder< float > > &inputGateScratch, std::unique_ptr< Encoder< float > > &cellScratch, std::unique_ptr< Encoder< float > > &forgetGateScratch, std::unique_ptr< Encoder< float > > &outputGateScratch, std::unique_ptr< Decoder< float > > &inputGateScratchDecoder, std::unique_ptr< Decoder< float > > &cellScratchDecoder, std::unique_ptr< Decoder< float > > &forgetGateScratchDecoder, std::unique_ptr< Decoder< float > > &outputGateScratchDecoder, float layerNormEpsilon)
Definition Lstm.cpp:13
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
Contains information about TensorInfos of a layer.