ArmNN
 24.08
RefLstmWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
19  , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20  , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21  , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22  , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23  , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24  , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25  , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26  , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27  , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28  , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29  , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30  , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31  , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32  , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33  , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34  , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35  , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36  , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37  , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38  , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39  , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41 
43 {
45 }
46 
48 {
49  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50  Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
51 }
52 
53 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54 {
55  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefLstmWorkload_Execute");
56 
57  // This is a porting of the LSTM::Eval() method in the Android code base
58  // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
59 
60  const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
61  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
62 
63  const TensorShape& inputShape = inputInfo.GetShape();
64 
65  std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
66  std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
67  std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
68 
69  std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
70  std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
71 
72  std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
73  std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
74  std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
75 
76  const uint32_t nBatch = inputShape[0];
77  const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
78 
79  const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
80  const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
81  const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
82 
83  // Index the scratch buffers pointers to the global scratch buffer.
84  std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85  std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86  std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
87  std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
88 
89  std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
90  MakeDecoder<float>(outputInfo, outputs[0]->Map());
91  std::unique_ptr<Decoder<float>> cellScratchDecoder =
92  MakeDecoder<float>(outputInfo, outputs[0]->Map());
93  std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
94  MakeDecoder<float>(outputInfo, outputs[0]->Map());
95  std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
96  MakeDecoder<float>(outputInfo, outputs[0]->Map());
97 
98  if (useCifg)
99  {
100  *cellScratch += (0 * nCell * nBatch);
101  *forgetGateScratch += (1 * nCell * nBatch);
102  *outputGateScratch += (2 * nCell * nBatch);
103 
104  *cellScratchDecoder += (0 * nCell * nBatch);
105  *forgetGateScratchDecoder += (1 * nCell * nBatch);
106  *outputGateScratchDecoder += (2 * nCell * nBatch);
107  }
108  else
109  {
110  *inputGateScratch += (0 * nCell * nBatch);
111  *cellScratch += (1 * nCell * nBatch);
112  *forgetGateScratch += (2 * nCell * nBatch);
113  *outputGateScratch += (3 * nCell * nBatch);
114 
115  *inputGateScratchDecoder += (0 * nCell * nBatch);
116  *cellScratchDecoder += (1 * nCell * nBatch);
117  *forgetGateScratchDecoder += (2 * nCell * nBatch);
118  *outputGateScratchDecoder += (3 * nCell * nBatch);
119  }
120 
121  std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
122  std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
123  m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
124  std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
125  m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
126  std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
127  m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
128 
129  std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
130  std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
131  m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
132  std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
133  m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
134  std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
135  m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
136 
137  std::unique_ptr<Decoder<float>> inputGateBiasTensor;
138  std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
139  m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
140  std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
141  m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
142  std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
143  m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
144 
145  std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
146  std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
147  std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
148 
149  std::unique_ptr<Decoder<float>> projectionWeightsTensor;
150  std::unique_ptr<Decoder<float>> projectionBiasTensor;
151 
152  std::unique_ptr<Decoder<float>> inputLayerNormWeights;
153  std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
154  std::unique_ptr<Decoder<float>> cellLayerNormWeights;
155  std::unique_ptr<Decoder<float>> outputLayerNormWeights;
156 
157  const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
158  const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
159 
160  if (useLayerNorm)
161  {
162  if (!useCifg)
163  {
164  inputLayerNormWeights = MakeDecoder<float>(
165  m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
166  }
167  forgetLayerNormWeights = MakeDecoder<float>(
168  m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
169  cellLayerNormWeights = MakeDecoder<float>(
170  m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
171  outputLayerNormWeights = MakeDecoder<float>(
172  m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
173  }
174 
175  if (!useCifg)
176  {
177  inputToInputWeightsTensor = MakeDecoder<float>(
178  m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
179  inputGateBiasTensor = MakeDecoder<float>(
180  m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
181  recurrentToInputWeightsTensor = MakeDecoder<float>(
182  m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
183  }
184 
185  if (usePeephole)
186  {
187  cellToForgetWeightsTensor = MakeDecoder<float>(
188  m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
189  cellToOutputWeightsTensor = MakeDecoder<float>(
190  m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
191  }
192 
193  if (!useCifg && usePeephole)
194  {
195  cellToInputWeightsTensor = MakeDecoder<float>(
196  m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
197  }
198 
200  {
201  projectionWeightsTensor = MakeDecoder<float>(
202  m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
203  if (m_ProjectionBiasTensor)
204  {
205  projectionBiasTensor = MakeDecoder<float>(
206  m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
207  }
208  }
209 
211  inputInfo,
212  outputInfo,
213  inputToOutputWeightsShape,
214  recurrentToOutputWeightsShape,
215  inputData,
216  outputStateIn,
217  cellStateIn,
218  outputStateOut,
219  cellStateOut,
220  output,
221  cellStateOutDecoder,
222  outputDecoder,
223  inputToInputWeightsTensor,
224  inputToForgetWeightsTensor,
225  inputToCellWeightsTensor,
226  inputToOutputWeightsTensor,
227  recurrentToInputWeightsTensor,
228  recurrentToForgetWeightsTensor,
229  recurrentToCellWeightsTensor,
230  recurrentToOutputWeightsTensor,
231  cellToInputWeightsTensor,
232  cellToForgetWeightsTensor,
233  cellToOutputWeightsTensor,
234  inputGateBiasTensor,
235  forgetGateBiasTensor,
236  cellBiasTensor,
237  outputGateBiasTensor,
238  projectionWeightsTensor,
239  projectionBiasTensor,
240  inputLayerNormWeights,
241  forgetLayerNormWeights,
242  cellLayerNormWeights,
243  outputLayerNormWeights,
244  inputGateScratch,
245  cellScratch,
246  forgetGateScratch,
247  outputGateScratch,
248  inputGateScratchDecoder,
249  cellScratchDecoder,
250  forgetGateScratchDecoder,
251  outputGateScratchDecoder,
252  m_LayerNormEpsilon);
253 }
254 
255 } //namespace armnn
Lstm.hpp
armnn::experimental::ExecutionData::m_Data
void * m_Data
Definition: ExecutionData.hpp:16
armnn::TensorInfo
Definition: Tensor.hpp:152
AssignScopedTensorHandle
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
Definition: LstmUtils.cpp:299
ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
Definition: RefWorkloadUtils.hpp:22
armnn::RefLstmWorkload::ExecuteAsync
void ExecuteAsync(ExecutionData &executionData) override
Definition: RefLstmWorkload.cpp:47
armnn::LstmDescriptor::m_PeepholeEnabled
bool m_PeepholeEnabled
Enable/disable peephole.
Definition: Descriptors.hpp:1148
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::LstmImpl
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
armnn::RefLstmWorkload::RefLstmWorkload
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: RefLstmWorkload.cpp:17
armnn::QueueDescriptorWithParameters::m_Parameters
LayerDescriptor m_Parameters
Definition: WorkloadData.hpp:66
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
armnn::GetTensorInfo
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
Definition: RefWorkloadUtils.hpp:33
Activation.hpp
armnn::BoostLogSeverityMapping::info
@ info
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::LstmQueueDescriptor
Definition: WorkloadData.hpp:400
RefWorkloadUtils.hpp
armnn::LstmDescriptor::m_CifgEnabled
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
Definition: Descriptors.hpp:1146
armnn::BaseWorkload< LstmQueueDescriptor >::m_Data
LstmQueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::LstmDescriptor::m_LayerNormEnabled
bool m_LayerNormEnabled
Enable/disable layer normalization.
Definition: Descriptors.hpp:1152
Decoders.hpp
armnn::LayerType::Map
@ Map
armnn::experimental::WorkingMemDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkingMemDescriptor.hpp:20
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::experimental::WorkingMemDescriptor
Definition: WorkingMemDescriptor.hpp:18
armnn::LstmDescriptor::m_ProjectionEnabled
bool m_ProjectionEnabled
Enable/disable the projection layer.
Definition: Descriptors.hpp:1150
LstmUtils.hpp
RefLstmWorkload.hpp
Encoders.hpp
armnn::RefBaseWorkload
Definition: RefBaseWorkload.hpp:13
armnn::RefLstmWorkload::Execute
void Execute() const override
Definition: RefLstmWorkload.cpp:42
armnn::experimental::WorkingMemDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkingMemDescriptor.hpp:21
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::experimental::ExecutionData
Definition: ExecutionData.hpp:14