55 float layerNormEpsilon)
63 const uint32_t nBatch = inputShape[0];
64 const uint32_t nInput = inputShape[1];
66 const uint32_t nCell = inputToOutputWeightsShape[0];
67 const uint32_t nOutput = recurrentToOutputWeightsShape[1];
79 nCell, nBatch, *inputGateScratch);
82 nCell, nBatch, *forgetGateScratch);
84 nCell, nBatch, *cellScratch);
86 nCell, nBatch, *outputGateScratch);
95 ZeroVector(*forgetGateScratch, nCell * nBatch);
97 ZeroVector(*outputGateScratch, nCell * nBatch);
104 nCell, nInput, *inputData, nBatch, *inputGateScratch);
107 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
109 nCell, nInput, *inputData, nBatch, *cellScratch);
111 nCell, nInput, *inputData, nBatch, *outputGateScratch);
117 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
120 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
122 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
124 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
132 nCell, *cellStateIn, nBatch, *inputGateScratch);
137 *inputGateScratch, nCell, nBatch, layerNormEpsilon);
139 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
141 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
143 Activation(*inputGateScratchDecoder, *inputGateScratch,
152 *cellStateIn, nBatch, *forgetGateScratch);
157 *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
159 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
161 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
163 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
171 *cellScratch, nCell, nBatch, layerNormEpsilon);
173 nCell, *cellScratchDecoder, nBatch, *cellScratch);
175 nCell, *cellScratchDecoder, nBatch, *cellScratch);
189 armnnActivationFunc, a, b);
193 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
195 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
200 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
211 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
216 *outputGateScratch, nCell, nBatch, layerNormEpsilon);
218 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
220 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
222 Activation(*outputGateScratchDecoder, *outputGateScratch,
228 Activation(*cellStateOutDecoder, *cellScratch,
230 armnnActivationFunc, a, b);
238 if (projectionBiasTensor)
241 nOutput, nBatch, *output);
244 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
253 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
256 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
void SetActivationParameters(uint32_t activation, armnn::ActivationFunction &outArmnnActivation, float &outA, float &outB)
const TensorShape & GetShape() const
DataType GetDataType() const
Copyright (c) 2021 ARM Limited and Contributors.
float Activation(float in, ActivationFunction function, float a, float b)
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
An LstmDescriptor for the LstmLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
float m_ClippingThresCell
Clipping threshold value for the cell state.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_ClippingThresProj
Clipping threshold value for the projection.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
uint32_t m_ActivationFunc
The activation function to use.