62 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
63 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
64 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
66 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
67 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
69 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
70 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
71 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
73 const uint32_t nBatch = inputShape[0];
74 const uint32_t nInput = inputShape[1];
76 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
77 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
84 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
87 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->Map());
91 std::unique_ptr<Decoder<float>> cellScratchDecoder =
92 MakeDecoder<float>(outputInfo, outputs[0]->Map());
93 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
94 MakeDecoder<float>(outputInfo, outputs[0]->Map());
95 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
96 MakeDecoder<float>(outputInfo, outputs[0]->Map());
100 *cellScratch += (0 * nCell * nBatch);
101 *forgetGateScratch += (1 * nCell * nBatch);
102 *outputGateScratch += (2 * nCell * nBatch);
104 *cellScratchDecoder += (0 * nCell * nBatch);
105 *forgetGateScratchDecoder += (1 * nCell * nBatch);
106 *outputGateScratchDecoder += (2 * nCell * nBatch);
110 *inputGateScratch += (0 * nCell * nBatch);
111 *cellScratch += (1 * nCell * nBatch);
112 *forgetGateScratch += (2 * nCell * nBatch);
113 *outputGateScratch += (3 * nCell * nBatch);
115 *inputGateScratchDecoder += (0 * nCell * nBatch);
116 *cellScratchDecoder += (1 * nCell * nBatch);
117 *forgetGateScratchDecoder += (2 * nCell * nBatch);
118 *outputGateScratchDecoder += (3 * nCell * nBatch);
121 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
122 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
123 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
124 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
125 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
126 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
127 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
129 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
130 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
131 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
132 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
133 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
134 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
135 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
137 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
138 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
139 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
140 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
141 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
142 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
143 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
145 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
146 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
147 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
149 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
150 std::unique_ptr<Decoder<float>> projectionBiasTensor;
152 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
153 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
154 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
155 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
161 inputLayerNormWeights = MakeDecoder<float>(
162 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
164 forgetLayerNormWeights = MakeDecoder<float>(
165 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
166 cellLayerNormWeights = MakeDecoder<float>(
167 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
168 outputLayerNormWeights = MakeDecoder<float>(
169 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
174 inputToInputWeightsTensor = MakeDecoder<float>(
175 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
176 inputGateBiasTensor = MakeDecoder<float>(
177 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
178 recurrentToInputWeightsTensor = MakeDecoder<float>(
179 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
184 cellToForgetWeightsTensor = MakeDecoder<float>(
185 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
186 cellToOutputWeightsTensor = MakeDecoder<float>(
187 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
190 if (!useCifg && usePeephole)
192 cellToInputWeightsTensor = MakeDecoder<float>(
193 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
198 projectionWeightsTensor = MakeDecoder<float>(
199 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
200 if (m_ProjectionBiasTensor)
202 projectionBiasTensor = MakeDecoder<float>(
203 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
213 nCell, nBatch, *inputGateScratch);
216 nCell, nBatch, *forgetGateScratch);
218 nCell, nBatch, *cellScratch);
220 nCell, nBatch, *outputGateScratch);
227 ZeroVector(*inputGateScratch, nCell * nBatch);
229 ZeroVector(*forgetGateScratch, nCell * nBatch);
231 ZeroVector(*outputGateScratch, nCell * nBatch);
238 nCell, nInput, *inputData, nBatch, *inputGateScratch);
241 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
243 nCell, nInput, *inputData, nBatch, *cellScratch);
245 nCell, nInput, *inputData, nBatch, *outputGateScratch);
251 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
254 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
256 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
258 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
266 nCell, *cellStateIn, nBatch, *inputGateScratch);
271 *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
273 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
275 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
277 Activation(*inputGateScratchDecoder, *inputGateScratch,
286 *cellStateIn, nBatch, *forgetGateScratch);
291 *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
293 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
295 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
297 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
305 *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
307 nCell, *cellScratchDecoder, nBatch, *cellScratch);
309 nCell, *cellScratchDecoder, nBatch, *cellScratch);
323 armnnActivationFunc, a, b);
327 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
329 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
334 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
345 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
350 *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
352 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
354 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
356 Activation(*outputGateScratchDecoder, *outputGateScratch,
362 Activation(*cellStateOutDecoder, *cellScratch,
364 armnnActivationFunc, a, b);
372 if (m_ProjectionBiasTensor)
375 nOutput, nBatch, *output);
378 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
387 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
390 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
bool m_ProjectionEnabled
Enable/disable the projection layer.
const TensorShape & GetShape() const
float m_ClippingThresProj
Clipping threshold value for the projection.
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
Copyright (c) 2021 ARM Limited and Contributors.
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
LayerDescriptor m_Parameters
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
std::vector< ITensorHandle * > m_Inputs
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
LstmQueueDescriptor m_Data
DataType GetDataType() const
void ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) override
bool m_PeepholeEnabled
Enable/disable peephole.
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
uint32_t m_ActivationFunc
The activation function to use.
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
float m_ClippingThresCell
Clipping threshold value for the cell state.
void Execute() const override
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::vector< ITensorHandle * > m_Outputs
std::vector< ITensorHandle * > m_Outputs
bool m_LayerNormEnabled
Enable/disable layer normalization.
Contains information about inputs and outputs to a layer.
std::vector< ITensorHandle * > m_Inputs
void SetActivationParameters(uint32_t activation, armnn::ActivationFunction &outArmnnActivation, float &outA, float &outB)
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers