84 layer->m_BasicParameters.m_InputToForgetWeights =
m_BasicParameters.m_InputToForgetWeights ?
87 layer->m_BasicParameters.m_InputToCellWeights =
m_BasicParameters.m_InputToCellWeights ?
89 layer->m_BasicParameters.m_InputToOutputWeights =
m_BasicParameters.m_InputToOutputWeights ?
91 layer->m_BasicParameters.m_RecurrentToForgetWeights =
m_BasicParameters.m_RecurrentToForgetWeights ?
93 layer->m_BasicParameters.m_RecurrentToCellWeights =
m_BasicParameters.m_RecurrentToCellWeights ?
95 layer->m_BasicParameters.m_RecurrentToOutputWeights =
m_BasicParameters.m_RecurrentToOutputWeights ?
106 layer->m_CifgParameters.m_InputToInputWeights =
m_CifgParameters.m_InputToInputWeights ?
108 layer->m_CifgParameters.m_RecurrentToInputWeights =
m_CifgParameters.m_RecurrentToInputWeights ?
110 layer->m_CifgParameters.m_InputGateBias =
m_CifgParameters.m_InputGateBias ?
114 if (
m_Param.m_ProjectionEnabled)
129 layer->m_PeepholeParameters.m_CellToForgetWeights =
m_PeepholeParameters.m_CellToForgetWeights ?
131 layer->m_PeepholeParameters.m_CellToOutputWeights =
m_PeepholeParameters.m_CellToOutputWeights ?
135 if (
m_Param.m_LayerNormEnabled)
137 layer->m_LayerNormParameters.m_InputLayerNormWeights =
m_LayerNormParameters.m_InputLayerNormWeights ?
139 layer->m_LayerNormParameters.m_ForgetLayerNormWeights =
m_LayerNormParameters.m_ForgetLayerNormWeights ?
143 layer->m_LayerNormParameters.m_OutputLayerNormWeights =
m_LayerNormParameters.m_OutputLayerNormWeights ?
147 return std::move(layer);
152 if (inputShapes.size() != 3)
154 throw armnn::Exception(
"inputShapes' size is \"" + std::to_string(inputShapes.size()) +
155 "\" - should be \"3\".");
159 unsigned int batchSize = inputShapes[0][0];
160 unsigned int outputSize = inputShapes[1][1];
161 unsigned int numUnits = inputShapes[2][1];
163 std::vector<TensorShape> outShapes;
164 outShapes.push_back(
TensorShape({batchSize, numUnits * (
m_Param.m_CifgEnabled ? 3 : 4)}));
165 outShapes.push_back(
TensorShape({batchSize, outputSize}));
166 outShapes.push_back(
TensorShape({batchSize, numUnits}));
167 outShapes.push_back(
TensorShape({batchSize, outputSize}));
186 if (inferredShapes.size() != 4)
189 + std::to_string(inferredShapes.size()) +
190 " element(s) - should only have 4.");
197 "m_BasicParameters.m_InputToForgetWeights should not be null.");
203 "m_BasicParameters.m_InputToCellWeights should not be null.");
209 "m_BasicParameters.m_InputToOutputWeights should not be null.");
215 "m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
221 "m_BasicParameters.m_RecurrentToCellWeights should not be null.");
227 "m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
233 "m_BasicParameters.m_ForgetGateBias should not be null.");
239 "m_BasicParameters.m_CellBias should not be null.");
245 "m_BasicParameters.m_OutputGateBias should not be null.");
253 "m_CifgParameters.m_InputToInputWeights should not be null.");
259 "m_CifgParameters.m_RecurrentToInputWeights should not be null.");
265 "m_CifgParameters.m_InputGateBias should not be null.");
275 "m_CifgParameters.m_InputToInputWeights should not have a value "
276 "when CIFG is enabled.");
282 "m_CifgParameters.m_RecurrentToInputWeights should not have a value "
283 "when CIFG is enabled.");
289 "m_CifgParameters.m_InputGateBias should not have a value "
290 "when CIFG is enabled.");
296 if (
m_Param.m_ProjectionEnabled)
301 "m_ProjectionParameters.m_ProjectionWeights should not be null.");
312 "m_PeepholeParameters.m_CellToInputWeights should not be null "
313 "when Peephole is enabled and CIFG is disabled.");
320 "m_PeepholeParameters.m_CellToForgetWeights should not be null.");
326 "m_PeepholeParameters.m_CellToOutputWeights should not be null.");
337 if (
m_Param.m_LayerNormEnabled)
344 "m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
351 "m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
357 "m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
363 "m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
404 std::vector<ConstTensor> constTensors;
442 managedInputToForgetWeights.
Map()));
447 managedInputToCellWeights.
Map()));
452 managedInputToOutputWeights.
Map()));
458 managedRecurrentToForgetWeights.
Map()));
464 managedRecurrentToCellWeights.
Map()));
470 managedRecurrentToOutputWeights.
Map()));
475 managedForgetGateBias.
Map()));
480 managedCellBias.
Map()));
485 managedOutputGateBias.
Map()));
494 managedInputToInputWeights.
Map()));
500 managedRecurrentToInputWeights.
Map()));
505 managedInputGateBias.
Map()));
517 managedCellToInputWeights.
Map()));
523 managedCellToForgetWeights.
Map()));
528 managedCellToOutputWeights.
Map()));
538 managedProjectionWeights.
Map()));
543 managedProjectionBias.
Map()));
555 managedInputLayerNormWeights.
Map()));
561 managedForgetLayerNormWeights.
Map()));
566 managedCellLayerNormWeights.
Map()));
571 managedOutputLayerNormWeights.
Map()));