ArmNN
 25.11
Loading...
Searching...
No Matches
RefQLstmWorkload.cpp
Go to the documentation of this file.
1//
2// Copyright © 2020-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
18 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
22
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27
28 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
31
32 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
36
37 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
39
40 , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
47{
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
51void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
52{
53 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefQLstmWorkload_Execute");
54
55 // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
56 // method in the Android code base
57 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
58 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
59 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
60 const DataType& internalType = armnn::DataType::QSymmS16;
61
62 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
63 const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
64 const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
65
66 const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
67 const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
68 const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
69
70 const TensorShape& inputShape = inputInfo.GetShape();
71 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
72 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
73
74 // Infer numBatches, inputSize, outputSize and numUnits
75 const uint32_t numBatches = inputShape[0];
76 const uint32_t inputSize = inputShape[1];
77 const uint32_t outputSize = outputStateInShape[1];
78 const uint32_t numUnits = cellStateInShape[1];
79
80 // Optional param settings
81 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
82 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
83 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
84 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
85
86 // Input decoders
87 std::unique_ptr<Decoder<float>> inputDecoder =
88 MakeDecoder<float>(inputInfo, inputs[0]->Map());
89 std::unique_ptr<Decoder<float>> outputStateInDecoder =
90 MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
91 std::unique_ptr<Decoder<float>> cellStateInDecoder =
92 MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
93
94 // Output decoders
95 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
96 MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
97 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
98 MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
99 std::unique_ptr<Decoder<float>> outputDecoder =
100 MakeDecoder<float>(outputInfo, outputs[2]->Map());
101
102 // Output encoders
103 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
104 MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
105 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
106 MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
107 std::unique_ptr<Encoder<float>> outputEncoder =
108 MakeEncoder<float>(outputInfo, outputs[2]->Map());
109
110 // Weights decoders
111 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
112 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
113 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
114 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
115 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
116 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
117
118 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
119 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
120 m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
121 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
122 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
123 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
124 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
125 m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
126
127 // Optional CIFG params
128 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
129 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
130 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
131
132 // Optional Peephole params
133 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
134 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
135 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
136
137 // Optional Projection params
138 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
139 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
140
141 // Optional Layer Norm params
142 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
143 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
144 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
145 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
146
147 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
148 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
149 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
150 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
151
152 // Int16 vectors for internal state data (to be decoded/encoded)
153 const uint32_t stateTensorSize = numBatches * numUnits;
154 std::vector<int16_t> inputGateData(stateTensorSize);
155 std::vector<int16_t> cellGateData(stateTensorSize);
156 std::vector<int16_t> forgetGateData(stateTensorSize);
157 std::vector<int16_t> outputGateData(stateTensorSize);
158 std::vector<int32_t> hiddenStateData(stateTensorSize);
159 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
160
161 armnn::TensorInfo inputGateInfo(
163 armnn::TensorInfo cellGateInfo(
164 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
165 armnn::TensorInfo forgetGateInfo(
166 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
167 armnn::TensorInfo outputGateInfo(
168 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
169 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
171 m_Data.m_Parameters.m_HiddenStateScale,
172 m_Data.m_Parameters.m_HiddenStateZeroPoint);
173 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
175 outputInfo.GetQuantizationScale(),
176 outputInfo.GetQuantizationOffset());
177
178 // Decoders/Encoders for internal states
179 std::unique_ptr<Decoder<float>> inputGateDecoder =
180 MakeDecoder<float>(inputGateInfo, inputGateData.data());
181 std::unique_ptr<Decoder<float>> cellGateDecoder =
182 MakeDecoder<float>(cellGateInfo, cellGateData.data());
183 std::unique_ptr<Decoder<float>> forgetGateDecoder =
184 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
185 std::unique_ptr<Decoder<float>> outputGateDecoder =
186 MakeDecoder<float>(outputGateInfo, outputGateData.data());
187 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
188 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
189
190 std::unique_ptr<Encoder<float>> inputGateEncoder =
191 MakeEncoder<float>(inputGateInfo, inputGateData.data());
192 std::unique_ptr<Encoder<float>> cellGateEncoder =
193 MakeEncoder<float>(cellGateInfo, cellGateData.data());
194 std::unique_ptr<Encoder<float>> forgetGateEncoder =
195 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
196 std::unique_ptr<Encoder<float>> outputGateEncoder =
197 MakeEncoder<float>(outputGateInfo, outputGateData.data());
198 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
199 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
200
201 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
202 std::unique_ptr<Decoder<float>> outputInt16Decoder =
203 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
204 std::unique_ptr<Encoder<float>> outputInt16Encoder =
205 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
206
207 // Create decoders for optional params if they are enabled
208 if (!cifgEnabled)
209 {
210 inputToInputWeightsDecoder = MakeDecoder<float>(
211 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
212 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
213 m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
214 }
215
216 if (peepholeEnabled)
217 {
218 if (!cifgEnabled)
219 {
220 cellToInputWeightsDecoder = MakeDecoder<float>(
221 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
222 }
223 cellToForgetWeightsDecoder = MakeDecoder<float>(
224 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
225 cellToOutputWeightsDecoder = MakeDecoder<float>(
226 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
227 }
228
229 if (projectionEnabled)
230 {
231 projectionWeightsDecoder = MakeDecoder<float>(
232 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
233 if (m_ProjectionBiasTensor)
234 {
235 projectionBiasDecoder = MakeDecoder<float>(
236 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
237 }
238 }
239
240 if (layerNormEnabled)
241 {
242 if (!cifgEnabled)
243 {
244 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
245 m_InputLayerNormWeightsTensor->GetConstTensor<void>());
246
247 // Bias only used if layer norm enabled
248 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
249 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
250 inputGateBiasDecoder = MakeDecoder<float>(
251 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
252 }
253
254 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
255 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
256 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
257 cellLayerNormWeightsDecoder = MakeDecoder<float>(
258 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
259 outputLayerNormWeightsDecoder = MakeDecoder<float>(
260 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
261 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
262
263 // Bias only used if layer norm enabled
264 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
265 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
266 forgetGateBiasDecoder = MakeDecoder<float>(
267 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
268
269 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
270 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
271 cellGateBiasDecoder = MakeDecoder<float>(
272 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
273
274 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
275 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
276 outputGateBiasDecoder = MakeDecoder<float>(
277 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
278 }
279
280 // Initialize internal state tensors with zeroes.
281 if (!cifgEnabled)
282 {
283 ZeroVector(*inputGateEncoder, stateTensorSize);
284 }
285 ZeroVector(*forgetGateEncoder, stateTensorSize);
286 ZeroVector(*cellGateEncoder, stateTensorSize);
287 ZeroVector(*outputGateEncoder, stateTensorSize);
288 ZeroVector(*hiddenStateEncoder, stateTensorSize);
289
290 // Input weights * Input
291 if (!cifgEnabled)
292 {
293 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
294 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
295 }
296
297 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
298 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
299
300 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
301 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
302
303 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
304 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
305
306 // Recurrent weights * OutputStateIn
307 if (!cifgEnabled)
308 {
309 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
310 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
311 }
312
313 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
314 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
315
316 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
317 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
318
319 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
320 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
321
322 // Input gate.
323 if (!cifgEnabled)
324 {
325 if (peepholeEnabled)
326 {
327 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
328 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
329 }
330
331 if (layerNormEnabled)
332 {
333 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
334 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
335 1024);
336 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
337
338 MeanStddevNormalization(*inputGateDecoder,
339 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
340
341 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
342
343 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
344 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
345
346 inputGateInfo.SetQuantizationScale(1.f / 4096);
347 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
348
349 VectorBatchVectorAdd(*inputGateBiasDecoder,
350 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
351
352 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
353 }
354
355 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
356 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
357
358 // Input gate sigmoid
359 Activation(*inputGateDecoder, *inputGateEncoder,
360 TensorInfo({numUnits, numBatches}, internalType),
362
363 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
364 }
365
366 // Forget gate
367 if (peepholeEnabled)
368 {
369 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
370 *cellStateInDecoder, numBatches, *forgetGateEncoder);
371 }
372
373 if (layerNormEnabled)
374 {
375 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
376 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
377 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
378 1024);
379 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
380
381
382
383 MeanStddevNormalization(*forgetGateDecoder,
384 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
385
386
387 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
388
389 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
390 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
391
392
393 // Dequantize layer norm output to (1 / 4096)
394 forgetGateInfo.SetQuantizationScale(1.f / 4096);
395 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
396
397 VectorBatchVectorAdd(*forgetGateBiasDecoder,
398 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
399
400
401 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
402 }
403
404 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
405 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
406
407 // Forget gate sigmoid
408 Activation(*forgetGateDecoder, *forgetGateEncoder,
409 TensorInfo({numUnits, numBatches}, internalType),
411
412 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
413
414 // Cell (Modulation) gate
415 if (layerNormEnabled)
416 {
417 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
418 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
419 1024);
420 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
421
422 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
423
424 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
425
426 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
427 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
428
429 cellGateInfo.SetQuantizationScale(1.f / 4096);
430 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
431
432 VectorBatchVectorAdd(*cellGateBiasDecoder,
433 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
434
435 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
436 }
437
438 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
439 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
440
441 // Cell (Modulation) gate tanH
442 Activation(*cellGateDecoder, *cellGateEncoder,
443 TensorInfo({numUnits, numBatches}, internalType),
444 ActivationFunction::TanH, 1.0f, 1.0f);
445
446 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
447
448 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
449
450 if (cifgEnabled)
451 {
452 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
454 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
455 }
456 else
457 {
459 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
460 }
461
462 // Final cell state out calculated here
463 if (m_Data.m_Parameters.m_CellClip > 0.0)
464 {
465 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
466 }
467
468 // Output gate.
469 if (peepholeEnabled)
470 {
471 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
472 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
473 }
474
475 if (layerNormEnabled)
476 {
477 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
478 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
479 1024);
480 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
481
482 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
483
484 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
485
486 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
487 numBatches, *outputGateEncoder);
488
489 outputGateInfo.SetQuantizationScale(1.f / 4096);
490 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
491
492 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
493
494 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
495 }
496
497 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
498 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
499
500 // Output gate sigmoid
501 Activation(*outputGateDecoder, *outputGateEncoder,
502 TensorInfo({numUnits, numBatches}, internalType),
504
505 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
506
507 // Hidden state tanH
508 Activation(*cellStateOutDecoder, *cellGateEncoder,
509 TensorInfo({numUnits, numBatches}, internalType),
510 ActivationFunction::TanH, 1.0f, 1.0f);
511
512 // Final hidden state output
513 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
514
515 // Projection
516 if (m_Data.m_Parameters.m_ProjectionEnabled)
517 {
518 if (m_ProjectionBiasTensor)
519 {
520 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
521 }
522
523 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
524 numBatches, *outputInt16Encoder);
525
526 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
527
528 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
529 {
530 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
531 }
532 }
533 else
534 {
535 // Output has same quantization scale as hidden state if projection is disabled
536 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
537 }
538
539 // output == outputStateOut
540 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
541}
542
543} //namespace armnn
void CopyVector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &outResult)
void MeanStddevNormalization(armnn::Decoder< float > &input_vector, armnn::Encoder< float > &output_vector, uint32_t v_size, uint32_t n_batch, float normalization_epsilon)
Definition LstmUtils.cpp:40
void ClipVector(armnn::Decoder< float > &vector, uint32_t vSize, float absLimit, armnn::Encoder< float > &outResult)
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
void VectorBatchVectorCwiseProduct(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void VectorVectorCwiseProductAccumulate(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorAdd(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition LstmUtils.cpp:16
void ZeroVector(armnn::Encoder< float > &vector, uint32_t vSize)
Definition LstmUtils.cpp:76
void VectorVectorCwiseProduct(armnn::Decoder< float > &vector1, armnn::Decoder< float > &vector2, uint32_t vSize, armnn::Encoder< float > &outResult)
void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Decoder< float > &batchVector, uint32_t nBatch, armnn::Encoder< float > &outResult)
void VectorBatchVectorAssign(armnn::Decoder< float > &vector, uint32_t vSize, uint32_t nBatch, armnn::Encoder< float > &outBatchVector)
void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder< float > &matrix, uint32_t mRows, uint32_t mCols, armnn::Decoder< float > &vector, uint32_t nBatch, armnn::Encoder< float > &outResult)
Definition LstmUtils.cpp:87
void Sub1Vector(armnn::Decoder< float > &vector, uint32_t vSize, armnn::Encoder< float > &result)
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
RefBaseWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
float GetQuantizationScale() const
Definition Tensor.cpp:461
const TensorShape & GetShape() const
Definition Tensor.hpp:193
int32_t GetQuantizationOffset() const
Definition Tensor.cpp:482
Copyright (c) 2021 ARM Limited and Contributors.
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
DataType
Definition Types.hpp:49
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
Contains information about TensorInfos of a layer.