81 const uint32_t numBatches = inputShape[0];
82 const uint32_t inputSize = inputShape[1];
83 const uint32_t outputSize = outputStateInShape[1];
84 const uint32_t numUnits = cellStateInShape[1];
93 std::unique_ptr<Decoder<float>> inputDecoder =
94 MakeDecoder<float>(inputInfo, inputs[0]->
Map());
95 std::unique_ptr<Decoder<float>> outputStateInDecoder =
96 MakeDecoder<float>(outputStateInInfo, inputs[1]->
Map());
97 std::unique_ptr<Decoder<float>> cellStateInDecoder =
98 MakeDecoder<float>(cellStateInInfo, inputs[2]->
Map());
101 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
102 MakeDecoder<float>(outputStateOutInfo, outputs[0]->
Map());
103 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
104 MakeDecoder<float>(cellStateOutInfo, outputs[1]->
Map());
105 std::unique_ptr<Decoder<float>> outputDecoder =
106 MakeDecoder<float>(outputInfo, outputs[2]->
Map());
109 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
110 MakeEncoder<float>(outputStateOutInfo, outputs[0]->
Map());
111 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
112 MakeEncoder<float>(cellStateOutInfo, outputs[1]->
Map());
113 std::unique_ptr<Encoder<float>> outputEncoder =
114 MakeEncoder<float>(outputInfo, outputs[2]->
Map());
117 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
118 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
119 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
120 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
121 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
122 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
124 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
125 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
126 m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
127 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
128 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
129 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
130 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
131 m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
134 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
135 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
136 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
139 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
140 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
141 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
144 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
145 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
148 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
149 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
150 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
151 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
154 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
155 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
156 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
159 const uint32_t stateTensorSize = numBatches * numUnits;
160 std::vector<int16_t> inputGateData(stateTensorSize);
161 std::vector<int16_t> cellGateData(stateTensorSize);
162 std::vector<int16_t> forgetGateData(stateTensorSize);
163 std::vector<int16_t> outputGateData(stateTensorSize);
164 std::vector<int32_t> hiddenStateData(stateTensorSize);
165 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
185 std::unique_ptr<Decoder<float>> inputGateDecoder =
186 MakeDecoder<float>(inputGateInfo, inputGateData.data());
187 std::unique_ptr<Decoder<float>> cellGateDecoder =
188 MakeDecoder<float>(cellGateInfo, cellGateData.data());
189 std::unique_ptr<Decoder<float>> forgetGateDecoder =
190 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
191 std::unique_ptr<Decoder<float>> outputGateDecoder =
192 MakeDecoder<float>(outputGateInfo, outputGateData.data());
193 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
194 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
196 std::unique_ptr<Encoder<float>> inputGateEncoder =
197 MakeEncoder<float>(inputGateInfo, inputGateData.data());
198 std::unique_ptr<Encoder<float>> cellGateEncoder =
199 MakeEncoder<float>(cellGateInfo, cellGateData.data());
200 std::unique_ptr<Encoder<float>> forgetGateEncoder =
201 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
202 std::unique_ptr<Encoder<float>> outputGateEncoder =
203 MakeEncoder<float>(outputGateInfo, outputGateData.data());
204 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
205 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
208 std::unique_ptr<Decoder<float>> outputInt16Decoder =
209 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
210 std::unique_ptr<Encoder<float>> outputInt16Encoder =
211 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
216 inputToInputWeightsDecoder = MakeDecoder<float>(
217 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
218 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
219 m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
226 cellToInputWeightsDecoder = MakeDecoder<float>(
227 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
229 cellToForgetWeightsDecoder = MakeDecoder<float>(
230 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
231 cellToOutputWeightsDecoder = MakeDecoder<float>(
232 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
235 if (projectionEnabled)
237 projectionWeightsDecoder = MakeDecoder<float>(
238 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
239 if (m_ProjectionBiasTensor)
241 projectionBiasDecoder = MakeDecoder<float>(
242 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
246 if (layerNormEnabled)
250 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
251 m_InputLayerNormWeightsTensor->GetConstTensor<
void>());
255 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
256 inputGateBiasDecoder = MakeDecoder<float>(
257 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<
void>());
260 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
261 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
262 m_ForgetLayerNormWeightsTensor->GetConstTensor<
void>());
263 cellLayerNormWeightsDecoder = MakeDecoder<float>(
264 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<
void>());
265 outputLayerNormWeightsDecoder = MakeDecoder<float>(
266 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
267 m_OutputLayerNormWeightsTensor->GetConstTensor<
void>());
271 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
272 forgetGateBiasDecoder = MakeDecoder<float>(
273 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<
void>());
276 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
277 cellGateBiasDecoder = MakeDecoder<float>(
278 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<
void>());
281 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
282 outputGateBiasDecoder = MakeDecoder<float>(
283 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<
void>());
289 ZeroVector(*inputGateEncoder, stateTensorSize);
291 ZeroVector(*forgetGateEncoder, stateTensorSize);
292 ZeroVector(*cellGateEncoder, stateTensorSize);
293 ZeroVector(*outputGateEncoder, stateTensorSize);
294 ZeroVector(*hiddenStateEncoder, stateTensorSize);
300 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
304 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
307 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
310 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
316 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
320 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
323 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
326 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
334 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
337 if (layerNormEnabled)
340 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
342 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
345 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
347 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
350 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
352 inputGateInfo.SetQuantizationScale(1.f / 4096);
353 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
356 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
358 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
362 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
365 Activation(*inputGateDecoder, *inputGateEncoder,
366 TensorInfo({numUnits, numBatches}, internalType),
369 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
376 *cellStateInDecoder, numBatches, *forgetGateEncoder);
379 if (layerNormEnabled)
383 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
385 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
390 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
393 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
396 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
400 forgetGateInfo.SetQuantizationScale(1.f / 4096);
401 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
404 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
407 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
411 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
414 Activation(*forgetGateDecoder, *forgetGateEncoder,
415 TensorInfo({numUnits, numBatches}, internalType),
418 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
421 if (layerNormEnabled)
424 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
426 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
430 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
433 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
435 cellGateInfo.SetQuantizationScale(1.f / 4096);
436 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
439 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
441 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
445 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
448 Activation(*cellGateDecoder, *cellGateEncoder,
449 TensorInfo({numUnits, numBatches}, internalType),
452 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
458 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
460 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
465 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
478 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
481 if (layerNormEnabled)
484 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
486 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
490 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
493 numBatches, *outputGateEncoder);
495 outputGateInfo.SetQuantizationScale(1.f / 4096);
496 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
498 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
500 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
504 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
507 Activation(*outputGateDecoder, *outputGateEncoder,
508 TensorInfo({numUnits, numBatches}, internalType),
511 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
514 Activation(*cellStateOutDecoder, *cellGateEncoder,
515 TensorInfo({numUnits, numBatches}, internalType),
524 if (m_ProjectionBiasTensor)
530 numBatches, *outputInt16Encoder);
532 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
542 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
546 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);