17 void LstmImpl(
const LstmDescriptor& descriptor,
18 const TensorInfo& inputInfo,
19 const TensorInfo& outputInfo,
20 const TensorShape& inputToOutputWeightsShape,
21 const TensorShape& recurrentToOutputWeightsShape,
22 std::unique_ptr<Decoder<float>>& inputData,
23 std::unique_ptr<Decoder<float>>& outputStateIn,
24 std::unique_ptr<Decoder<float>>& cellStateIn,
25 std::unique_ptr<Encoder<float>>& outputStateOut,
26 std::unique_ptr<Encoder<float>>& cellStateOut,
27 std::unique_ptr<Encoder<float>>& output,
28 std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
29 std::unique_ptr<Decoder<float>>& outputDecoder,
30 std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
31 std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
32 std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
33 std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
34 std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
35 std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
36 std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
37 std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
38 std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
39 std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
40 std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
41 std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
42 std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
43 std::unique_ptr<Decoder<float>>& cellBiasTensor,
44 std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
45 std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
46 std::unique_ptr<Decoder<float>>& projectionBiasTensor,
47 std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
48 std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
49 std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
50 std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
51 std::unique_ptr<Encoder<float>>& inputGateScratch,
52 std::unique_ptr<Encoder<float>>& cellScratch,
53 std::unique_ptr<Encoder<float>>& forgetGateScratch,
54 std::unique_ptr<Encoder<float>>& outputGateScratch,
55 std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
56 std::unique_ptr<Decoder<float>>& cellScratchDecoder,
57 std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
58 std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
59 float layerNormEpsilon);