59 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->
Map());
60 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->
Map());
61 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->
Map());
63 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->
Map());
64 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->
Map());
66 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->
Map());
67 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->
Map());
68 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->
Map());
70 const uint32_t nBatch = inputShape[0];
71 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
73 const bool useCifg =
m_Data.m_Parameters.m_CifgEnabled;
74 const bool usePeephole =
m_Data.m_Parameters.m_PeepholeEnabled;
75 const bool useLayerNorm =
m_Data.m_Parameters.m_LayerNormEnabled;
78 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
79 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
80 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
81 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
83 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
84 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
85 std::unique_ptr<Decoder<float>> cellScratchDecoder =
86 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
87 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
88 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
89 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->
Map());
94 *cellScratch += (0 * nCell * nBatch);
95 *forgetGateScratch += (1 * nCell * nBatch);
96 *outputGateScratch += (2 * nCell * nBatch);
98 *cellScratchDecoder += (0 * nCell * nBatch);
99 *forgetGateScratchDecoder += (1 * nCell * nBatch);
100 *outputGateScratchDecoder += (2 * nCell * nBatch);
104 *inputGateScratch += (0 * nCell * nBatch);
105 *cellScratch += (1 * nCell * nBatch);
106 *forgetGateScratch += (2 * nCell * nBatch);
107 *outputGateScratch += (3 * nCell * nBatch);
109 *inputGateScratchDecoder += (0 * nCell * nBatch);
110 *cellScratchDecoder += (1 * nCell * nBatch);
111 *forgetGateScratchDecoder += (2 * nCell * nBatch);
112 *outputGateScratchDecoder += (3 * nCell * nBatch);
115 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
116 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
117 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
118 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
119 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
120 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
121 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
123 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
124 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
125 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
126 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
127 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
128 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
129 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
131 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
132 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
133 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
134 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
135 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
136 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
137 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
139 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
140 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
141 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
143 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
144 std::unique_ptr<Decoder<float>> projectionBiasTensor;
146 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
147 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
148 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
149 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
151 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
152 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
158 inputLayerNormWeights = MakeDecoder<float>(
159 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
161 forgetLayerNormWeights = MakeDecoder<float>(
162 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
163 cellLayerNormWeights = MakeDecoder<float>(
164 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
165 outputLayerNormWeights = MakeDecoder<float>(
166 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
171 inputToInputWeightsTensor = MakeDecoder<float>(
172 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
173 inputGateBiasTensor = MakeDecoder<float>(
174 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
175 recurrentToInputWeightsTensor = MakeDecoder<float>(
176 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
181 cellToForgetWeightsTensor = MakeDecoder<float>(
182 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
183 cellToOutputWeightsTensor = MakeDecoder<float>(
184 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
187 if (!useCifg && usePeephole)
189 cellToInputWeightsTensor = MakeDecoder<float>(
190 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
193 if (
m_Data.m_Parameters.m_ProjectionEnabled)
195 projectionWeightsTensor = MakeDecoder<float>(
196 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
197 if (m_ProjectionBiasTensor)
199 projectionBiasTensor = MakeDecoder<float>(
200 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
207 inputToOutputWeightsShape,
208 recurrentToOutputWeightsShape,
217 inputToInputWeightsTensor,
218 inputToForgetWeightsTensor,
219 inputToCellWeightsTensor,
220 inputToOutputWeightsTensor,
221 recurrentToInputWeightsTensor,
222 recurrentToForgetWeightsTensor,
223 recurrentToCellWeightsTensor,
224 recurrentToOutputWeightsTensor,
225 cellToInputWeightsTensor,
226 cellToForgetWeightsTensor,
227 cellToOutputWeightsTensor,
229 forgetGateBiasTensor,
231 outputGateBiasTensor,
232 projectionWeightsTensor,
233 projectionBiasTensor,
234 inputLayerNormWeights,
235 forgetLayerNormWeights,
236 cellLayerNormWeights,
237 outputLayerNormWeights,
242 inputGateScratchDecoder,
244 forgetGateScratchDecoder,
245 outputGateScratchDecoder,
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
const TensorShape & GetShape() const
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
Contains information about TensorInfos of a layer.