ArmNN
 24.08
Lstm.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Activation.hpp"
7 #include "Lstm.hpp"
8 #include "LstmUtils.hpp"
9 
10 namespace armnn
11 {
12 
13 void LstmImpl(const LstmDescriptor& descriptor,
14  const TensorInfo& inputInfo,
15  const TensorInfo& outputInfo,
16  const TensorShape& inputToOutputWeightsShape,
17  const TensorShape& recurrentToOutputWeightsShape,
18  std::unique_ptr<Decoder<float>>& inputData,
19  std::unique_ptr<Decoder<float>>& outputStateIn,
20  std::unique_ptr<Decoder<float>>& cellStateIn,
21  std::unique_ptr<Encoder<float>>& outputStateOut,
22  std::unique_ptr<Encoder<float>>& cellStateOut,
23  std::unique_ptr<Encoder<float>>& output,
24  std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
25  std::unique_ptr<Decoder<float>>& outputDecoder,
26  std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
27  std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
28  std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
29  std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
30  std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
31  std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
32  std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
33  std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
34  std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
35  std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
36  std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
37  std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
38  std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
39  std::unique_ptr<Decoder<float>>& cellBiasTensor,
40  std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
41  std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
42  std::unique_ptr<Decoder<float>>& projectionBiasTensor,
43  std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
44  std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
45  std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
46  std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
47  std::unique_ptr<Encoder<float>>& inputGateScratch,
48  std::unique_ptr<Encoder<float>>& cellScratch,
49  std::unique_ptr<Encoder<float>>& forgetGateScratch,
50  std::unique_ptr<Encoder<float>>& outputGateScratch,
51  std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
52  std::unique_ptr<Decoder<float>>& cellScratchDecoder,
53  std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
54  std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
55  float layerNormEpsilon)
56 {
57  // This is a porting of the LSTM::Eval() method in the Android code base
58  // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
59 
60  const TensorShape& inputShape = inputInfo.GetShape();
61  const DataType& outputType = outputInfo.GetDataType();
62 
63  const uint32_t nBatch = inputShape[0];
64  const uint32_t nInput = inputShape[1];
65 
66  const uint32_t nCell = inputToOutputWeightsShape[0];
67  const uint32_t nOutput = recurrentToOutputWeightsShape[1];
68 
69  const bool useCifg = descriptor.m_CifgEnabled;
70  const bool usePeephole = descriptor.m_PeepholeEnabled;
71  const bool useLayerNorm = descriptor.m_LayerNormEnabled;
72 
73  if (!useLayerNorm)
74  {
75  // Initialize scratch buffers with bias.
76  if (!useCifg)
77  {
78  VectorBatchVectorAssign(*inputGateBiasTensor,
79  nCell, nBatch, *inputGateScratch);
80  }
81  VectorBatchVectorAssign(*forgetGateBiasTensor,
82  nCell, nBatch, *forgetGateScratch);
83  VectorBatchVectorAssign(*cellBiasTensor,
84  nCell, nBatch, *cellScratch);
85  VectorBatchVectorAssign(*outputGateBiasTensor,
86  nCell, nBatch, *outputGateScratch);
87  }
88  else
89  {
90  // Initialize scratch buffers with zeroes.
91  if (!useCifg)
92  {
93  ZeroVector(*inputGateScratch, nCell * nBatch);
94  }
95  ZeroVector(*forgetGateScratch, nCell * nBatch);
96  ZeroVector(*cellScratch , nCell * nBatch);
97  ZeroVector(*outputGateScratch, nCell * nBatch);
98  }
99 
100  // For each batch and cell: compute input_weight * input.
101  if (!useCifg)
102  {
103  MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
104  nCell, nInput, *inputData, nBatch, *inputGateScratch);
105  }
106  MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
107  nCell, nInput, *inputData, nBatch, *forgetGateScratch);
108  MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
109  nCell, nInput, *inputData, nBatch, *cellScratch);
110  MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
111  nCell, nInput, *inputData, nBatch, *outputGateScratch);
112 
113  // For each batch and cell: compute recurrent_weight * output_state.
114  if (!useCifg)
115  {
116  MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
117  nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
118  }
119  MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
120  nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
121  MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
122  nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
123  MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
124  nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
125 
126  // For each batch and cell: update input gate.
127  if (!useCifg)
128  {
129  if (usePeephole)
130  {
131  VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
132  nCell, *cellStateIn, nBatch, *inputGateScratch);
133  }
134  if (useLayerNorm)
135  {
136  MeanStddevNormalization(*inputGateScratchDecoder,
137  *inputGateScratch, nCell, nBatch, layerNormEpsilon);
138  VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
139  nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
140  VectorBatchVectorAdd(*inputGateBiasTensor,
141  nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
142  }
143  Activation(*inputGateScratchDecoder, *inputGateScratch,
144  TensorInfo({nCell, nBatch}, outputType),
146  }
147 
148  // For each batch and cell: update forget gate.
149  if (usePeephole)
150  {
151  VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
152  *cellStateIn, nBatch, *forgetGateScratch);
153  }
154  if (useLayerNorm)
155  {
156  MeanStddevNormalization(*forgetGateScratchDecoder,
157  *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
158  VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
159  nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
160  VectorBatchVectorAdd(*forgetGateBiasTensor,
161  nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
162  }
163  Activation(*forgetGateScratchDecoder, *forgetGateScratch,
164  TensorInfo({nCell, nBatch}, outputType),
166 
167  // For each batch and cell: update the cell.
168  if (useLayerNorm)
169  {
170  MeanStddevNormalization(*cellScratchDecoder,
171  *cellScratch, nCell, nBatch, layerNormEpsilon);
172  VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
173  nCell, *cellScratchDecoder, nBatch, *cellScratch);
174  VectorBatchVectorAdd(*cellBiasTensor,
175  nCell, *cellScratchDecoder, nBatch, *cellScratch);
176  }
177 
178  VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
179 
180  ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
181  float a = 0;
182  float b = 0;
183  SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b);
184 
185  if (descriptor.m_ActivationFunc > 0)
186  {
187  Activation(*cellScratchDecoder, *cellScratch,
188  TensorInfo({nCell, nBatch}, outputType),
189  armnnActivationFunc, a, b);
190  }
191  if (useCifg)
192  {
193  Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
195  *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
196  }
197  else
198  {
200  *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
201  }
202  if (descriptor.m_ClippingThresCell > 0.0)
203  {
204  ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut);
205  }
206 
207  // For each batch and cell: update the output gate.
208  if (usePeephole)
209  {
210  VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
211  nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
212  }
213  if (useLayerNorm)
214  {
215  MeanStddevNormalization(*outputGateScratchDecoder,
216  *outputGateScratch, nCell, nBatch, layerNormEpsilon);
217  VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
218  nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
219  VectorBatchVectorAdd(*outputGateBiasTensor,
220  nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
221  }
222  Activation(*outputGateScratchDecoder, *outputGateScratch,
223  TensorInfo({nCell, nBatch}, outputType),
225 
226  if (descriptor.m_ActivationFunc > 0)
227  {
228  Activation(*cellStateOutDecoder, *cellScratch,
229  TensorInfo({nCell, nBatch}, outputType),
230  armnnActivationFunc, a, b);
231  }
232 
233  VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
234 
235  // For each batch: update the projection and output_state.
236  if (descriptor.m_ProjectionEnabled)
237  {
238  if (projectionBiasTensor)
239  {
240  VectorBatchVectorAssign(*projectionBiasTensor,
241  nOutput, nBatch, *output);
242  }
243  MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
244  nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
245 
246  if (descriptor.m_ClippingThresProj > 0.0)
247  {
248  ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output);
249  }
250  }
251  else
252  {
253  CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
254  }
255 
256  CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
257 }
258 
259 } //namespace armnn
VectorVectorCwiseProduct
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:187
armnn::Decoder< float >
MeanStddevNormalization
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
Definition: LstmUtils.cpp:40
VectorBatchVectorCwiseProductAccumulate
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:131
Lstm.hpp
VectorBatchVectorAdd
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:16
armnn::TensorInfo
Definition: Tensor.hpp:152
VectorBatchVectorCwiseProduct
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:152
ClipVector
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:229
MatrixBatchVectorMultiplyAccumulate
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:87
CopyVector
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:244
Sub1Vector
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
Definition: LstmUtils.cpp:173
armnn::LstmDescriptor::m_PeepholeEnabled
bool m_PeepholeEnabled
Enable/disable peephole.
Definition: Descriptors.hpp:1148
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Encoder< float >
armnn::LstmDescriptor::m_ClippingThresProj
float m_ClippingThresProj
Clipping threshold value for the projection.
Definition: Descriptors.hpp:1144
armnn::LstmImpl
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
armnn::DataType
DataType
Definition: Types.hpp:48
Activation.hpp
armnn::ActivationFunction
ActivationFunction
Definition: Types.hpp:86
armnn::TensorInfo::GetDataType
DataType GetDataType() const
Definition: Tensor.hpp:200
ZeroVector
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
Definition: LstmUtils.cpp:76
armnn::LstmDescriptor
An LstmDescriptor for the LstmLayer.
Definition: Descriptors.hpp:1102
armnn::LstmDescriptor::m_CifgEnabled
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
Definition: Descriptors.hpp:1146
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::LstmDescriptor::m_LayerNormEnabled
bool m_LayerNormEnabled
Enable/disable layer normalization.
Definition: Descriptors.hpp:1152
VectorVectorCwiseProductAccumulate
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
Definition: LstmUtils.cpp:204
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::LstmDescriptor::m_ProjectionEnabled
bool m_ProjectionEnabled
Enable/disable the projection layer.
Definition: Descriptors.hpp:1150
SetActivationParameters
void SetActivationParameters(uint32_t activation, armnn::ActivationFunction &outArmnnActivation, float &outA, float &outB)
Definition: LstmUtils.cpp:258
armnn::Activation
float Activation(float in, ActivationFunction function, float a, float b)
Definition: Activation.cpp:13
LstmUtils.hpp
armnn::LstmDescriptor::m_ActivationFunc
uint32_t m_ActivationFunc
The activation function to use.
Definition: Descriptors.hpp:1140
armnn::LstmDescriptor::m_ClippingThresCell
float m_ClippingThresCell
Clipping threshold value for the cell state.
Definition: Descriptors.hpp:1142
armnn::ActivationFunction::Sigmoid
@ Sigmoid
VectorBatchVectorAssign
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
Definition: LstmUtils.cpp:113