52 std::vector<ITensorHandle*> outputs)
const
64 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
66 if (!
m_Data.m_Parameters.m_TimeMajor)
70 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
79 unsigned int maxTime = inputShape[0];
80 unsigned int batchSize = inputShape[1];
81 unsigned int outputSize = outputShape[2];
82 unsigned int inputSize = inputShape[2];
84 TensorInfo scratchInfo = outputInfo;
87 std::vector<float> inputGateScratchBuffer;
88 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
89 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
90 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
92 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
93 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
95 void* outputStateOutData = outputStateOutBuffer.data();
96 void* cellStateOutData = cellStateOutBuffer.data();
98 std::unique_ptr<Encoder<float>> inputGateScratch;
99 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
100 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
101 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
103 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
104 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
105 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
106 forgetGateScratchBuffer.data());
107 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108 outputGateScratchBuffer.data());
110 const bool useCifg =
m_Data.m_Parameters.m_CifgEnabled;
111 const bool usePeephole =
m_Data.m_Parameters.m_PeepholeEnabled;
112 const bool useLayerNorm =
m_Data.m_Parameters.m_LayerNormEnabled;
116 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
117 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
118 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
122 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
123 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
125 TensorInfo lstmInputInfo = inputInfo;
126 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
127 lstmInputInfo.
SetShape(batchInputShape);
129 TensorInfo lstmOutputInfo = outputInfo;
130 lstmOutputInfo.
SetShape({batchSize, outputSize});
132 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
133 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
134 unsigned int nOutput = recurrentToOutputWeightsShape[1];
135 auto outputStateInData = inputs[1]->Map();
136 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
138 auto cellStateInData = inputs[2]->Map();
139 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
141 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
142 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
143 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
144 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
145 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
147 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
148 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
149 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
150 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
151 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
152 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
153 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
155 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
156 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
157 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
158 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
159 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
160 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
161 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
163 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
164 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
165 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
166 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
167 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
168 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
169 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
171 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
172 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
173 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
175 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
176 std::unique_ptr<Decoder<float>> projectionBiasTensor;
178 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
179 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
180 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
181 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
187 inputLayerNormWeights = MakeDecoder<float>(
188 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
190 forgetLayerNormWeights = MakeDecoder<float>(
191 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
192 cellLayerNormWeights = MakeDecoder<float>(
193 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
194 outputLayerNormWeights = MakeDecoder<float>(
195 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
200 inputToInputWeightsTensor = MakeDecoder<float>(
201 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
202 inputGateBiasTensor = MakeDecoder<float>(
203 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
204 recurrentToInputWeightsTensor = MakeDecoder<float>(
205 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
210 cellToForgetWeightsTensor = MakeDecoder<float>(
211 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
212 cellToOutputWeightsTensor = MakeDecoder<float>(
213 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
216 if (!useCifg && usePeephole)
218 cellToInputWeightsTensor = MakeDecoder<float>(
219 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
222 if (
m_Data.m_Parameters.m_ProjectionEnabled)
224 projectionWeightsTensor = MakeDecoder<float>(
225 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
226 if (m_ProjectionBiasTensor)
228 projectionBiasTensor = MakeDecoder<float>(
229 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
233 unsigned int batchInputSize = batchSize * inputSize;
234 unsigned int batchOutputSize = batchSize * nOutput;
236 for (
unsigned int t = 0; t < maxTime; ++t)
241 inputToOutputWeightsShape,
242 recurrentToOutputWeightsShape,
251 inputToInputWeightsTensor,
252 inputToForgetWeightsTensor,
253 inputToCellWeightsTensor,
254 inputToOutputWeightsTensor,
255 recurrentToInputWeightsTensor,
256 recurrentToForgetWeightsTensor,
257 recurrentToCellWeightsTensor,
258 recurrentToOutputWeightsTensor,
259 cellToInputWeightsTensor,
260 cellToForgetWeightsTensor,
261 cellToOutputWeightsTensor,
263 forgetGateBiasTensor,
265 outputGateBiasTensor,
266 projectionWeightsTensor,
267 projectionBiasTensor,
268 inputLayerNormWeights,
269 forgetLayerNormWeights,
270 cellLayerNormWeights,
271 outputLayerNormWeights,
276 inputGateScratchDecoder,
278 forgetGateScratchDecoder,
279 outputGateScratchDecoder,
282 currentInputData += batchInputSize;
283 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
284 currentOutputData += batchOutputSize;
285 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
286 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
289 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
292 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
295 if (!
m_Data.m_Parameters.m_TimeMajor)
298 const PermutationVector& mappings = {1U, 0U, 2U};
299 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
300 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
unsigned int GetNumElements() const
const TensorShape & GetShape() const
void SetShape(const TensorShape &newShape)
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.