55 float layerNormEpsilon)
63 const uint32_t nBatch = inputShape[0];
64 const uint32_t nInput = inputShape[1];
66 const uint32_t nCell = inputToOutputWeightsShape[0];
67 const uint32_t nOutput = recurrentToOutputWeightsShape[1];
79 nCell, nBatch, *inputGateScratch);
82 nCell, nBatch, *forgetGateScratch);
84 nCell, nBatch, *cellScratch);
86 nCell, nBatch, *outputGateScratch);
95 ZeroVector(*forgetGateScratch, nCell * nBatch);
97 ZeroVector(*outputGateScratch, nCell * nBatch);
104 nCell, nInput, *inputData, nBatch, *inputGateScratch);
107 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
109 nCell, nInput, *inputData, nBatch, *cellScratch);
111 nCell, nInput, *inputData, nBatch, *outputGateScratch);
117 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
120 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
122 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
124 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
132 nCell, *cellStateIn, nBatch, *inputGateScratch);
137 *inputGateScratch, nCell, nBatch, layerNormEpsilon);
139 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
141 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
143 Activation(*inputGateScratchDecoder, *inputGateScratch,
152 *cellStateIn, nBatch, *forgetGateScratch);
157 *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
159 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
161 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
163 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
171 *cellScratch, nCell, nBatch, layerNormEpsilon);
173 nCell, *cellScratchDecoder, nBatch, *cellScratch);
175 nCell, *cellScratchDecoder, nBatch, *cellScratch);
189 armnnActivationFunc, a, b);
193 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
195 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
200 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
211 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
216 *outputGateScratch, nCell, nBatch, layerNormEpsilon);
218 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
220 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
222 Activation(*outputGateScratchDecoder, *outputGateScratch,
228 Activation(*cellStateOutDecoder, *cellScratch,
230 armnnActivationFunc, a, b);
238 if (projectionBiasTensor)
241 nOutput, nBatch, *output);
244 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
253 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
256 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);