65 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->
Map());
66 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->
Map());
67 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->
Map());
69 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->
Map());
70 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->
Map());
72 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->
Map());
73 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->
Map());
74 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->
Map());
76 const uint32_t nBatch = inputShape[0];
77 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
84 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
85 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
86 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
87 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
89 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
91 std::unique_ptr<Decoder<float>> cellScratchDecoder =
92 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
93 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
94 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
95 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
96 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
100 *cellScratch += (0 * nCell * nBatch);
101 *forgetGateScratch += (1 * nCell * nBatch);
102 *outputGateScratch += (2 * nCell * nBatch);
104 *cellScratchDecoder += (0 * nCell * nBatch);
105 *forgetGateScratchDecoder += (1 * nCell * nBatch);
106 *outputGateScratchDecoder += (2 * nCell * nBatch);
110 *inputGateScratch += (0 * nCell * nBatch);
111 *cellScratch += (1 * nCell * nBatch);
112 *forgetGateScratch += (2 * nCell * nBatch);
113 *outputGateScratch += (3 * nCell * nBatch);
115 *inputGateScratchDecoder += (0 * nCell * nBatch);
116 *cellScratchDecoder += (1 * nCell * nBatch);
117 *forgetGateScratchDecoder += (2 * nCell * nBatch);
118 *outputGateScratchDecoder += (3 * nCell * nBatch);
121 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
122 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
123 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
124 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
125 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
126 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
127 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
129 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
130 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
131 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
132 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
133 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
134 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
135 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
137 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
138 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
139 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
140 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
141 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
142 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
143 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
145 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
146 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
147 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
149 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
150 std::unique_ptr<Decoder<float>> projectionBiasTensor;
152 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
153 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
154 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
155 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
157 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
158 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
164 inputLayerNormWeights = MakeDecoder<float>(
165 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
167 forgetLayerNormWeights = MakeDecoder<float>(
168 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
169 cellLayerNormWeights = MakeDecoder<float>(
170 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
171 outputLayerNormWeights = MakeDecoder<float>(
172 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
177 inputToInputWeightsTensor = MakeDecoder<float>(
178 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
179 inputGateBiasTensor = MakeDecoder<float>(
180 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
181 recurrentToInputWeightsTensor = MakeDecoder<float>(
182 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
187 cellToForgetWeightsTensor = MakeDecoder<float>(
188 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
189 cellToOutputWeightsTensor = MakeDecoder<float>(
190 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
193 if (!useCifg && usePeephole)
195 cellToInputWeightsTensor = MakeDecoder<float>(
196 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
201 projectionWeightsTensor = MakeDecoder<float>(
202 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
203 if (m_ProjectionBiasTensor)
205 projectionBiasTensor = MakeDecoder<float>(
206 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
213 inputToOutputWeightsShape,
214 recurrentToOutputWeightsShape,
223 inputToInputWeightsTensor,
224 inputToForgetWeightsTensor,
225 inputToCellWeightsTensor,
226 inputToOutputWeightsTensor,
227 recurrentToInputWeightsTensor,
228 recurrentToForgetWeightsTensor,
229 recurrentToCellWeightsTensor,
230 recurrentToOutputWeightsTensor,
231 cellToInputWeightsTensor,
232 cellToForgetWeightsTensor,
233 cellToOutputWeightsTensor,
235 forgetGateBiasTensor,
237 outputGateBiasTensor,
238 projectionWeightsTensor,
239 projectionBiasTensor,
240 inputLayerNormWeights,
241 forgetLayerNormWeights,
242 cellLayerNormWeights,
243 outputLayerNormWeights,
248 inputGateScratchDecoder,
250 forgetGateScratchDecoder,
251 outputGateScratchDecoder,