ArmNN
 25.11
Loading...
Searching...
No Matches
UnidirectionalSequenceLstmLayer.cpp
Go to the documentation of this file.
1//
2// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
6
7#include "LayerCloneBase.hpp"
8
10#include <armnn/TypesUtils.hpp>
13
14namespace armnn
15{
16
21
22std::unique_ptr<IWorkload> UnidirectionalSequenceLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23{
25
26 // Basic parameters
27 descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28 descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29 descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30 descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31 descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32 descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33 descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35 descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36
37 // Cifg parameters
38 if (!m_Param.m_CifgEnabled)
39 {
40 descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41 descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42 descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43 }
44
45 // Projection parameters
46 if (m_Param.m_ProjectionEnabled)
47 {
48 descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49 descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get();
50 }
51
52 // Peephole parameters
53 if (m_Param.m_PeepholeEnabled)
54 {
55 if (!m_Param.m_CifgEnabled)
56 {
57 descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58 }
59 descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
60 descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
61 }
62
63 // Layer normalisation parameters
64 if(m_Param.m_LayerNormEnabled)
65 {
66 if (!m_Param.m_CifgEnabled)
67 {
68 descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
69 }
70 descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
71 descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
72 descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
73 }
74
75 SetAdditionalInfo(descriptor);
76
77 return factory.CreateWorkload(LayerType::UnidirectionalSequenceLstm, descriptor, PrepInfoAndDesc(descriptor));
78}
79
81{
83
84 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
85 m_BasicParameters.m_InputToForgetWeights
86 : nullptr;
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88 m_BasicParameters.m_InputToCellWeights : nullptr;
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90 m_BasicParameters.m_InputToOutputWeights : nullptr;
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92 m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94 m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96 m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98 m_BasicParameters.m_ForgetGateBias : nullptr;
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100 m_BasicParameters.m_CellBias : nullptr;
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102 m_BasicParameters.m_OutputGateBias : nullptr;
103
104 if (!m_Param.m_CifgEnabled)
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107 m_CifgParameters.m_InputToInputWeights : nullptr;
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111 m_CifgParameters.m_InputGateBias : nullptr;
112 }
113
114 if (m_Param.m_ProjectionEnabled)
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117 m_ProjectionParameters.m_ProjectionWeights : nullptr;
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119 m_ProjectionParameters.m_ProjectionBias : nullptr;
120 }
121
122 if (m_Param.m_PeepholeEnabled)
123 {
124 if (!m_Param.m_CifgEnabled)
125 {
126 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127 m_PeepholeParameters.m_CellToInputWeights : nullptr;
128 }
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130 m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132 m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133 }
134
135 if (m_Param.m_LayerNormEnabled)
136 {
137 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
138 m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
139 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
140 m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
141 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
142 m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
143 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
144 m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
145 }
146
147 return std::move(layer);
148}
149
151 const std::vector<TensorShape>& inputShapes) const
152{
153 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputShapes.size() == 3,
154 "inputShapes' size is \"" + std::to_string(inputShapes.size()) +
155 "\" - should be \"3\".");
156
157 // Get input values for validation
158 unsigned int outputSize = inputShapes[1][1];
159
160 std::vector<TensorShape> outShapes;
161 if (m_Param.m_TimeMajor)
162 {
163 outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
164 }
165 else
166 {
167 outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
168 }
169 return outShapes;
170}
171
173{
175
176 const TensorShape& outputShape = GetOutputSlot(2).GetTensorInfo().GetShape();
177
179
180 auto inferredShapes = InferOutputShapes( {
184 });
185
186 if (inferredShapes.size() != 1)
187 {
188 throw armnn::LayerValidationException("inferredShapes has "
189 + std::to_string(inferredShapes.size()) +
190 " elements - should only have 1.");
191 }
192
193 // Check if the weights are nullptr
194 if (!m_BasicParameters.m_InputToForgetWeights)
195 {
196 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
197 "m_BasicParameters.m_InputToForgetWeights should not be null.");
198 }
199
200 if (!m_BasicParameters.m_InputToCellWeights)
201 {
202 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
203 "m_BasicParameters.m_InputToCellWeights should not be null.");
204 }
205
206 if (!m_BasicParameters.m_InputToOutputWeights)
207 {
208 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
209 "m_BasicParameters.m_InputToOutputWeights should not be null.");
210 }
211
212 if (!m_BasicParameters.m_RecurrentToForgetWeights)
213 {
214 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
215 "m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
216 }
217
218 if (!m_BasicParameters.m_RecurrentToCellWeights)
219 {
220 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
221 "m_BasicParameters.m_RecurrentToCellWeights should not be null.");
222 }
223
224 if (!m_BasicParameters.m_RecurrentToOutputWeights)
225 {
226 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
227 "m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
228 }
229
230 if (!m_BasicParameters.m_ForgetGateBias)
231 {
232 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
233 "m_BasicParameters.m_ForgetGateBias should not be null.");
234 }
235
236 if (!m_BasicParameters.m_CellBias)
237 {
238 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
239 "m_BasicParameters.m_CellBias should not be null.");
240 }
241
242 if (!m_BasicParameters.m_OutputGateBias)
243 {
244 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
245 "m_BasicParameters.m_OutputGateBias should not be null.");
246 }
247
248 if (!m_Param.m_CifgEnabled)
249 {
250 if (!m_CifgParameters.m_InputToInputWeights)
251 {
252 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
253 "m_CifgParameters.m_InputToInputWeights should not be null.");
254 }
255
256 if (!m_CifgParameters.m_RecurrentToInputWeights)
257 {
258 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
259 "m_CifgParameters.m_RecurrentToInputWeights should not be null.");
260 }
261
262 if (!m_CifgParameters.m_InputGateBias)
263 {
264 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
265 "m_CifgParameters.m_InputGateBias should not be null.");
266 }
267 }
268 else
269 {
270 if (m_CifgParameters.m_InputToInputWeights)
271 {
272 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
273 "m_CifgParameters.m_InputToInputWeights should not have a value "
274 "when CIFG is enabled.");
275 }
276
277 if (m_CifgParameters.m_RecurrentToInputWeights)
278 {
279 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
280 "m_CifgParameters.m_RecurrentToInputWeights should not have a value "
281 "when CIFG is enabled.");
282 }
283
284 if (m_CifgParameters.m_InputGateBias)
285 {
286 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
287 "m_CifgParameters.m_InputGateBias should not have a value "
288 "when CIFG is enabled.");
289 }
290 }
291
292 if (m_Param.m_ProjectionEnabled)
293 {
294 if (!m_ProjectionParameters.m_ProjectionWeights)
295 {
296 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
297 "m_ProjectionParameters.m_ProjectionWeights should not be null.");
298 }
299 }
300
301 if (m_Param.m_PeepholeEnabled)
302 {
303 if (!m_Param.m_CifgEnabled)
304 {
305 if (!m_PeepholeParameters.m_CellToInputWeights)
306 {
307 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
308 "m_PeepholeParameters.m_CellToInputWeights should not be null "
309 "when Peephole is enabled and CIFG is disabled.");
310 }
311 }
312
313 if (!m_PeepholeParameters.m_CellToForgetWeights)
314 {
315 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
316 "m_PeepholeParameters.m_CellToForgetWeights should not be null.");
317 }
318
319 if (!m_PeepholeParameters.m_CellToOutputWeights)
320 {
321 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
322 "m_PeepholeParameters.m_CellToOutputWeights should not be null.");
323 }
324 }
325
326 if (m_Param.m_LayerNormEnabled)
327 {
328 if(!m_Param.m_CifgEnabled)
329 {
330 if (!m_LayerNormParameters.m_InputLayerNormWeights)
331 {
332 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
333 "m_LayerNormParameters.m_inputLayerNormWeights "
334 "should not be null.");
335 }
336 }
337
338 if (!m_LayerNormParameters.m_ForgetLayerNormWeights)
339 {
340 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
341 "m_LayerNormParameters.m_forgetLayerNormWeights "
342 "should not be null.");
343 }
344
345 if (!m_LayerNormParameters.m_CellLayerNormWeights)
346 {
347 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
348 "m_LayerNormParameters.m_cellLayerNormWeights "
349 "should not be null.");
350 }
351
352 if (!m_LayerNormParameters.m_OutputLayerNormWeights)
353 {
354 throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: "
355 "m_LayerNormParameters.m_outputLayerNormWeights "
356 "should not be null.");
357 }
358 }
359
360 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer");
361}
362
364{
365 // For API stability DO NOT ALTER order and add new members to the end of vector
366 return {m_BasicParameters.m_InputToForgetWeights,
367 m_BasicParameters.m_InputToCellWeights,
368 m_BasicParameters.m_InputToOutputWeights,
369 m_BasicParameters.m_RecurrentToForgetWeights,
370 m_BasicParameters.m_RecurrentToCellWeights,
371 m_BasicParameters.m_RecurrentToOutputWeights,
372 m_BasicParameters.m_ForgetGateBias,
373 m_BasicParameters.m_CellBias,
374 m_BasicParameters.m_OutputGateBias,
375
376 // Cifg parameters
377 m_CifgParameters.m_InputToInputWeights,
378 m_CifgParameters.m_RecurrentToInputWeights,
379 m_CifgParameters.m_InputGateBias,
380
381 // Projection parameters
382 m_ProjectionParameters.m_ProjectionWeights,
383 m_ProjectionParameters.m_ProjectionBias,
384
385 // Peephole parameters
386 m_PeepholeParameters.m_CellToInputWeights,
387 m_PeepholeParameters.m_CellToForgetWeights,
388 m_PeepholeParameters.m_CellToOutputWeights,
389
390 // Layer normalisation parameters
391 m_LayerNormParameters.m_InputLayerNormWeights,
392 m_LayerNormParameters.m_ForgetLayerNormWeights,
393 m_LayerNormParameters.m_CellLayerNormWeights,
394 m_LayerNormParameters.m_OutputLayerNormWeights};
395}
396
398{
399 std::vector<ConstTensor> constTensors;
400
401 LstmDescriptor descriptor = GetParameters();
402
403 ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
404 ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
405 ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
406 ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
407 ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
408 ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
409 ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
410 ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
411 ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
412
413 // Cifg parameters
414 ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
415 ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
416 ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
417
418 // Projection parameters
419 ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
420 ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
421
422 // Peephole parameters
423 ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
424 ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
425 ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
426
427 // Layer normalisation parameters
428 ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
429 ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
430 ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
431 ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
432
433 // First add mandatory/basic parameters
434 if (m_BasicParameters.m_InputToForgetWeights != nullptr)
435 {
436 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
437 managedInputToForgetWeights.Map()));
438 }
439 if (m_BasicParameters.m_InputToCellWeights != nullptr)
440 {
441 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
442 managedInputToCellWeights.Map()));
443 }
444 if (m_BasicParameters.m_InputToOutputWeights != nullptr)
445 {
446 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
447 managedInputToOutputWeights.Map()));
448 }
449 if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
450 {
451 constTensors.emplace_back(ConstTensor(
452 managedRecurrentToForgetWeights.GetTensorInfo(),
453 managedRecurrentToForgetWeights.Map()));
454 }
455 if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
456 {
457 constTensors.emplace_back(ConstTensor(
458 managedRecurrentToCellWeights.GetTensorInfo(),
459 managedRecurrentToCellWeights.Map()));
460 }
461 if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
462 {
463 constTensors.emplace_back(ConstTensor(
464 managedRecurrentToOutputWeights.GetTensorInfo(),
465 managedRecurrentToOutputWeights.Map()));
466 }
467 if (m_BasicParameters.m_ForgetGateBias != nullptr)
468 {
469 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
470 managedForgetGateBias.Map()));
471 }
472 if (m_BasicParameters.m_CellBias != nullptr)
473 {
474 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
475 managedCellBias.Map()));
476 }
477 if (m_BasicParameters.m_OutputGateBias != nullptr)
478 {
479 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
480 managedOutputGateBias.Map()));
481 }
482
483 // Add cifg parameters
484 if (!descriptor.m_CifgEnabled)
485 {
486 if (m_CifgParameters.m_InputToInputWeights != nullptr)
487 {
488 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
489 managedInputToInputWeights.Map()));
490 }
491 if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
492 {
493 constTensors.emplace_back(ConstTensor(
494 managedRecurrentToInputWeights.GetTensorInfo(),
495 managedRecurrentToInputWeights.Map()));
496 }
497 if (m_CifgParameters.m_InputGateBias != nullptr)
498 {
499 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
500 managedInputGateBias.Map()));
501 }
502 }
503
504 // Add peephole parameters
505 if (descriptor.m_PeepholeEnabled)
506 {
507 if (!descriptor.m_CifgEnabled)
508 {
509 if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
510 {
511 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
512 managedCellToInputWeights.Map()));
513 }
514 }
515 if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
516 {
517 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
518 managedCellToForgetWeights.Map()));
519 }
520 if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
521 {
522 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
523 managedCellToOutputWeights.Map()));
524 }
525 }
526
527 // Add projection parameters
528 if (descriptor.m_ProjectionEnabled)
529 {
530 if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
531 {
532 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
533 managedProjectionWeights.Map()));
534 }
535 if (m_ProjectionParameters.m_ProjectionBias != nullptr)
536 {
537 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
538 managedProjectionBias.Map()));
539 }
540 }
541
542 // Add norm parameters
543 if (descriptor.m_LayerNormEnabled)
544 {
545 if (!descriptor.m_CifgEnabled)
546 {
547 if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
548 {
549 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
550 managedInputLayerNormWeights.Map()));
551 }
552 }
553 if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
554 {
555 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
556 managedForgetLayerNormWeights.Map()));
557 }
558 if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
559 {
560 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
561 managedCellLayerNormWeights.Map()));
562 }
563 if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
564 {
565 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
566 managedOutputLayerNormWeights.Map()));
567 }
568 }
569
570 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
571}
572
573} // namespace armnn
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
#define CHECK_LOCATION()
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
std::vector< std::reference_wrapper< const std::shared_ptr< ConstTensorHandle > > > ImmutableConstantTensors
Definition INetwork.hpp:141
virtual void ExecuteStrategy(const IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const =0
Backends should implement their own CreateWorkload function with a switch statement.
const TensorInfo & GetTensorInfo() const override
Gets the TensorInfo for this InputSlot.
Definition Layer.cpp:614
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
Definition Layer.cpp:410
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
Definition Layer.cpp:526
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
LayerType * CloneBase(Graph &graph, Params &&... params) const
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
Definition Layer.cpp:457
void SetAdditionalInfo(QueueDescriptor &descriptor) const
Definition Layer.cpp:303
ShapeInferenceMethod m_ShapeInferenceMethod
Definition Layer.hpp:441
LayerWithParameters(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const LstmDescriptor &param, const char *name)
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
const LstmDescriptor & GetParameters() const override
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const TensorInfo & GetTensorInfo() const
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Layer::ImmutableConstantTensors GetConstantTensorsByRef() const override
Retrieve the handles to the constant values stored by the layer.
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs,...
UnidirectionalSequenceLstmLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of UnidirectionalSequenceLstmLa...
UnidirectionalSequenceLstmLayer(const LstmDescriptor &param, const char *name)
Constructor to create a UnidirectionalSequenceLstmLayer.
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the UnidirectionalSequence LSTM type.
Copyright (c) 2021 ARM Limited and Contributors.
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
Definition Types.hpp:494
@ UnidirectionalSequenceLstm
Definition Types.hpp:496
An LstmDescriptor for the LstmLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).