ArmNN
 26.01
Loading...
Searching...
No Matches
LstmLayer.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "LstmLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
10#include <armnn/TypesUtils.hpp>
13
14namespace armnn
15{
16
17LstmLayer::LstmLayer(const LstmDescriptor& param, const char* name)
18 : LayerWithParameters(3, 4, LayerType::Lstm, param, name)
19{
20}
21
22std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23{
24 LstmQueueDescriptor descriptor;
25
26 // Basic parameters
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
36
37 // Cifg parameters
39 {
43 }
44
45 // Projection parameters
47 {
50 }
51
52 // Peephole parameters
54 {
56 {
58 }
61 }
62
63 // Layer normalisation parameters
65 {
67 {
69 }
73 }
74
75 SetAdditionalInfo(descriptor);
76
77 return factory.CreateWorkload(LayerType::Lstm, descriptor, PrepInfoAndDesc(descriptor));
78}
79
81{
82 auto layer = CloneBase<LstmLayer>(graph, m_Param, GetName());
83
84 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
86 : nullptr;
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
103
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
112 }
113
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
120 }
121
123 {
125 {
126 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
128 }
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
133 }
134
136 {
137 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
139 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
141 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
143 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
145 }
146
147 return std::move(layer);
148}
149
150std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
151{
152 if (inputShapes.size() != 3)
153 {
154 throw armnn::Exception("inputShapes' size is \"" + std::to_string(inputShapes.size()) +
155 "\" - should be \"3\".");
156 }
157
158 // Get input values for validation
159 unsigned int batchSize = inputShapes[0][0];
160 unsigned int outputSize = inputShapes[1][1];
161 unsigned int numUnits = inputShapes[2][1];
162
163 std::vector<TensorShape> outShapes;
164 outShapes.push_back(TensorShape({batchSize, numUnits * (m_Param.m_CifgEnabled ? 3 : 4)}));
165 outShapes.push_back(TensorShape({batchSize, outputSize}));
166 outShapes.push_back(TensorShape({batchSize, numUnits}));
167 outShapes.push_back(TensorShape({batchSize, outputSize}));
168
169 return outShapes;
170}
171
173{
175
176 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
177
179
180 auto inferredShapes = InferOutputShapes( {
184 });
185
186 if (inferredShapes.size() != 4)
187 {
188 throw armnn::Exception("inferredShapes has "
189 + std::to_string(inferredShapes.size()) +
190 " element(s) - should only have 4.");
191 }
192
193 // Check if the weights are nullptr
195 {
196 throw armnn::NullPointerException("LstmLayer: "
197 "m_BasicParameters.m_InputToForgetWeights should not be null.");
198 }
199
201 {
202 throw armnn::NullPointerException("LstmLayer: "
203 "m_BasicParameters.m_InputToCellWeights should not be null.");
204 }
205
207 {
208 throw armnn::NullPointerException("LstmLayer: "
209 "m_BasicParameters.m_InputToOutputWeights should not be null.");
210 }
211
213 {
214 throw armnn::NullPointerException("LstmLayer: "
215 "m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
216 }
217
219 {
220 throw armnn::NullPointerException("LstmLayer: "
221 "m_BasicParameters.m_RecurrentToCellWeights should not be null.");
222 }
223
225 {
226 throw armnn::NullPointerException("LstmLayer: "
227 "m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
228 }
229
231 {
232 throw armnn::NullPointerException("LstmLayer: "
233 "m_BasicParameters.m_ForgetGateBias should not be null.");
234 }
235
237 {
238 throw armnn::NullPointerException("LstmLayer: "
239 "m_BasicParameters.m_CellBias should not be null.");
240 }
241
243 {
244 throw armnn::NullPointerException("LstmLayer: "
245 "m_BasicParameters.m_OutputGateBias should not be null.");
246 }
247
249 {
251 {
252 throw armnn::NullPointerException("LstmLayer: "
253 "m_CifgParameters.m_InputToInputWeights should not be null.");
254 }
255
257 {
258 throw armnn::NullPointerException("LstmLayer: "
259 "m_CifgParameters.m_RecurrentToInputWeights should not be null.");
260 }
261
263 {
264 throw armnn::NullPointerException("LstmLayer: "
265 "m_CifgParameters.m_InputGateBias should not be null.");
266 }
267
268 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
269 }
270 else
271 {
273 {
274 throw armnn::Exception("LstmLayer: "
275 "m_CifgParameters.m_InputToInputWeights should not have a value "
276 "when CIFG is enabled.");
277 }
278
280 {
281 throw armnn::Exception("LstmLayer: "
282 "m_CifgParameters.m_RecurrentToInputWeights should not have a value "
283 "when CIFG is enabled.");
284 }
285
287 {
288 throw armnn::Exception("LstmLayer: "
289 "m_CifgParameters.m_InputGateBias should not have a value "
290 "when CIFG is enabled.");
291 }
292
293 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
294 }
295
297 {
299 {
300 throw armnn::NullPointerException("LstmLayer: "
301 "m_ProjectionParameters.m_ProjectionWeights should not be null.");
302 }
303 }
304
306 {
308 {
310 {
311 throw armnn::NullPointerException("LstmLayer: "
312 "m_PeepholeParameters.m_CellToInputWeights should not be null "
313 "when Peephole is enabled and CIFG is disabled.");
314 }
315 }
316
318 {
319 throw armnn::NullPointerException("LstmLayer: "
320 "m_PeepholeParameters.m_CellToForgetWeights should not be null.");
321 }
322
324 {
325 throw armnn::NullPointerException("LstmLayer: "
326 "m_PeepholeParameters.m_CellToOutputWeights should not be null.");
327 }
328 }
329
331 GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "LstmLayer", 1);
333 GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "LstmLayer", 2);
335 GetOutputSlot(3).GetTensorInfo().GetShape(), inferredShapes[3], m_ShapeInferenceMethod, "LstmLayer", 3);
336
338 {
340 {
342 {
343 throw armnn::NullPointerException("LstmLayer: "
344 "m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
345 }
346 }
347
349 {
350 throw armnn::NullPointerException("LstmLayer: "
351 "m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
352 }
353
355 {
356 throw armnn::NullPointerException("LstmLayer: "
357 "m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
358 }
359
361 {
362 throw armnn::NullPointerException("LstmLayer: "
363 "m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
364 }
365 }
366}
367
401
403{
404 std::vector<ConstTensor> constTensors;
405
406 LstmDescriptor descriptor = GetParameters();
407
417
418 // Cifg parameters
422
423 // Projection parameters
426
427 // Peephole parameters
431
432 // Layer normalisation parameters
437
438 // First add mandatory/basic parameters
440 {
441 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
442 managedInputToForgetWeights.Map()));
443 }
445 {
446 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
447 managedInputToCellWeights.Map()));
448 }
450 {
451 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
452 managedInputToOutputWeights.Map()));
453 }
455 {
456 constTensors.emplace_back(ConstTensor(
457 managedRecurrentToForgetWeights.GetTensorInfo(),
458 managedRecurrentToForgetWeights.Map()));
459 }
461 {
462 constTensors.emplace_back(ConstTensor(
463 managedRecurrentToCellWeights.GetTensorInfo(),
464 managedRecurrentToCellWeights.Map()));
465 }
467 {
468 constTensors.emplace_back(ConstTensor(
469 managedRecurrentToOutputWeights.GetTensorInfo(),
470 managedRecurrentToOutputWeights.Map()));
471 }
472 if (m_BasicParameters.m_ForgetGateBias != nullptr)
473 {
474 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
475 managedForgetGateBias.Map()));
476 }
477 if (m_BasicParameters.m_CellBias != nullptr)
478 {
479 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
480 managedCellBias.Map()));
481 }
482 if (m_BasicParameters.m_OutputGateBias != nullptr)
483 {
484 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
485 managedOutputGateBias.Map()));
486 }
487
488 // Add cifg parameters
489 if (!descriptor.m_CifgEnabled)
490 {
492 {
493 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
494 managedInputToInputWeights.Map()));
495 }
497 {
498 constTensors.emplace_back(ConstTensor(
499 managedRecurrentToInputWeights.GetTensorInfo(),
500 managedRecurrentToInputWeights.Map()));
501 }
502 if (m_CifgParameters.m_InputGateBias != nullptr)
503 {
504 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
505 managedInputGateBias.Map()));
506 }
507 }
508
509 // Add peephole parameters
510 if (descriptor.m_PeepholeEnabled)
511 {
512 if (!descriptor.m_CifgEnabled)
513 {
515 {
516 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
517 managedCellToInputWeights.Map()));
518 }
519 }
521 {
522 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
523 managedCellToForgetWeights.Map()));
524 }
526 {
527 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
528 managedCellToOutputWeights.Map()));
529 }
530 }
531
532 // Add projection parameters
533 if (descriptor.m_ProjectionEnabled)
534 {
536 {
537 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
538 managedProjectionWeights.Map()));
539 }
541 {
542 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
543 managedProjectionBias.Map()));
544 }
545 }
546
547 // Add norm parameters
548 if (descriptor.m_LayerNormEnabled)
549 {
550 if (!descriptor.m_CifgEnabled)
551 {
553 {
554 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
555 managedInputLayerNormWeights.Map()));
556 }
557 }
559 {
560 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
561 managedForgetLayerNormWeights.Map()));
562 }
564 {
565 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
566 managedCellLayerNormWeights.Map()));
567 }
569 {
570 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
571 managedOutputLayerNormWeights.Map()));
572 }
573 }
574
575 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
576}
577
578} // namespace armnn
#define CHECK_LOCATION()
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Base class for all ArmNN exceptions so that users can filter to just those.
std::vector< std::reference_wrapper< const std::shared_ptr< ConstTensorHandle > > > ImmutableConstantTensors
Definition INetwork.hpp:141
virtual void ExecuteStrategy(const IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const =0
Backends should implement their own CreateWorkload function with a switch statement.
const TensorInfo & GetTensorInfo() const override
Gets the TensorInfo for this InputSlot.
Definition Layer.cpp:614
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
Definition Layer.cpp:410
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
Definition Layer.cpp:526
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
Definition Layer.cpp:457
void SetAdditionalInfo(QueueDescriptor &descriptor) const
Definition Layer.cpp:303
ShapeInferenceMethod m_ShapeInferenceMethod
Definition Layer.hpp:441
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *LayerCreateWorkload.
const LstmDescriptor & GetParameters() const override
LstmDescriptor m_Param
The parameters for the layer (not including tensor-valued weights etc.).
This layer represents a LSTM operation.
Definition LstmLayer.hpp:17
LstmOptCifgParameters m_CifgParameters
Definition LstmLayer.hpp:21
Layer::ImmutableConstantTensors GetConstantTensorsByRef() const override
Retrieve the handles to the constant values stored by the layer.
LstmOptProjectionParameters m_ProjectionParameters
Definition LstmLayer.hpp:22
LstmOptLayerNormParameters m_LayerNormParameters
Definition LstmLayer.hpp:24
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
LstmOptPeepholeParameters m_PeepholeParameters
Definition LstmLayer.hpp:23
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs,...
LstmBasicParameters m_BasicParameters
Definition LstmLayer.hpp:20
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of LstmLayer.
LstmLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
Definition LstmLayer.cpp:80
LstmLayer(const LstmDescriptor &param, const char *name)
Constructor to create a LstmLayer.
Definition LstmLayer.cpp:17
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the LSTM type.
Definition LstmLayer.cpp:22
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const TensorInfo & GetTensorInfo() const
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
Definition Types.hpp:494
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
std::shared_ptr< ConstTensorHandle > m_RecurrentToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
std::shared_ptr< ConstTensorHandle > m_CellBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_InputToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
std::shared_ptr< ConstTensorHandle > m_RecurrentToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
std::shared_ptr< ConstTensorHandle > m_OutputGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_InputToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
std::shared_ptr< ConstTensorHandle > m_InputToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
std::shared_ptr< ConstTensorHandle > m_RecurrentToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
std::shared_ptr< ConstTensorHandle > m_ForgetGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
An LstmDescriptor for the LstmLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
std::shared_ptr< ConstTensorHandle > m_InputToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
std::shared_ptr< ConstTensorHandle > m_InputGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_RecurrentToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
std::shared_ptr< ConstTensorHandle > m_CellLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_InputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_OutputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_ForgetLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_CellToForgetWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_CellToInputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_CellToOutputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
std::shared_ptr< ConstTensorHandle > m_ProjectionBias
A unique pointer to represent 1D weights tensor with dimensions [output_size].
std::shared_ptr< ConstTensorHandle > m_ProjectionWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToInputWeights
const ConstTensorHandle * m_CellToOutputWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_CellLayerNormWeights