85 layer->m_BasicParameters.m_InputToForgetWeights =
m_BasicParameters.m_InputToForgetWeights ?
87 layer->m_BasicParameters.m_InputToCellWeights =
m_BasicParameters.m_InputToCellWeights ?
89 layer->m_BasicParameters.m_InputToOutputWeights =
m_BasicParameters.m_InputToOutputWeights ?
91 layer->m_BasicParameters.m_RecurrentToForgetWeights =
m_BasicParameters.m_RecurrentToForgetWeights ?
93 layer->m_BasicParameters.m_RecurrentToCellWeights =
m_BasicParameters.m_RecurrentToCellWeights ?
95 layer->m_BasicParameters.m_RecurrentToOutputWeights =
m_BasicParameters.m_RecurrentToOutputWeights ?
106 layer->m_CifgParameters.m_InputToInputWeights =
m_CifgParameters.m_InputToInputWeights ?
108 layer->m_CifgParameters.m_RecurrentToInputWeights =
m_CifgParameters.m_RecurrentToInputWeights ?
110 layer->m_CifgParameters.m_InputGateBias =
m_CifgParameters.m_InputGateBias ?
114 if (
m_Param.m_ProjectionEnabled)
129 layer->m_PeepholeParameters.m_CellToForgetWeights =
m_PeepholeParameters.m_CellToForgetWeights ?
131 layer->m_PeepholeParameters.m_CellToOutputWeights =
m_PeepholeParameters.m_CellToOutputWeights ?
135 if (
m_Param.m_LayerNormEnabled)
138 layer->m_LayerNormParameters.m_InputLayerNormWeights =
m_LayerNormParameters.m_InputLayerNormWeights ?
142 layer->m_LayerNormParameters.m_ForgetLayerNormWeights =
m_LayerNormParameters.m_ForgetLayerNormWeights ?
146 layer->m_LayerNormParameters.m_OutputLayerNormWeights =
m_LayerNormParameters.m_OutputLayerNormWeights ?
150 return std::move(layer);
155 if (inputShapes.size() != 3)
157 throw armnn::Exception(
"inputShapes' size is \"" + std::to_string(inputShapes.size()) +
158 "\" - should be \"3\".");
162 unsigned int batchSize = inputShapes[0][0];
163 unsigned int outputSize = inputShapes[1][1];
164 unsigned int numUnits = inputShapes[2][1];
166 std::vector<TensorShape> outShapes;
167 outShapes.push_back(
TensorShape({ batchSize, outputSize }));
168 outShapes.push_back(
TensorShape({ batchSize, numUnits }));
169 outShapes.push_back(
TensorShape({ batchSize, outputSize }));
189 if (inferredShapes.size() != 3)
192 + std::to_string(inferredShapes.size()) +
193 " element(s) - should only have 3.");
200 "m_BasicParameters.m_InputToForgetWeights should not be null.");
206 "m_BasicParameters.m_InputToCellWeights should not be null.");
212 "m_BasicParameters.m_InputToOutputWeights should not be null.");
218 "m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
224 "m_BasicParameters.m_RecurrentToCellWeights should not be null.");
230 "m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
236 "m_BasicParameters.m_ForgetGateBias should not be null.");
242 "m_BasicParameters.m_CellBias should not be null.");
248 "m_BasicParameters.m_OutputGateBias should not be null.");
256 "m_CifgParameters.m_InputToInputWeights should not be null.");
262 "m_CifgParameters.m_RecurrentToInputWeights should not be null.");
268 "m_CifgParameters.m_InputGateBias should not be null.");
278 "m_CifgParameters.m_InputToInputWeights "
279 "should not have a value when CIFG is enabled.");
285 "m_CifgParameters.m_RecurrentToInputWeights "
286 "should not have a value when CIFG is enabled.");
292 "m_CifgParameters.m_InputGateBias "
293 "should not have a value when CIFG is enabled.");
299 if (
m_Param.m_ProjectionEnabled)
304 "m_ProjectionParameters.m_ProjectionWeights should not be null.");
314 "m_PeepholeParameters.m_CellToInputWeights should not be null "
315 "when Peephole is enabled and CIFG is disabled.");
322 "m_PeepholeParameters.m_CellToForgetWeights should not be null.");
328 "m_PeepholeParameters.m_CellToOutputWeights should not be null.");
337 if (
m_Param.m_LayerNormEnabled)
344 "should not be null.");
351 "m_LayerNormParameters.m_ForgetLayerNormWeights should not be null.");
357 "m_LayerNormParameters.m_CellLayerNormWeights should not be null.");
363 "m_LayerNormParameters.m_UutputLayerNormWeights should not be null.");
405 std::vector<ConstTensor> constTensors;
440 managedInputToForgetWeights.
Map()));
445 managedInputToCellWeights.
Map()));
450 managedInputToOutputWeights.
Map()));
456 managedRecurrentToForgetWeights.
Map()));
462 managedRecurrentToCellWeights.
Map()));
468 managedRecurrentToOutputWeights.
Map()));
473 managedForgetGateBias.
Map()));
478 managedCellBias.
Map()));
483 managedOutputGateBias.
Map()));
490 managedInputToInputWeights.
Map()));
496 managedRecurrentToInputWeights.
Map()));
501 managedInputGateBias.
Map()));
508 managedCellToInputWeights.
Map()));
513 managedCellToForgetWeights.
Map()));
518 managedCellToOutputWeights.
Map()));
525 managedProjectionWeights.
Map()));
530 managedProjectionBias.
Map()));
537 managedInputLayerNormWeights.
Map()));
542 managedForgetLayerNormWeights.
Map()));
547 managedCellLayerNormWeights.
Map()));
552 managedOutputLayerNormWeights.
Map()));