58 std::vector<ITensorHandle*> outputs)
const
70 auto inputTensor =
reinterpret_cast<float*
>(inputs[0]->Map());
76 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.
GetNumElements());
85 unsigned int maxTime = inputShape[0];
86 unsigned int batchSize = inputShape[1];
87 unsigned int outputSize = outputShape[2];
88 unsigned int inputSize = inputShape[2];
90 TensorInfo scratchInfo = outputInfo;
93 std::vector<float> inputGateScratchBuffer;
94 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
95 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
96 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
98 std::vector<float> outputStateOutBuffer(outputStateInfo.
GetNumElements(), 0.);
99 std::vector<float> cellStateOutBuffer(cellStateInfo.
GetNumElements(), 0.);
101 void* outputStateOutData = outputStateOutBuffer.data();
102 void* cellStateOutData = cellStateOutBuffer.data();
104 std::unique_ptr<Encoder<float>> inputGateScratch;
105 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
106 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
107 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
109 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
110 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
111 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
112 forgetGateScratchBuffer.data());
113 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
114 outputGateScratchBuffer.data());
122 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
123 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
124 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
127 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
128 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
129 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
131 TensorInfo lstmInputInfo = inputInfo;
132 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
133 lstmInputInfo.
SetShape(batchInputShape);
135 TensorInfo lstmOutputInfo = outputInfo;
136 lstmOutputInfo.
SetShape({batchSize, outputSize});
138 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
139 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
140 unsigned int nOutput = recurrentToOutputWeightsShape[1];
141 auto outputStateInData = inputs[1]->Map();
142 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
144 auto cellStateInData = inputs[2]->Map();
145 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
147 auto currentInputData =
reinterpret_cast<float*
>(inputs[0]->Map());
148 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
149 auto currentOutputData =
reinterpret_cast<float*
>(outputs[2]->Map());
150 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
151 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
153 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
154 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
155 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
156 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
157 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
158 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
159 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
161 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
162 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
163 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
164 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
165 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
166 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
167 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
169 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
170 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
171 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
172 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
173 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
174 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
175 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
177 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
178 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
179 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
181 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
182 std::unique_ptr<Decoder<float>> projectionBiasTensor;
184 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
185 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
186 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
187 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
193 inputLayerNormWeights = MakeDecoder<float>(
194 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
196 forgetLayerNormWeights = MakeDecoder<float>(
197 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
198 cellLayerNormWeights = MakeDecoder<float>(
199 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
200 outputLayerNormWeights = MakeDecoder<float>(
201 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
206 inputToInputWeightsTensor = MakeDecoder<float>(
207 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
208 inputGateBiasTensor = MakeDecoder<float>(
209 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
210 recurrentToInputWeightsTensor = MakeDecoder<float>(
211 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
216 cellToForgetWeightsTensor = MakeDecoder<float>(
217 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
218 cellToOutputWeightsTensor = MakeDecoder<float>(
219 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
222 if (!useCifg && usePeephole)
224 cellToInputWeightsTensor = MakeDecoder<float>(
225 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
230 projectionWeightsTensor = MakeDecoder<float>(
231 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
232 if (m_ProjectionBiasTensor)
234 projectionBiasTensor = MakeDecoder<float>(
235 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
239 unsigned int batchInputSize = batchSize * inputSize;
240 unsigned int batchOutputSize = batchSize * nOutput;
242 for (
unsigned int t = 0; t < maxTime; ++t)
247 inputToOutputWeightsShape,
248 recurrentToOutputWeightsShape,
257 inputToInputWeightsTensor,
258 inputToForgetWeightsTensor,
259 inputToCellWeightsTensor,
260 inputToOutputWeightsTensor,
261 recurrentToInputWeightsTensor,
262 recurrentToForgetWeightsTensor,
263 recurrentToCellWeightsTensor,
264 recurrentToOutputWeightsTensor,
265 cellToInputWeightsTensor,
266 cellToForgetWeightsTensor,
267 cellToOutputWeightsTensor,
269 forgetGateBiasTensor,
271 outputGateBiasTensor,
272 projectionWeightsTensor,
273 projectionBiasTensor,
274 inputLayerNormWeights,
275 forgetLayerNormWeights,
276 cellLayerNormWeights,
277 outputLayerNormWeights,
282 inputGateScratchDecoder,
284 forgetGateScratchDecoder,
285 outputGateScratchDecoder,
288 currentInputData += batchInputSize;
289 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
290 currentOutputData += batchOutputSize;
291 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
292 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
295 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
298 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
304 const PermutationVector& mappings = {1U, 0U, 2U};
305 auto outputData =
reinterpret_cast<float*
>(outputs[2]->Map());
306 std::vector<float> outputValue(outputData, outputData + outputInfo.
GetNumElements());