ArmNN
 25.11
Loading...
Searching...
No Matches
ClQLstmWorkload Class Reference

#include <ClQLstmWorkload.hpp>

Inheritance diagram for ClQLstmWorkload:
[legend]
Collaboration diagram for ClQLstmWorkload:
[legend]

Public Member Functions

 ClQLstmWorkload (const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
virtual void Execute () const override
Public Member Functions inherited from ClBaseWorkload< QLstmQueueDescriptor >
 ClBaseWorkload (const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
void ReplaceInputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
void ReplaceOutputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
Public Member Functions inherited from BaseWorkload< QLstmQueueDescriptor >
 BaseWorkload (const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
virtual const std::string & GetName () const override
void PostAllocationConfigure () override
const QLstmQueueDescriptorGetData () const
arm::pipe::ProfilingGuid GetGuid () const final
virtual bool SupportsTensorHandleReplacement () const override
Public Member Functions inherited from IWorkload
virtual ~IWorkload ()
virtual void RegisterDebugCallback (const DebugCallbackFunction &)
virtual armnn::Optional< armnn::MemoryRequirementsGetMemoryRequirements ()

Additional Inherited Members

Protected Member Functions inherited from ClBaseWorkload< QLstmQueueDescriptor >
virtual void Reconfigure ()
Protected Attributes inherited from BaseWorkload< QLstmQueueDescriptor >
QLstmQueueDescriptor m_Data
const arm::pipe::ProfilingGuid m_Guid
const std::string m_Name

Detailed Description

Definition at line 19 of file ClQLstmWorkload.hpp.

Constructor & Destructor Documentation

◆ ClQLstmWorkload()

ClQLstmWorkload ( const QLstmQueueDescriptor & descriptor,
const WorkloadInfo & info,
const arm_compute::CLCompileContext & clCompileContext )

Definition at line 17 of file ClQLstmWorkload.cpp.

20 : ClBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
21{
22 // Report Profiling Details
23 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClQLstmWorkload_Construct",
24 descriptor.m_Parameters,
25 info,
26 this->GetGuid());
27
28 arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
29
30 // Mandatory params
31 m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
32 BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
33
34 m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
35 BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
36
37 m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
38 BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
39
40 m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
41 BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
42
43 m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
44 BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
45
46 m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
47 BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
48
49 m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
50 BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
51
52 m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
53 BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
54
55 m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
56 BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
57
58 // Create tensors for optional params if they are enabled
59 if (m_Data.m_Parameters.m_PeepholeEnabled)
60 {
61 m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
62
63 if (!m_Data.m_Parameters.m_CifgEnabled)
64 {
65 // In ACL this is categorised as a CIFG param and not a Peephole param
66 BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
67 }
68
69 m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
70 BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
71
72 m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
73 BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
74
75 // Set Peephole params
76 qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
77 m_CellToOutputWeightsTensor.get());
78 }
79
80 if (m_Data.m_Parameters.m_ProjectionEnabled)
81 {
82 m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
83 BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
84
85 m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
86 if (m_Data.m_ProjectionBias != nullptr)
87 {
88 BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
89 }
90
91 // Set projection params
92 qLstmParams.set_projection_params(
93 m_ProjectionWeightsTensor.get(),
94 m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
95 }
96
97 if (m_Data.m_Parameters.m_LayerNormEnabled)
98 {
99 m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
100
101 if (!m_Data.m_Parameters.m_CifgEnabled)
102 {
103 BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
104 }
105
106 m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
107 BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
108
109 m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
110 BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
111
112 m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
113 BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
114
115 // Set layer norm params
116 qLstmParams.set_layer_normalization_params(
117 m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
118 m_ForgetLayerNormWeightsTensor.get(),
119 m_CellLayerNormWeightsTensor.get(),
120 m_OutputLayerNormWeightsTensor.get());
121 }
122
123 if (!m_Data.m_Parameters.m_CifgEnabled)
124 {
125 m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
126 BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
127
128 m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
129 BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
130
131 m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
132 BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
133
134 // Set CIFG params
135 qLstmParams.set_cifg_params(
136 m_InputToInputWeightsTensor.get(),
137 m_RecurrentToInputWeightsTensor.get(),
138 m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
139 m_InputGateBiasTensor.get());
140 }
141
142 // Input/Output tensors
143 const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
144 arm_compute::ICLTensor& outputStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
145 arm_compute::ICLTensor& cellStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
146
147 arm_compute::ICLTensor& outputStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
148 arm_compute::ICLTensor& cellStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
149 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
150
151 // Set scalar descriptor params
152 qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
153 qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
154 qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
155 m_Data.m_Parameters.m_HiddenStateScale);
156 qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
157 m_Data.m_Parameters.m_ForgetIntermediateScale,
158 m_Data.m_Parameters.m_CellIntermediateScale,
159 m_Data.m_Parameters.m_OutputIntermediateScale);
160
161 {
162 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClQLstmWorkload_configure");
163 // QLSTM CL configure
164 m_QLstmLayer.configure(clCompileContext,
165 &input,
166 m_InputToForgetWeightsTensor.get(),
167 m_InputToCellWeightsTensor.get(),
168 m_InputToOutputWeightsTensor.get(),
169 m_RecurrentToForgetWeightsTensor.get(),
170 m_RecurrentToCellWeightsTensor.get(),
171 m_RecurrentToOutputWeightsTensor.get(),
172 m_ForgetGateBiasTensor.get(),
173 m_CellBiasTensor.get(),
174 m_OutputGateBiasTensor.get(),
175 &cellStateIn,
176 &outputStateIn,
177 &cellStateOut,
178 &outputStateOut,
179 &output,
180 qLstmParams);
181 }
182
183 // Initialise ACL tensor data for mandatory params
184 InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
185 InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
186 InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
187
188 InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
189 InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
190 InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
191
192 InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
193 InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
194 InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
195
196 // Initialise ACL tensor data for optional params
197 if (!m_Data.m_Parameters.m_CifgEnabled)
198 {
199 InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
200 InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
201 InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
202 }
203
204 if (m_Data.m_Parameters.m_ProjectionEnabled)
205 {
206 InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
207
208 if (m_Data.m_ProjectionBias != nullptr)
209 {
210 InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
211 }
212 }
213
214 if (m_Data.m_Parameters.m_PeepholeEnabled)
215 {
216 if (!m_Data.m_Parameters.m_CifgEnabled)
217 {
218 InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
219 }
220
221 InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
222 InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
223 }
224
225 if (m_Data.m_Parameters.m_LayerNormEnabled)
226 {
227 if (!m_Data.m_Parameters.m_CifgEnabled)
228 {
229 InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
230 }
231 InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
232 InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
233 InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
234 }
235
236 m_QLstmLayer.prepare();
237
238 FreeUnusedTensors();
239}
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
void InitializeArmComputeClTensorData(arm_compute::CLTensor &clTensor, const ConstTensorHandle *handle)

References ARMNN_REPORT_PROFILING_WORKLOAD_DESC, ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID, ClBaseWorkload< QLstmQueueDescriptor >::ClBaseWorkload(), armnn::info, armnn::InitializeArmComputeClTensorData(), BaseWorkload< QLstmQueueDescriptor >::m_Data, and QueueDescriptorWithParameters< LayerDescriptor >::m_Parameters.

Member Function Documentation

◆ Execute()

void Execute ( ) const
overridevirtual

Implements IWorkload.

Definition at line 241 of file ClQLstmWorkload.cpp.

242{
243 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClQuantizedLstmWorkload_Execute");
244 m_QLstmLayer.run();
245}

References ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID.


The documentation for this class was generated from the following files: