12 using namespace armnn::armcomputetensorutils;
26 indices_W_ND_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
31 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
32 const arm_compute::TensorInfo aclFlattenedCoeffInfo = BuildArmComputeTensorInfo(flattenedCoeff_Info);
35 const arm_compute::TensorInfo aclOutputMulInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
37 auto statusMul = arm_compute::CLPixelWiseMultiplication::validate(&aclIndicesInfo,
38 &aclFlattenedCoeffInfo,
41 arm_compute::ConvertPolicy::WRAP,
42 arm_compute::RoundingPolicy::TO_ZERO,
43 arm_compute::ActivationLayerInfo());
48 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
49 const arm_compute::TensorInfo aclFlattenedIndicesInfo = BuildArmComputeTensorInfo(flattenedIndices_Info);
51 const std::vector<unsigned int> armnnReduceAxes(1, 1);
56 auto statusReduceSum = arm_compute::CLReductionOperation::validate(&aclOutputMulInfo,
57 &aclFlattenedIndicesInfo,
58 static_cast<unsigned int>(coords[0]),
59 arm_compute::ReductionOperation::SUM,
65 params_K_C_Info.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
66 const arm_compute::TensorInfo aclParamsInfo = BuildArmComputeTensorInfo(params_K_C_Info);
70 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
71 const arm_compute::TensorInfo aclOutputGatherInfo = BuildArmComputeTensorInfo(outputGather_Info);
75 arm_compute::CLGather::validate(&aclParamsInfo, &aclFlattenedIndicesInfo, &aclOutputGatherInfo, aclAxis);
78 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
79 const arm_compute::TensorInfo aclParamsOriginalShapeInfo = BuildArmComputeTensorInfo(paramsInfo);
80 const arm_compute::TensorInfo aclIndicesOriginalShapeInfo = BuildArmComputeTensorInfo(indicesInfo);
81 const arm_compute::TensorInfo aclParamsReshapeInfo = BuildArmComputeTensorInfo(paramsInfo);
82 const arm_compute::TensorInfo aclIndicesReshapeInfo = BuildArmComputeTensorInfo(indicesInfo);
84 auto statusOutputReshape = arm_compute::CLReshapeLayer::validate(&aclOutputGatherInfo, &aclOutputInfo);
85 auto statusParamsReshape = arm_compute::CLReshapeLayer::validate(&aclParamsOriginalShapeInfo,
86 &aclParamsReshapeInfo);
87 auto statusIndicesReshape = arm_compute::CLReshapeLayer::validate(&aclIndicesOriginalShapeInfo,
88 &aclIndicesReshapeInfo);
91 auto okCode = arm_compute::ErrorCode::OK;
92 if (statusMul.error_code() == okCode &&
93 statusReduceSum.error_code() == okCode &&
94 statusGather.error_code() == okCode &&
95 statusParamsReshape.error_code() == okCode &&
96 statusIndicesReshape.error_code() == okCode &&
97 statusOutputReshape.error_code() == okCode)
100 "All GatherND layers validate status OK.");
105 "GatherND layer validate status failed.");
111 const arm_compute::CLCompileContext& clCompileContext)
133 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
134 BuildArmComputeTensor(m_FlattenedIndices, flattenedIndices_Info);
135 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedIndices);
139 indicesInfoReshape.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
140 BuildArmComputeTensor(m_IndicesReshaped, indicesInfoReshape);
141 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_IndicesReshaped);
145 std::vector<int32_t> flattenedCoeff(keyIndices[
"ND"], 1);
146 for (
unsigned int i = 1; i < keyIndices[
"ND"]; ++i)
148 flattenedCoeff[i - 1] =
static_cast<int32_t
>(paramsShape[i]);
150 for (
unsigned int i = keyIndices[
"ND"] - 1; i > 0; --i)
152 flattenedCoeff[i - 1] *= flattenedCoeff[i];
155 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
156 BuildArmComputeTensor(m_FlattenedCoeff, flattenedCoeff_Info);
157 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedCoeff);
158 CopyArmComputeClTensorData<int32_t>(m_FlattenedCoeff, flattenedCoeff.data());
162 outputMul_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
163 BuildArmComputeTensor(m_OutputMul, outputMul_Info);
164 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputMul);
167 m_ReshapeIndicesLayer.configure(&indices, &m_IndicesReshaped);
170 m_MulLayer.configure(clCompileContext,
175 arm_compute::ConvertPolicy::WRAP,
176 arm_compute::RoundingPolicy::TO_ZERO,
177 arm_compute::ActivationLayerInfo());
180 const std::vector<unsigned int> armnnReduceAxes(1, 1);
184 m_ReduceSumLayer.configure(clCompileContext,
187 static_cast<unsigned int>(coords[0]),
188 arm_compute::ReductionOperation::SUM,
194 paramsInfoReshape.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
195 BuildArmComputeTensor(m_InputGather, paramsInfoReshape);
196 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_InputGather);
199 m_ReshapeInputLayer.configure(&input, &m_InputGather);
204 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
205 BuildArmComputeTensor(m_OutputGather, outputGather_Info);
206 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputGather);
210 m_GatherLayer.configure(clCompileContext, &m_InputGather, &m_FlattenedIndices, &m_OutputGather, aclAxis);
214 m_ReshapeOutputLayer.configure(clCompileContext, &m_OutputGather, &output);
#define ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
ClGatherNdWorkload(const GatherNdQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
virtual void Execute() const override
unsigned int GetNumDimensions() const
const TensorShape & GetShape() const
void SetShape(const TensorShape &newShape)
Copyright (c) 2021 ARM Limited and Contributors.
int ComputeAclAxis(const int &armnnAxis, const armnn::TensorInfo &tensor)
Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank,...
std::map< std::string, unsigned int > CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
arm_compute::Status ClGatherNdWorkloadValidate(const TensorInfo ¶msInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo)
void RunClFunction(arm_compute::IFunction &function, const CheckLocation &location)
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
void ValidateInputsOutputs(const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
Contains information about TensorInfos of a layer.