12 using namespace armnn::armcomputetensorutils;
26 indices_W_ND_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
31 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
32 const arm_compute::TensorInfo aclFlattenedCoeffInfo = BuildArmComputeTensorInfo(flattenedCoeff_Info);
35 const arm_compute::TensorInfo aclOutputMulInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
37 auto statusMul = arm_compute::CLPixelWiseMultiplication::validate(&aclIndicesInfo,
38 &aclFlattenedCoeffInfo,
41 arm_compute::ConvertPolicy::WRAP,
42 arm_compute::RoundingPolicy::TO_ZERO,
43 arm_compute::ActivationLayerInfo());
48 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
49 const arm_compute::TensorInfo aclFlattenedIndicesInfo = BuildArmComputeTensorInfo(flattenedIndices_Info);
51 const std::vector<unsigned int> armnnReduceAxes(1, 1);
56 auto statusReduceSum = arm_compute::CLReductionOperation::validate(&aclOutputMulInfo,
57 &aclFlattenedIndicesInfo,
58 static_cast<unsigned int>(coords[0]),
59 arm_compute::ReductionOperation::SUM,
65 params_K_C_Info.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
66 const arm_compute::TensorInfo aclParamsInfo = BuildArmComputeTensorInfo(params_K_C_Info);
70 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
71 const arm_compute::TensorInfo aclOutputGatherInfo = BuildArmComputeTensorInfo(outputGather_Info);
75 arm_compute::CLGather::validate(&aclParamsInfo, &aclFlattenedIndicesInfo, &aclOutputGatherInfo, aclAxis);
78 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
80 auto statusReshape = arm_compute::CLReshapeLayer::validate(&aclOutputGatherInfo, &aclOutputInfo);
83 auto okCode = arm_compute::ErrorCode::OK;
84 if (statusMul.error_code() == okCode &&
85 statusReduceSum.error_code() == okCode &&
86 statusGather.error_code() == okCode &&
87 statusReshape.error_code() == okCode)
90 "All GatherND layers validate status OK.");
95 "GatherND layer validate status failed.");
101 const arm_compute::CLCompileContext& clCompileContext)
123 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
124 BuildArmComputeTensor(m_FlattenedIndices, flattenedIndices_Info);
125 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedIndices);
128 indices.info()->set_tensor_shape(BuildArmComputeTensorShape({ keyIndices[
"W"], keyIndices[
"ND"] }));
132 std::vector<int32_t> flattenedCoeff(keyIndices[
"ND"], 1);
133 for (
unsigned int i = 1; i < keyIndices[
"ND"]; ++i)
135 flattenedCoeff[i - 1] =
static_cast<int32_t
>(paramsShape[i]);
137 for (
unsigned int i = keyIndices[
"ND"] - 1; i > 0; --i)
139 flattenedCoeff[i - 1] *= flattenedCoeff[i];
142 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
143 BuildArmComputeTensor(m_FlattenedCoeff, flattenedCoeff_Info);
144 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedCoeff);
145 CopyArmComputeClTensorData<int32_t>(m_FlattenedCoeff, flattenedCoeff.data());
149 outputMul_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
150 BuildArmComputeTensor(m_OutputMul, outputMul_Info);
151 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputMul);
154 m_MulLayer.configure(clCompileContext,
159 arm_compute::ConvertPolicy::WRAP,
160 arm_compute::RoundingPolicy::TO_ZERO,
161 arm_compute::ActivationLayerInfo());
164 const std::vector<unsigned int> armnnReduceAxes(1, 1);
168 m_ReduceSumLayer.configure(clCompileContext,
171 static_cast<unsigned int>(coords[0]),
172 arm_compute::ReductionOperation::SUM,
177 paramsInfo.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
178 input.info()->set_tensor_shape(BuildArmComputeTensorShape(paramsInfo.
GetShape()));
183 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
184 BuildArmComputeTensor(m_OutputGather, outputGather_Info);
185 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputGather);
189 m_GatherLayer.configure(clCompileContext, &input, &m_FlattenedIndices, &m_OutputGather, aclAxis);
193 m_ReshapeLayer.configure(clCompileContext, &m_OutputGather, &output);