12 using namespace armnn::armcomputetensorutils;
26 indices_W_ND_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
27 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
31 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
32 const arm_compute::TensorInfo aclFlattenedCoeffInfo = BuildArmComputeTensorInfo(flattenedCoeff_Info);
35 const arm_compute::TensorInfo aclOutputMulInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
37 auto statusMul = arm_compute::CLPixelWiseMultiplication::validate(&aclIndicesInfo,
38 &aclFlattenedCoeffInfo,
41 arm_compute::ConvertPolicy::WRAP,
42 arm_compute::RoundingPolicy::TO_ZERO,
43 arm_compute::ActivationLayerInfo());
48 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
49 const arm_compute::TensorInfo aclFlattenedIndicesInfo = BuildArmComputeTensorInfo(flattenedIndices_Info);
51 const std::vector<unsigned int> armnnReduceAxes(1, 1);
56 auto statusReduceSum = arm_compute::CLReductionOperation::validate(&aclOutputMulInfo,
57 &aclFlattenedIndicesInfo,
58 static_cast<unsigned int>(coords[0]),
59 arm_compute::ReductionOperation::SUM,
65 params_K_C_Info.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
66 const arm_compute::TensorInfo aclParamsInfo = BuildArmComputeTensorInfo(params_K_C_Info);
70 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
71 const arm_compute::TensorInfo aclOutputGatherInfo = BuildArmComputeTensorInfo(outputGather_Info);
75 arm_compute::CLGather::validate(&aclParamsInfo, &aclFlattenedIndicesInfo, &aclOutputGatherInfo, aclAxis);
78 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
80 auto statusReshape = arm_compute::CLReshapeLayer::validate(&aclOutputGatherInfo, &aclOutputInfo);
83 auto okCode = arm_compute::ErrorCode::OK;
84 if (statusMul.error_code() == okCode &&
85 statusReduceSum.error_code() == okCode &&
86 statusGather.error_code() == okCode &&
87 statusReshape.error_code() == okCode)
90 "All GatherND layers validate status OK.");
95 "GatherND layer validate status failed.");
101 const arm_compute::CLCompileContext& clCompileContext)
123 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
124 BuildArmComputeTensor(m_FlattenedIndices, flattenedIndices_Info);
125 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedIndices);
128 indices.info()->set_tensor_shape(BuildArmComputeTensorShape({ keyIndices[
"W"], keyIndices[
"ND"] }));
132 std::vector<int32_t> flattenedCoeff(keyIndices[
"ND"], 1);
133 for (
unsigned int i = 1; i < keyIndices[
"ND"]; ++i)
135 flattenedCoeff[i - 1] =
static_cast<int32_t
>(paramsShape[i]);
137 for (
unsigned int i = keyIndices[
"ND"] - 1; i > 0; --i)
139 flattenedCoeff[i - 1] *= flattenedCoeff[i];
142 flattenedCoeff_Info.
SetShape({ keyIndices[
"ND"] });
143 BuildArmComputeTensor(m_FlattenedCoeff, flattenedCoeff_Info);
144 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedCoeff);
146 "flattenedCoeff must be same data type as m_FlattenedCoeff");
147 CopyArmComputeClTensorData<int32_t>(m_FlattenedCoeff, flattenedCoeff.data());
151 outputMul_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"ND"] });
152 BuildArmComputeTensor(m_OutputMul, outputMul_Info);
153 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputMul);
156 m_MulLayer.configure(clCompileContext,
161 arm_compute::ConvertPolicy::WRAP,
162 arm_compute::RoundingPolicy::TO_ZERO,
163 arm_compute::ActivationLayerInfo());
166 const std::vector<unsigned int> armnnReduceAxes(1, 1);
170 m_ReduceSumLayer.configure(clCompileContext,
173 static_cast<unsigned int>(coords[0]),
174 arm_compute::ReductionOperation::SUM,
179 paramsInfo.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
180 input.info()->set_tensor_shape(BuildArmComputeTensorShape(paramsInfo.
GetShape()));
185 outputGather_Info.
SetShape({ keyIndices[
"W"], keyIndices[
"C"] });
186 BuildArmComputeTensor(m_OutputGather, outputGather_Info);
187 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputGather);
191 m_GatherLayer.configure(clCompileContext, &input, &m_FlattenedIndices, &m_OutputGather, aclAxis);
195 m_ReshapeLayer.configure(clCompileContext, &m_OutputGather, &output);