ArmNN
 24.11
ClGatherNdWorkload Class Reference

#include <ClGatherNdWorkload.hpp>

Inheritance diagram for ClGatherNdWorkload:
[legend]
Collaboration diagram for ClGatherNdWorkload:
[legend]

Public Member Functions

 ClGatherNdWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
 
virtual void Execute () const override
 
- Public Member Functions inherited from ClBaseWorkload< GatherNdQueueDescriptor >
 ClBaseWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info)
 
void ReplaceInputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
 
void ReplaceOutputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
 
- Public Member Functions inherited from BaseWorkload< GatherNdQueueDescriptor >
 BaseWorkload (const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info)
 
virtual const std::string & GetName () const override
 
void ExecuteAsync (ExecutionData &executionData) override
 
void PostAllocationConfigure () override
 
const GatherNdQueueDescriptorGetData () const
 
arm::pipe::ProfilingGuid GetGuid () const final
 
virtual bool SupportsTensorHandleReplacement () const override
 
- Public Member Functions inherited from IWorkload
virtual ~IWorkload ()
 
virtual arm::pipe::ProfilingGuid GetGuid () const =0
 
virtual bool SupportsTensorHandleReplacement () const =0
 
virtual const std::string & GetName () const =0
 
virtual void RegisterDebugCallback (const DebugCallbackFunction &)
 
virtual armnn::Optional< armnn::MemoryRequirementsGetMemoryRequirements ()
 

Additional Inherited Members

- Protected Member Functions inherited from ClBaseWorkload< GatherNdQueueDescriptor >
virtual void Reconfigure ()
 
- Protected Attributes inherited from BaseWorkload< GatherNdQueueDescriptor >
GatherNdQueueDescriptor m_Data
 
const arm::pipe::ProfilingGuid m_Guid
 
const std::string m_Name
 

Detailed Description

Definition at line 22 of file ClGatherNdWorkload.hpp.

Constructor & Destructor Documentation

◆ ClGatherNdWorkload()

ClGatherNdWorkload ( const GatherNdQueueDescriptor descriptor,
const WorkloadInfo info,
const arm_compute::CLCompileContext &  clCompileContext 
)

Calculate flattened indices: m_FlattenedIndices = indices * m_FlattenedCoeff. This could be done using MatMul instead of multiplication followed by reduce sum operation, but GeMM does not support s32 at the moment.

Call Gather with adequate shapes

Definition at line 109 of file ClGatherNdWorkload.cpp.

112  : ClBaseWorkload<GatherNdQueueDescriptor>(descriptor, info)
113 {
114  m_Data.ValidateInputsOutputs("ClGatherNdWorkload", 2, 1);
115 
116  TensorInfo paramsInfo = info.m_InputTensorInfos[0];
117  TensorInfo indicesInfo = info.m_InputTensorInfos[1];
118  TensorInfo outputInfo = info.m_OutputTensorInfos[0];
119 
120  arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
121  arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
122  arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
123 
124  // Calculate ND, K, W, C.
125  std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(paramsInfo, indicesInfo);
126 
127  /// Calculate flattened indices: m_FlattenedIndices = indices * m_FlattenedCoeff.
128  /// This could be done using MatMul instead of multiplication followed by reduce sum operation,
129  /// but GeMM does not support s32 at the moment.
130 
131  // Prepare the tensor to store the output of the reduce_sum operation
132  armnn::TensorInfo flattenedIndices_Info = indicesInfo;
133  flattenedIndices_Info.SetShape({ keyIndices["W"] });
134  BuildArmComputeTensor(m_FlattenedIndices, flattenedIndices_Info);
135  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedIndices);
136 
137  // Reshape indices into { W, ND }
138  armnn::TensorInfo indicesInfoReshape = indicesInfo;
139  indicesInfoReshape.SetShape({ keyIndices["W"], keyIndices["ND"] });
140  BuildArmComputeTensor(m_IndicesReshaped, indicesInfoReshape);
141  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_IndicesReshaped);
142 
143  // Calculate the m_FlattenedCoeff
144  TensorShape paramsShape = paramsInfo.GetShape();
145  std::vector<int32_t> flattenedCoeff(keyIndices["ND"], 1);
146  for (unsigned int i = 1; i < keyIndices["ND"]; ++i)
147  {
148  flattenedCoeff[i - 1] = static_cast<int32_t>(paramsShape[i]);
149  }
150  for (unsigned int i = keyIndices["ND"] - 1; i > 0; --i)
151  {
152  flattenedCoeff[i - 1] *= flattenedCoeff[i];
153  }
154  armnn::TensorInfo flattenedCoeff_Info = indicesInfo;
155  flattenedCoeff_Info.SetShape({ keyIndices["ND"] });
156  BuildArmComputeTensor(m_FlattenedCoeff, flattenedCoeff_Info);
157  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedCoeff);
158  CopyArmComputeClTensorData<int32_t>(m_FlattenedCoeff, flattenedCoeff.data());
159 
160  // Prepare the tensor to store the output of the multiplication
161  armnn::TensorInfo outputMul_Info = indicesInfo;
162  outputMul_Info.SetShape({ keyIndices["W"], keyIndices["ND"] });
163  BuildArmComputeTensor(m_OutputMul, outputMul_Info);
164  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputMul);
165 
166  // Reshape indices to the mul layer input shape
167  m_ReshapeIndicesLayer.configure(&indices, &m_IndicesReshaped);
168 
169  // Multiply
170  m_MulLayer.configure(clCompileContext,
171  &m_IndicesReshaped,
172  &m_FlattenedCoeff,
173  &m_OutputMul,
174  1.0f,
175  arm_compute::ConvertPolicy::WRAP,
176  arm_compute::RoundingPolicy::TO_ZERO,
177  arm_compute::ActivationLayerInfo());
178 
179  // Reduce Sum
180  const std::vector<unsigned int> armnnReduceAxes(1, 1);
181  arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(m_OutputMul.info()->num_dimensions(),
182  outputMul_Info.GetNumDimensions(),
183  armnnReduceAxes);
184  m_ReduceSumLayer.configure(clCompileContext,
185  &m_OutputMul,
186  &m_FlattenedIndices,
187  static_cast<unsigned int>(coords[0]),
188  arm_compute::ReductionOperation::SUM,
189  false);
190 
191  /// Call Gather with adequate shapes
192  // Reshape params into { K, C }
193  armnn::TensorInfo paramsInfoReshape = paramsInfo;
194  paramsInfoReshape.SetShape({ keyIndices["K"], keyIndices["C"] });
195  BuildArmComputeTensor(m_InputGather, paramsInfoReshape);
196  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_InputGather);
197 
198  // Reshape input to the gather params input shape
199  m_ReshapeInputLayer.configure(&input, &m_InputGather);
200 
201  // Reshape output to have the shape given by gather { W, C }
202  // (the original outputInfo has the shape given by gatherNd)
203  armnn::TensorInfo outputGather_Info = outputInfo;
204  outputGather_Info.SetShape({ keyIndices["W"], keyIndices["C"] });
205  BuildArmComputeTensor(m_OutputGather, outputGather_Info);
206  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputGather);
207  {
208  ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClGatherNdWorkload_configure");
209  auto aclAxis = ComputeAclAxis(0, paramsInfoReshape);
210  m_GatherLayer.configure(clCompileContext, &m_InputGather, &m_FlattenedIndices, &m_OutputGather, aclAxis);
211  }
212 
213  // Reshape output to the original output shape
214  m_ReshapeOutputLayer.configure(clCompileContext, &m_OutputGather, &output);
215 };

References armnn::CalculateGatherNdKeyIndices(), armnn::info, BaseWorkload< GatherNdQueueDescriptor >::m_Data, QueueDescriptor::m_Inputs, QueueDescriptor::m_Outputs, TensorInfo::SetShape(), and QueueDescriptor::ValidateInputsOutputs().

Member Function Documentation

◆ Execute()

void Execute ( ) const
overridevirtual

Implements IWorkload.

Definition at line 217 of file ClGatherNdWorkload.cpp.

218 {
219  ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClGatherNdWorkload_Execute");
220  RunClFunction(m_ReshapeInputLayer, CHECK_LOCATION());
221  RunClFunction(m_ReshapeIndicesLayer, CHECK_LOCATION());
222  RunClFunction(m_MulLayer, CHECK_LOCATION());
223  RunClFunction(m_ReduceSumLayer, CHECK_LOCATION());
224  RunClFunction(m_GatherLayer, CHECK_LOCATION());
225  RunClFunction(m_ReshapeOutputLayer, CHECK_LOCATION());
226 }

References ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID, CHECK_LOCATION, and armnn::RunClFunction().


The documentation for this class was generated from the following files:
armnn::RunClFunction
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)
Definition: ClWorkloadUtils.hpp:167
armnn::QueueDescriptor::ValidateInputsOutputs
void ValidateInputsOutputs(const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
Definition: WorkloadData.cpp:447
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
CHECK_LOCATION
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
armnn::Coordinates
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
Definition: InternalTypes.hpp:15
ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
Definition: ClWorkloadUtils.hpp:36
armnn::CalculateGatherNdKeyIndices
std::map< std::string, unsigned int > CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
Definition: WorkloadUtils.cpp:313
armnn::BoostLogSeverityMapping::info
@ info
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::BaseWorkload< GatherNdQueueDescriptor >::m_Data
GatherNdQueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::TensorInfo::SetShape
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:195
armnn::ComputeAclAxis
int ComputeAclAxis(const int &armnnAxis, const armnn::TensorInfo &tensor)
Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank,...
Definition: ArmComputeUtils.hpp:246
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26