ArmNN
 26.01
Loading...
Searching...
No Matches
QLstmLayer.cpp
Go to the documentation of this file.
1//
2// Copyright © 2020-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "QLstmLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
10#include <armnn/TypesUtils.hpp>
13
14namespace armnn
15{
16
17QLstmLayer::QLstmLayer(const QLstmDescriptor& param, const char* name)
18 : LayerWithParameters(3, 3, LayerType::QLstm, param, name)
19{
20}
21
22std::unique_ptr<IWorkload> QLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23{
24 QLstmQueueDescriptor descriptor;
25
26 // Basic parameters
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
36
37 // CIFG parameters
39 {
43 }
44
45 // Projection parameters
47 {
50 }
51
52 // Peephole parameters
54 {
56 {
58 }
59
62 }
63
64 // Layer normalisation parameters
66 {
68 {
70 }
74 }
75
76 SetAdditionalInfo(descriptor);
77
78 return factory.CreateWorkload(LayerType::QLstm, descriptor, PrepInfoAndDesc(descriptor));
79}
80
82{
83 auto layer = CloneBase<QLstmLayer>(graph, m_Param, GetName());
84
85 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
103
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
112 }
113
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
120 }
121
123 {
124 if (!m_Param.m_CifgEnabled) {
125 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127 }
128
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
133 }
134
136 {
137 if (!m_Param.m_CifgEnabled) {
138 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
140 }
141
142 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
144 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
146 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
148 }
149
150 return std::move(layer);
151}
152
153std::vector<TensorShape> QLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
154{
155 if (inputShapes.size() != 3)
156 {
157 throw armnn::Exception("inputShapes' size is \"" + std::to_string(inputShapes.size()) +
158 "\" - should be \"3\".");
159 }
160
161 // Get input values for validation
162 unsigned int batchSize = inputShapes[0][0];
163 unsigned int outputSize = inputShapes[1][1];
164 unsigned int numUnits = inputShapes[2][1];
165
166 std::vector<TensorShape> outShapes;
167 outShapes.push_back(TensorShape({ batchSize, outputSize })); // outputStateOut
168 outShapes.push_back(TensorShape({ batchSize, numUnits })); // cellStateOut
169 outShapes.push_back(TensorShape({ batchSize, outputSize })); // output
170
171 return outShapes;
172}
173
175{
177
178 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
179
181
182 auto inferredShapes = InferOutputShapes(
183 {
184 GetInputSlot(0).GetTensorInfo().GetShape(), // input
185 GetInputSlot(1).GetTensorInfo().GetShape(), // previousOutputIn
186 GetInputSlot(2).GetTensorInfo().GetShape() // previousCellStateIn
187 });
188
189 if (inferredShapes.size() != 3)
190 {
191 throw armnn::LayerValidationException("inferredShapes has "
192 + std::to_string(inferredShapes.size()) +
193 " element(s) - should only have 3.");
194 }
195
196 // Check if the weights are nullptr for basic params
198 {
199 throw armnn::LayerValidationException("QLstmLayer: "
200 "m_BasicParameters.m_InputToForgetWeights should not be null.");
201 }
202
204 {
205 throw armnn::LayerValidationException("QLstmLayer: "
206 "m_BasicParameters.m_InputToCellWeights should not be null.");
207 }
208
210 {
211 throw armnn::LayerValidationException("QLstmLayer: "
212 "m_BasicParameters.m_InputToOutputWeights should not be null.");
213 }
214
216 {
217 throw armnn::LayerValidationException("QLstmLayer: "
218 "m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
219 }
220
222 {
223 throw armnn::LayerValidationException("QLstmLayer: "
224 "m_BasicParameters.m_RecurrentToCellWeights should not be null.");
225 }
226
228 {
229 throw armnn::LayerValidationException("QLstmLayer: "
230 "m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
231 }
232
234 {
235 throw armnn::LayerValidationException("QLstmLayer: "
236 "m_BasicParameters.m_ForgetGateBias should not be null.");
237 }
238
240 {
241 throw armnn::LayerValidationException("QLstmLayer: "
242 "m_BasicParameters.m_CellBias should not be null.");
243 }
244
246 {
247 throw armnn::LayerValidationException("QLstmLayer: "
248 "m_BasicParameters.m_OutputGateBias should not be null.");
249 }
250
252 {
254 {
255 throw armnn::LayerValidationException("QLstmLayer: "
256 "m_CifgParameters.m_InputToInputWeights should not be null.");
257 }
258
260 {
261 throw armnn::LayerValidationException("QLstmLayer: "
262 "m_CifgParameters.m_RecurrentToInputWeights should not be null.");
263 }
264
266 {
267 throw armnn::LayerValidationException("QLstmLayer: "
268 "m_CifgParameters.m_InputGateBias should not be null.");
269 }
270
271 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
272 }
273 else
274 {
276 {
277 throw armnn::LayerValidationException("QLstmLayer: "
278 "m_CifgParameters.m_InputToInputWeights "
279 "should not have a value when CIFG is enabled.");
280 }
281
283 {
284 throw armnn::LayerValidationException("QLstmLayer: "
285 "m_CifgParameters.m_RecurrentToInputWeights "
286 "should not have a value when CIFG is enabled.");
287 }
288
290 {
291 throw armnn::LayerValidationException("QLstmLayer: "
292 "m_CifgParameters.m_InputGateBias "
293 "should not have a value when CIFG is enabled.");
294 }
295
296 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
297 }
298
300 {
302 {
303 throw armnn::LayerValidationException("QLstmLayer: "
304 "m_ProjectionParameters.m_ProjectionWeights should not be null.");
305 }
306 }
307
309 {
310 if (!m_Param.m_CifgEnabled) {
312 {
313 throw armnn::LayerValidationException("QLstmLayer: "
314 "m_PeepholeParameters.m_CellToInputWeights should not be null "
315 "when Peephole is enabled and CIFG is disabled.");
316 }
317 }
318
320 {
321 throw armnn::LayerValidationException("QLstmLayer: "
322 "m_PeepholeParameters.m_CellToForgetWeights should not be null.");
323 }
324
326 {
327 throw armnn::LayerValidationException("QLstmLayer: "
328 "m_PeepholeParameters.m_CellToOutputWeights should not be null.");
329 }
330 }
331
333 GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "QLstmLayer", 1);
335 GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "QLstmLayer", 2);
336
338 {
340 {
342 {
343 throw armnn::LayerValidationException("QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights "
344 "should not be null.");
345 }
346 }
347
349 {
350 throw armnn::LayerValidationException("QLstmLayer: "
351 "m_LayerNormParameters.m_ForgetLayerNormWeights should not be null.");
352 }
353
355 {
356 throw armnn::LayerValidationException("QLstmLayer: "
357 "m_LayerNormParameters.m_CellLayerNormWeights should not be null.");
358 }
359
361 {
362 throw armnn::LayerValidationException("QLstmLayer: "
363 "m_LayerNormParameters.m_UutputLayerNormWeights should not be null.");
364 }
365 }
366}
367
401
402
404{
405 std::vector<ConstTensor> constTensors;
415
416 // Cifg parameters
420
421 // Projection parameters
424
425 // Peephole parameters
429
430 // Layer normalisation parameters
435
436 // First add mandatory/basic parameters
438 {
439 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
440 managedInputToForgetWeights.Map()));
441 }
443 {
444 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
445 managedInputToCellWeights.Map()));
446 }
448 {
449 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
450 managedInputToOutputWeights.Map()));
451 }
453 {
454 constTensors.emplace_back(ConstTensor(
455 managedRecurrentToForgetWeights.GetTensorInfo(),
456 managedRecurrentToForgetWeights.Map()));
457 }
459 {
460 constTensors.emplace_back(ConstTensor(
461 managedRecurrentToCellWeights.GetTensorInfo(),
462 managedRecurrentToCellWeights.Map()));
463 }
465 {
466 constTensors.emplace_back(ConstTensor(
467 managedRecurrentToOutputWeights.GetTensorInfo(),
468 managedRecurrentToOutputWeights.Map()));
469 }
470 if (m_BasicParameters.m_ForgetGateBias != nullptr)
471 {
472 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
473 managedForgetGateBias.Map()));
474 }
475 if (m_BasicParameters.m_CellBias != nullptr)
476 {
477 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
478 managedCellBias.Map()));
479 }
480 if (m_BasicParameters.m_OutputGateBias != nullptr)
481 {
482 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
483 managedOutputGateBias.Map()));
484 }
485
486 // Add cifig parameters
488 {
489 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
490 managedInputToInputWeights.Map()));
491 }
493 {
494 constTensors.emplace_back(ConstTensor(
495 managedRecurrentToInputWeights.GetTensorInfo(),
496 managedRecurrentToInputWeights.Map()));
497 }
498 if (m_CifgParameters.m_InputGateBias != nullptr)
499 {
500 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
501 managedInputGateBias.Map()));
502 }
503
504 // Add peephole parameters
506 {
507 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
508 managedCellToInputWeights.Map()));
509 }
511 {
512 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
513 managedCellToForgetWeights.Map()));
514 }
516 {
517 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
518 managedCellToOutputWeights.Map()));
519 }
520
521 // Add projection parameters
523 {
524 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
525 managedProjectionWeights.Map()));
526 }
528 {
529 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
530 managedProjectionBias.Map()));
531 }
532
533 // Add norm parameters
535 {
536 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
537 managedInputLayerNormWeights.Map()));
538 }
540 {
541 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
542 managedForgetLayerNormWeights.Map()));
543 }
545 {
546 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
547 managedCellLayerNormWeights.Map()));
548 }
550 {
551 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
552 managedOutputLayerNormWeights.Map()));
553 }
554 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
555}
556
557} // namespace armnn
#define CHECK_LOCATION()
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Base class for all ArmNN exceptions so that users can filter to just those.
std::vector< std::reference_wrapper< const std::shared_ptr< ConstTensorHandle > > > ImmutableConstantTensors
Definition INetwork.hpp:141
virtual void ExecuteStrategy(const IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const =0
Backends should implement their own CreateWorkload function with a switch statement.
const TensorInfo & GetTensorInfo() const override
Gets the TensorInfo for this InputSlot.
Definition Layer.cpp:614
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
Definition Layer.cpp:410
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
Definition Layer.cpp:526
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
Definition Layer.cpp:457
void SetAdditionalInfo(QueueDescriptor &descriptor) const
Definition Layer.cpp:303
ShapeInferenceMethod m_ShapeInferenceMethod
Definition Layer.hpp:441
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *LayerCreateWorkload.
const QLstmDescriptor & GetParameters() const override
QLstmDescriptor m_Param
The parameters for the layer (not including tensor-valued weights etc.).
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const TensorInfo & GetTensorInfo() const
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
This layer represents a QLstm operation.
Layer::ImmutableConstantTensors GetConstantTensorsByRef() const override
Retrieve the handles to the constant values stored by the layer.
QLstmLayer(const QLstmDescriptor &param, const char *name)
Constructor to create a QLstmLayer.
QLstmOptProjectionParameters m_ProjectionParameters
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
QLstmOptPeepholeParameters m_PeepholeParameters
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs,...
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of QLstmLayer.
QLstmBasicParameters m_BasicParameters
QLstmOptLayerNormParameters m_LayerNormParameters
QLstmLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the QLstm type.
QLstmOptCifgParameters m_CifgParameters
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
Definition Types.hpp:494
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
std::shared_ptr< ConstTensorHandle > m_RecurrentToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, outputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_CellBias
A unique pointer to represent 1D bias tensor with dimensions [num_units] (int32).
std::shared_ptr< ConstTensorHandle > m_InputToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, inputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_RecurrentToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, outputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_OutputGateBias
A unique pointer to represent 1D bias tensor with dimensions [num_units] (int32).
std::shared_ptr< ConstTensorHandle > m_InputToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, inputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_InputToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, inputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_RecurrentToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [num_units, outputSize] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_ForgetGateBias
A unique pointer to represent 1D bias tensor with dimensions [num_units] (int32).
A QLstmDescriptor for the QLstmLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
std::shared_ptr< ConstTensorHandle > m_InputToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_InputGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units] (int32).
std::shared_ptr< ConstTensorHandle > m_RecurrentToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units] (QSymmS8).
std::shared_ptr< ConstTensorHandle > m_CellLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_InputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_OutputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_ForgetLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_CellToForgetWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_CellToInputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_CellToOutputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units] (QSymmS16).
std::shared_ptr< ConstTensorHandle > m_ProjectionBias
A unique pointer to represent 1D weights tensor with dimensions [output_size] (int32).
std::shared_ptr< ConstTensorHandle > m_ProjectionWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units] (QSymmS8).
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToInputWeights
const ConstTensorHandle * m_CellToOutputWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_CellLayerNormWeights