63 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
MakeDecoder<float>(outputInfo, outputs[2]->
Map());
70 const uint32_t nBatch = inputShape[0];
71 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
80 std::unique_ptr<Encoder<float>> forgetGateScratch =
MakeEncoder<float>(outputInfo, outputs[0]->
Map());
81 std::unique_ptr<Encoder<float>> outputGateScratch =
MakeEncoder<float>(outputInfo, outputs[0]->
Map());
83 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
85 std::unique_ptr<Decoder<float>> cellScratchDecoder =
87 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
89 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94 *cellScratch += (0 * nCell * nBatch);
95 *forgetGateScratch += (1 * nCell * nBatch);
96 *outputGateScratch += (2 * nCell * nBatch);
98 *cellScratchDecoder += (0 * nCell * nBatch);
99 *forgetGateScratchDecoder += (1 * nCell * nBatch);
100 *outputGateScratchDecoder += (2 * nCell * nBatch);
104 *inputGateScratch += (0 * nCell * nBatch);
105 *cellScratch += (1 * nCell * nBatch);
106 *forgetGateScratch += (2 * nCell * nBatch);
107 *outputGateScratch += (3 * nCell * nBatch);
109 *inputGateScratchDecoder += (0 * nCell * nBatch);
110 *cellScratchDecoder += (1 * nCell * nBatch);
111 *forgetGateScratchDecoder += (2 * nCell * nBatch);
112 *outputGateScratchDecoder += (3 * nCell * nBatch);
115 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
117 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<
void>());
119 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<
void>());
121 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<
void>());
123 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
125 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<
void>());
127 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<
void>());
129 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<
void>());
131 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
133 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<
void>());
135 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<
void>());
137 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<
void>());
139 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
140 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
141 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
143 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
144 std::unique_ptr<Decoder<float>> projectionBiasTensor;
146 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
147 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
148 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
149 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
151 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
152 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
159 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<
void>());
162 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<
void>());
164 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<
void>());
166 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<
void>());
172 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<
void>());
174 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<
void>());
176 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<
void>());
182 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<
void>());
184 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<
void>());
187 if (!useCifg && usePeephole)
190 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<
void>());
193 if (
m_Data.m_Parameters.m_ProjectionEnabled)
196 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<
void>());
197 if (m_ProjectionBiasTensor)
200 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<
void>());
207 inputToOutputWeightsShape,
208 recurrentToOutputWeightsShape,
217 inputToInputWeightsTensor,
218 inputToForgetWeightsTensor,
219 inputToCellWeightsTensor,
220 inputToOutputWeightsTensor,
221 recurrentToInputWeightsTensor,
222 recurrentToForgetWeightsTensor,
223 recurrentToCellWeightsTensor,
224 recurrentToOutputWeightsTensor,
225 cellToInputWeightsTensor,
226 cellToForgetWeightsTensor,
227 cellToOutputWeightsTensor,
229 forgetGateBiasTensor,
231 outputGateBiasTensor,
232 projectionWeightsTensor,
233 projectionBiasTensor,
234 inputLayerNormWeights,
235 forgetLayerNormWeights,
236 cellLayerNormWeights,
237 outputLayerNormWeights,
242 inputGateScratchDecoder,
244 forgetGateScratchDecoder,
245 outputGateScratchDecoder,
std::unique_ptr< armnn::ScopedTensorHandle > AssignScopedTensorHandle(const armnn::ConstTensorHandle *ptr)
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
LstmQueueDescriptor m_Data
RefBaseWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void Execute() const override
const TensorShape & GetShape() const
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float > > &inputData, std::unique_ptr< Decoder< float > > &outputStateIn, std::unique_ptr< Decoder< float > > &cellStateIn, std::unique_ptr< Encoder< float > > &outputStateOut, std::unique_ptr< Encoder< float > > &cellStateOut, std::unique_ptr< Encoder< float > > &output, std::unique_ptr< Decoder< float > > &cellStateOutDecoder, std::unique_ptr< Decoder< float > > &outputDecoder, std::unique_ptr< Decoder< float > > &inputToInputWeightsTensor, std::unique_ptr< Decoder< float > > &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &inputToCellWeightsTensor, std::unique_ptr< Decoder< float > > &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToInputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &inputGateBiasTensor, std::unique_ptr< Decoder< float > > &forgetGateBiasTensor, std::unique_ptr< Decoder< float > > &cellBiasTensor, std::unique_ptr< Decoder< float > > &outputGateBiasTensor, std::unique_ptr< Decoder< float > > &projectionWeightsTensor, std::unique_ptr< Decoder< float > > &projectionBiasTensor, std::unique_ptr< Decoder< float > > &inputLayerNormWeights, std::unique_ptr< Decoder< float > > &forgetLayerNormWeights, std::unique_ptr< Decoder< float > > &cellLayerNormWeights, std::unique_ptr< Decoder< float > > &outputLayerNormWeights, std::unique_ptr< Encoder< float > > &inputGateScratch, std::unique_ptr< Encoder< float > > &cellScratch, std::unique_ptr< Encoder< float > > &forgetGateScratch, std::unique_ptr< Encoder< float > > &outputGateScratch, std::unique_ptr< Decoder< float > > &inputGateScratchDecoder, std::unique_ptr< Decoder< float > > &cellScratchDecoder, std::unique_ptr< Decoder< float > > &forgetGateScratchDecoder, std::unique_ptr< Decoder< float > > &outputGateScratchDecoder, float layerNormEpsilon)
std::unique_ptr< Decoder< T > > MakeDecoder(const TensorInfo &info, const void *data=nullptr)
std::unique_ptr< Encoder< T > > MakeEncoder(const TensorInfo &info, void *data=nullptr)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
LayerDescriptor m_Parameters
Contains information about TensorInfos of a layer.