ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RefLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019,2021-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
19  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41 
43 {
45 }
46 
47 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
48 {
49  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefLstmWorkload_Execute");
50 
51  // This is a porting of the LSTM::Eval() method in the Android code base
52  // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
53 
54  const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
55  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
56 
57  const TensorShape& inputShape = inputInfo.GetShape();
58 
59  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
60  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
61  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
62 
63  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
64  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
65 
66  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
67  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
68  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
69 
70  const uint32_t nBatch = inputShape[0];
71  const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
72 
73  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
74  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
75  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
76 
77  // Index the scratch buffers pointers to the global scratch buffer.
78  std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
79  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
80  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
81  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
82 
83  std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
84  MakeDecoder<float>(outputInfo, outputs[0]->Map());
85  std::unique_ptr<Decoder<float>> cellScratchDecoder =
86  MakeDecoder<float>(outputInfo, outputs[0]->Map());
87  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
88  MakeDecoder<float>(outputInfo, outputs[0]->Map());
89  std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
90  MakeDecoder<float>(outputInfo, outputs[0]->Map());
91 
92  if (useCifg)
93  {
94  *cellScratch += (0 * nCell * nBatch);
95  *forgetGateScratch += (1 * nCell * nBatch);
96  *outputGateScratch += (2 * nCell * nBatch);
97 
98  *cellScratchDecoder += (0 * nCell * nBatch);
99  *forgetGateScratchDecoder += (1 * nCell * nBatch);
100  *outputGateScratchDecoder += (2 * nCell * nBatch);
101  }
102  else
103  {
104  *inputGateScratch += (0 * nCell * nBatch);
105  *cellScratch += (1 * nCell * nBatch);
106  *forgetGateScratch += (2 * nCell * nBatch);
107  *outputGateScratch += (3 * nCell * nBatch);
108 
109  *inputGateScratchDecoder += (0 * nCell * nBatch);
110  *cellScratchDecoder += (1 * nCell * nBatch);
111  *forgetGateScratchDecoder += (2 * nCell * nBatch);
112  *outputGateScratchDecoder += (3 * nCell * nBatch);
113  }
114 
115  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
116  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
117  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
118  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
119  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
120  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
121  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
122 
123  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
124  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
125  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
126  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
127  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
128  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
129  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
130 
131  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
132  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
133  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
134  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
135  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
136  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
137  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
138 
139  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
140  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
141  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
142 
143  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
144  std::unique_ptr<Decoder<float>> projectionBiasTensor;
145 
146  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
147  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
148  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
149  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
150 
151  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
152  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
153 
154  if (useLayerNorm)
155  {
156  if (!useCifg)
157  {
158  inputLayerNormWeights = MakeDecoder<float>(
159  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
160  }
161  forgetLayerNormWeights = MakeDecoder<float>(
162  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
163  cellLayerNormWeights = MakeDecoder<float>(
164  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
165  outputLayerNormWeights = MakeDecoder<float>(
166  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
167  }
168 
169  if (!useCifg)
170  {
171  inputToInputWeightsTensor = MakeDecoder<float>(
172  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
173  inputGateBiasTensor = MakeDecoder<float>(
174  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
175  recurrentToInputWeightsTensor = MakeDecoder<float>(
176  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
177  }
178 
179  if (usePeephole)
180  {
181  cellToForgetWeightsTensor = MakeDecoder<float>(
182  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
183  cellToOutputWeightsTensor = MakeDecoder<float>(
184  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
185  }
186 
187  if (!useCifg && usePeephole)
188  {
189  cellToInputWeightsTensor = MakeDecoder<float>(
190  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
191  }
192 
193  if (m_Data.m_Parameters.m_ProjectionEnabled)
194  {
195  projectionWeightsTensor = MakeDecoder<float>(
196  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
197  if (m_ProjectionBiasTensor)
198  {
199  projectionBiasTensor = MakeDecoder<float>(
200  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
201  }
202  }
203 
204  LstmImpl(m_Data.m_Parameters,
205  inputInfo,
206  outputInfo,
207  inputToOutputWeightsShape,
208  recurrentToOutputWeightsShape,
209  inputData,
210  outputStateIn,
211  cellStateIn,
212  outputStateOut,
213  cellStateOut,
214  output,
215  cellStateOutDecoder,
216  outputDecoder,
217  inputToInputWeightsTensor,
218  inputToForgetWeightsTensor,
219  inputToCellWeightsTensor,
220  inputToOutputWeightsTensor,
221  recurrentToInputWeightsTensor,
222  recurrentToForgetWeightsTensor,
223  recurrentToCellWeightsTensor,
224  recurrentToOutputWeightsTensor,
225  cellToInputWeightsTensor,
226  cellToForgetWeightsTensor,
227  cellToOutputWeightsTensor,
228  inputGateBiasTensor,
229  forgetGateBiasTensor,
230  cellBiasTensor,
231  outputGateBiasTensor,
232  projectionWeightsTensor,
233  projectionBiasTensor,
234  inputLayerNormWeights,
235  forgetLayerNormWeights,
236  cellLayerNormWeights,
237  outputLayerNormWeights,
238  inputGateScratch,
239  cellScratch,
240  forgetGateScratch,
241  outputGateScratch,
242  inputGateScratchDecoder,
243  cellScratchDecoder,
244  forgetGateScratchDecoder,
245  outputGateScratchDecoder,
246  m_LayerNormEpsilon);
247 }
248 
249 } //namespace armnn
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
QueueDescriptor m_Data
Definition: Workload.hpp:74
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.