15 using namespace armcomputetensorutils;
19 const arm_compute::CLCompileContext& clCompileContext)
28 arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
31 m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
32 BuildArmComputeTensor(*m_InputToForgetWeightsTensor,
m_Data.m_InputToForgetWeights->GetTensorInfo());
34 m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
35 BuildArmComputeTensor(*m_InputToCellWeightsTensor,
m_Data.m_InputToCellWeights->GetTensorInfo());
37 m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
38 BuildArmComputeTensor(*m_InputToOutputWeightsTensor,
m_Data.m_InputToOutputWeights->GetTensorInfo());
40 m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
41 BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor,
m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
43 m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
44 BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor,
m_Data.m_RecurrentToCellWeights->GetTensorInfo());
46 m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
47 BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor,
m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
49 m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
50 BuildArmComputeTensor(*m_ForgetGateBiasTensor,
m_Data.m_ForgetGateBias->GetTensorInfo());
52 m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
53 BuildArmComputeTensor(*m_CellBiasTensor,
m_Data.m_CellBias->GetTensorInfo());
55 m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
56 BuildArmComputeTensor(*m_OutputGateBiasTensor,
m_Data.m_OutputGateBias->GetTensorInfo());
59 if (
m_Data.m_Parameters.m_PeepholeEnabled)
61 m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
63 if (!
m_Data.m_Parameters.m_CifgEnabled)
66 BuildArmComputeTensor(*m_CellToInputWeightsTensor,
m_Data.m_CellToInputWeights->GetTensorInfo());
69 m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
70 BuildArmComputeTensor(*m_CellToForgetWeightsTensor,
m_Data.m_CellToForgetWeights->GetTensorInfo());
72 m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
73 BuildArmComputeTensor(*m_CellToOutputWeightsTensor,
m_Data.m_CellToOutputWeights->GetTensorInfo());
76 qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
77 m_CellToOutputWeightsTensor.get());
80 if (
m_Data.m_Parameters.m_ProjectionEnabled)
82 m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
83 BuildArmComputeTensor(*m_ProjectionWeightsTensor,
m_Data.m_ProjectionWeights->GetTensorInfo());
85 m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
86 if (
m_Data.m_ProjectionBias !=
nullptr)
88 BuildArmComputeTensor(*m_ProjectionBiasTensor,
m_Data.m_ProjectionBias->GetTensorInfo());
92 qLstmParams.set_projection_params(
93 m_ProjectionWeightsTensor.get(),
94 m_Data.m_ProjectionBias !=
nullptr ? m_ProjectionBiasTensor.get() :
nullptr);
97 if (
m_Data.m_Parameters.m_LayerNormEnabled)
99 m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
101 if (!
m_Data.m_Parameters.m_CifgEnabled)
103 BuildArmComputeTensor(*m_InputLayerNormWeightsTensor,
m_Data.m_InputLayerNormWeights->GetTensorInfo());
106 m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
107 BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor,
m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
109 m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
110 BuildArmComputeTensor(*m_CellLayerNormWeightsTensor,
m_Data.m_CellLayerNormWeights->GetTensorInfo());
112 m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
113 BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor,
m_Data.m_OutputLayerNormWeights->GetTensorInfo());
116 qLstmParams.set_layer_normalization_params(
117 m_Data.m_InputLayerNormWeights !=
nullptr ? m_InputLayerNormWeightsTensor.get() :
nullptr,
118 m_ForgetLayerNormWeightsTensor.get(),
119 m_CellLayerNormWeightsTensor.get(),
120 m_OutputLayerNormWeightsTensor.get());
123 if (!
m_Data.m_Parameters.m_CifgEnabled)
125 m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
126 BuildArmComputeTensor(*m_InputToInputWeightsTensor,
m_Data.m_InputToInputWeights->GetTensorInfo());
128 m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
129 BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor,
m_Data.m_RecurrentToInputWeights->GetTensorInfo());
131 m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
132 BuildArmComputeTensor(*m_InputGateBiasTensor,
m_Data.m_InputGateBias->GetTensorInfo());
135 qLstmParams.set_cifg_params(
136 m_InputToInputWeightsTensor.get(),
137 m_RecurrentToInputWeightsTensor.get(),
138 m_Data.m_CellToInputWeights !=
nullptr ? m_CellToInputWeightsTensor.get() :
nullptr,
139 m_InputGateBiasTensor.get());
152 qLstmParams.set_cell_clip_params(
m_Data.m_Parameters.m_CellClip);
153 qLstmParams.set_projection_clip_params(
m_Data.m_Parameters.m_ProjectionClip);
154 qLstmParams.set_hidden_state_params(
m_Data.m_Parameters.m_HiddenStateZeroPoint,
155 m_Data.m_Parameters.m_HiddenStateScale);
156 qLstmParams.set_matmul_scale_params(
m_Data.m_Parameters.m_InputIntermediateScale,
157 m_Data.m_Parameters.m_ForgetIntermediateScale,
158 m_Data.m_Parameters.m_CellIntermediateScale,
159 m_Data.m_Parameters.m_OutputIntermediateScale);
164 m_QLstmLayer.configure(clCompileContext,
166 m_InputToForgetWeightsTensor.get(),
167 m_InputToCellWeightsTensor.get(),
168 m_InputToOutputWeightsTensor.get(),
169 m_RecurrentToForgetWeightsTensor.get(),
170 m_RecurrentToCellWeightsTensor.get(),
171 m_RecurrentToOutputWeightsTensor.get(),
172 m_ForgetGateBiasTensor.get(),
173 m_CellBiasTensor.get(),
174 m_OutputGateBiasTensor.get(),
197 if (!
m_Data.m_Parameters.m_CifgEnabled)
204 if (
m_Data.m_Parameters.m_ProjectionEnabled)
208 if (
m_Data.m_ProjectionBias !=
nullptr)
214 if (
m_Data.m_Parameters.m_PeepholeEnabled)
216 if (!
m_Data.m_Parameters.m_CifgEnabled)
225 if (
m_Data.m_Parameters.m_LayerNormEnabled)
227 if (!
m_Data.m_Parameters.m_CifgEnabled)
236 m_QLstmLayer.prepare();
256 arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
259 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
260 const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
261 const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
263 const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
264 const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
265 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
268 const arm_compute::TensorInfo aclInputToForgetWeightsInfo
270 const arm_compute::TensorInfo aclInputToCellWeightsInfo
272 const arm_compute::TensorInfo aclInputToOutputWeightsInfo
274 const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
276 const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
278 const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
280 const arm_compute::TensorInfo aclForgetGateBiasInfo
282 const arm_compute::TensorInfo aclCellBiasInfo
283 = BuildArmComputeTensorInfo(paramsInfo.
GetCellBias());
284 const arm_compute::TensorInfo aclOutputGateBiasInfo
288 arm_compute::TensorInfo aclInputToInputWeightsInfo;
289 arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
291 arm_compute::TensorInfo aclCellToInputWeightsInfo;
292 arm_compute::TensorInfo aclCellToForgetWeightsInfo;
293 arm_compute::TensorInfo aclCellToOutputWeightsInfo;
295 arm_compute::TensorInfo aclInputGateBiasInfo;
297 arm_compute::TensorInfo aclProjectionWeightsInfo;
298 arm_compute::TensorInfo aclProjectionBiasInfo;
300 arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
301 arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
302 arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
303 arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
317 aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
318 &aclCellToOutputWeightsInfo);
327 aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.
GetProjectionBias());
331 aclParamsInfo.set_projection_params(
332 &aclProjectionWeightsInfo,
348 aclParamsInfo.set_layer_normalization_params(
350 &aclForgetLayerNormWeightsInfo,
351 &aclCellLayerNormWeightsInfo,
352 &aclOutputLayerNormWeightsInfo);
359 aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.
GetInputGateBias());
362 aclParamsInfo.set_cifg_params(
363 &aclInputToInputWeightsInfo,
364 &aclRecurrentToInputWeightsInfo,
366 &aclInputGateBiasInfo);
370 aclParamsInfo.set_cell_clip_params(descriptor.
m_CellClip);
379 return arm_compute::CLQLSTMLayer::validate(&aclInputInfo,
380 &aclInputToForgetWeightsInfo,
381 &aclInputToCellWeightsInfo,
382 &aclInputToOutputWeightsInfo,
383 &aclRecurrentToForgetWeightsInfo,
384 &aclRecurrentToCellWeightsInfo,
385 &aclRecurrentToOutputWeightsInfo,
386 &aclForgetGateBiasInfo,
388 &aclOutputGateBiasInfo,
390 &aclOutputStateInInfo,
391 &aclCellStateOutInfo,
392 &aclOutputStateOutInfo,
397 void ClQLstmWorkload::FreeUnusedTensors()
399 FreeTensorIfUnused(m_InputToInputWeightsTensor);
400 FreeTensorIfUnused(m_InputToForgetWeightsTensor);
401 FreeTensorIfUnused(m_InputToCellWeightsTensor);
402 FreeTensorIfUnused(m_InputToOutputWeightsTensor);
404 FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
405 FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
406 FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
407 FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
409 FreeTensorIfUnused(m_CellToInputWeightsTensor);
410 FreeTensorIfUnused(m_CellToForgetWeightsTensor);
411 FreeTensorIfUnused(m_CellToOutputWeightsTensor);
413 FreeTensorIfUnused(m_InputGateBiasTensor);
414 FreeTensorIfUnused(m_ForgetGateBiasTensor);
415 FreeTensorIfUnused(m_CellBiasTensor);
416 FreeTensorIfUnused(m_OutputGateBiasTensor);
418 FreeTensorIfUnused(m_ProjectionWeightsTensor);
419 FreeTensorIfUnused(m_ProjectionBiasTensor);
421 FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
422 FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
423 FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
424 FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
ClQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
virtual void Execute() const override
Copyright (c) 2021 ARM Limited and Contributors.
void InitializeArmComputeClTensorData(arm_compute::CLTensor &clTensor, const ConstTensorHandle *handle)
arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo &input, const TensorInfo &cellStateIn, const TensorInfo &outputStateIn, const TensorInfo &cellStateOut, const TensorInfo &outputStateOut, const TensorInfo &output, const QLstmDescriptor &descriptor, const LstmInputParamsInfo ¶msInfo)
A QLstmDescriptor for the QLstmLayer.
float m_CellIntermediateScale
Cell intermediate quantization scale.
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_OutputIntermediateScale
Output intermediate quantization scale.
float m_ProjectionClip
Clipping threshold value for the projection.
float m_CellClip
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
float m_HiddenStateScale
Hidden State quantization scale.
float m_ForgetIntermediateScale
Forget intermediate quantization scale.
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
LayerDescriptor m_Parameters
Contains information about TensorInfos of a layer.