22 "ConvertGatherToTosaOperator: Gather must have two inputs");
25 "ConvertGatherToTosaOperator: Gather must have only one output");
27 unsigned int paramsRank = inputs[0]->GetNumDimensions();
28 unsigned int indicesRank = inputs[1]->GetNumDimensions();
33 gatherDescriptor->
m_Axis <
static_cast<int32_t
>(paramsRank),
34 "ConvertGatherToTosaOperator: axis must be < values rank");
37 "ConvertGatherToTosaOperator: batch dimensions must be <= indices rank");
40 "ConvertGatherToTosaOperator: axis must be >= batch dimensions.");
43 "ConvertGatherToTosaOperator: Tosa gather does not support unsigned types.");
46 "ConvertGatherToTosaOperator: Tosa gather does not support int 64 indices.");
48 unsigned int axis =
static_cast<unsigned int>(gatherDescriptor->
m_Axis);
49 unsigned int batchDims =
static_cast<unsigned int>(batch_dims);
51 std::string inputParamsName = std::string(
"input_0_params");
52 std::string inputIndicesName = std::string(
"input_1_indices");
53 std::string outputTransposeParamsName = std::string(
"intermediate_0_transpose_params") +
GetUniqueTosaMappingID();
54 std::string outputReshapeParamsName = std::string(
"intermediate_1_reshape_params") +
GetUniqueTosaMappingID();
55 std::string outputReshapeIndicesName = std::string(
"intermediate_2_reshape_indices") +
GetUniqueTosaMappingID();
57 std::string outputReshapeGatherName = std::string(
"intermediate_4_reshape_gather") +
GetUniqueTosaMappingID();
58 std::string outputName = std::string(
"output_0");
62 std::vector<TosaSerializationTensor*> tensors;
63 std::vector<TosaSerializationOperator*> operators;
75 auto inputParamsDType =
ArmNNToDType(inputs[0]->GetDataType());
76 auto inputIndicesDType =
ArmNNToDType(inputs[1]->GetDataType());
81 if(inputParamsName.find(
"input_") != std::string::npos)
84 tensors.push_back(
new TosaSerializationTensor(inputParamsName, inputParamsShape, inputParamsDType, {}));
86 if(inputIndicesName.find(
"input_") != std::string::npos)
89 tensors.push_back(
new TosaSerializationTensor(inputIndicesName, inputIndicesShape, inputIndicesDType, {}));
93 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
94 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
100 std::vector<int32_t> paramsBatch;
101 std::vector<int32_t> paramsIndices;
102 std::vector<int32_t> paramsLeftChannels;
103 std::vector<int32_t> paramsRightChannels;
105 std::vector<int32_t> paramsIdxBatch;
106 std::vector<int32_t> paramsIdxIndices;
107 std::vector<int32_t> paramsIdxLeftChannels;
108 std::vector<int32_t> paramsIdxRightChannels;
110 for (
unsigned int i = 0; i < paramsRank; i++)
112 if (i < batchDims && i < axis)
114 paramsBatch.push_back(paramsShape[i]);
115 paramsIdxBatch.push_back(
static_cast<int32_t
>(i));
119 paramsLeftChannels.push_back(paramsShape[i]);
120 paramsIdxLeftChannels.push_back(
static_cast<int32_t
>(i));
122 else if (i < (axis + 1))
124 paramsIndices.push_back(paramsShape[i]);
125 paramsIdxIndices.push_back(
static_cast<int32_t
>(i));
129 paramsRightChannels.push_back(paramsShape[i]);
130 paramsIdxRightChannels.push_back(
static_cast<int32_t
>(i));
139 std::vector<int32_t> paramsLow;
140 std::vector<int32_t> paramsMid;
141 std::vector<int32_t> paramsHigh;
142 std::vector<int32_t> indicesMid;
145 for (
unsigned int i = 0; i < batchDims; i++)
147 paramsLow.push_back(paramsShape[i]);
150 for (
unsigned int i = 0; i < (axis - batchDims); i++)
152 paramsMid.push_back(paramsShape[batchDims + i]);
155 for (
unsigned int i = 0; i < (paramsRank - axis - 1); i++)
157 paramsHigh.push_back(paramsShape[axis + 1 + i]);
160 for (
unsigned int i = 0; i < (indicesRank - batchDims); i++)
162 indicesMid.push_back(indicesShape[batchDims + i]);
165 auto lowProduct =
static_cast<int32_t
>(std::accumulate(std::begin(paramsMid),
168 std::multiplies<>() ));
169 auto highProduct =
static_cast<int32_t
>(std::accumulate(std::begin(paramsHigh),
170 std::end(paramsHigh),
172 std::multiplies<>() ));
174 auto N =
static_cast<int32_t
>(std::accumulate(std::begin(paramsLow),
177 std::multiplies<>() ));
178 auto W =
static_cast<int32_t
>(std::accumulate(std::begin(indicesMid),
179 std::end(indicesMid),
181 std::multiplies<>() ));
182 auto K = paramsShape[axis];
183 auto C = lowProduct * highProduct;
186 std::vector<int32_t> inputTransposePermutation;
187 std::vector<int32_t> inputTransposeShape;
188 for (
unsigned int i = 0; i < paramsBatch.size(); i++)
190 inputTransposePermutation.push_back(paramsIdxBatch[i]);
191 inputTransposeShape.push_back(paramsBatch[i]);
193 for (
unsigned int i = 0; i < paramsIndices.size(); i++)
195 inputTransposePermutation.push_back(paramsIdxIndices[i]);
196 inputTransposeShape.push_back(paramsIndices[i]);
198 for (
unsigned int i = 0; i < paramsLeftChannels.size(); i++)
200 inputTransposePermutation.push_back(paramsIdxLeftChannels[i]);
201 inputTransposeShape.push_back(paramsLeftChannels[i]);
203 for (
unsigned int i = 0; i < paramsRightChannels.size(); i++)
205 inputTransposePermutation.push_back(paramsIdxRightChannels[i]);
206 inputTransposeShape.push_back(paramsRightChannels[i]);
210 std::vector<int32_t> resultReshapeShape;
211 resultReshapeShape.insert(resultReshapeShape.end(), indicesShape.begin(), indicesShape.end());
212 resultReshapeShape.insert(resultReshapeShape.end(), paramsLeftChannels.begin(), paramsLeftChannels.end());
213 resultReshapeShape.insert(resultReshapeShape.end(), paramsRightChannels.begin(), paramsRightChannels.end());
215 std::vector<int32_t> resultTransposePerm;
216 for (
unsigned int i = 0; i < batchDims; i++)
218 resultTransposePerm.push_back(
static_cast<int32_t
>(i));
220 for (
unsigned int i = 0; i < paramsLeftChannels.size(); i++)
222 resultTransposePerm.push_back(
static_cast<int32_t
>(i + inputs[1]->GetNumDimensions()));
224 for (
unsigned int i = batchDims; i < inputs[1]->GetNumDimensions(); i++)
226 resultTransposePerm.push_back(
static_cast<int32_t
>(i));
228 for (
unsigned int i = 0; i < paramsRightChannels.size(); i++)
230 resultTransposePerm.push_back(
static_cast<int32_t
>(i + inputs[1]->GetNumDimensions() +
231 paramsLeftChannels.size()));
234 std::vector<int32_t> tosaValuesShape = {N, K, C};
235 std::vector<int32_t> tosaIndicesShape = {N, W};
236 std::vector<int32_t> tosaGatherResultShape = {N, W, C};
241 tensors.emplace_back(
new TosaSerializationTensor(outputTransposeParamsName,
246 TosaTransposeAttribute transposeInputAttribute(inputTransposePermutation);
248 auto *transposeInputOp =
new TosaSerializationOperator(Op_TRANSPOSE,
249 Attribute_TransposeAttribute,
250 &transposeInputAttribute,
252 {outputTransposeParamsName});
253 operators.push_back(transposeInputOp);
257 std::string& reshapeOpInputParamsName = axis > 0 ? outputTransposeParamsName : inputParamsName;
259 tensors.emplace_back(
new TosaSerializationTensor(outputReshapeParamsName,
264 TosaReshapeAttribute reshapeValuesAttribute(tosaValuesShape);
266 auto* reshapeInputParamsOp =
new TosaSerializationOperator(Op_RESHAPE,
267 Attribute_ReshapeAttribute,
268 &reshapeValuesAttribute,
269 {reshapeOpInputParamsName},
270 {outputReshapeParamsName});
271 operators.push_back(reshapeInputParamsOp);
274 tensors.emplace_back(
new TosaSerializationTensor(outputReshapeIndicesName,
279 TosaReshapeAttribute reshapeIndicesAttribute(tosaIndicesShape);
281 auto* reshapeInputIndicesOp =
new TosaSerializationOperator(Op_RESHAPE,
282 Attribute_ReshapeAttribute,
283 &reshapeIndicesAttribute,
285 {outputReshapeIndicesName});
286 operators.push_back(reshapeInputIndicesOp);
289 tensors.emplace_back(
new TosaSerializationTensor(outputGatherName,
290 tosaGatherResultShape,
294 auto* gatherOp =
new TosaSerializationOperator(Op_GATHER,
297 {outputReshapeParamsName, outputReshapeIndicesName},
299 operators.push_back(gatherOp);
305 tensors.emplace_back(
new TosaSerializationTensor(outputReshapeGatherName,
311 std::string& reshapeOpOutputName = axis > 0 ? outputReshapeGatherName : outputName;
313 TosaReshapeAttribute reshapeGatherAttribute(resultReshapeShape);
315 auto* reshapeGatherOutputOp =
new TosaSerializationOperator(Op_RESHAPE,
316 Attribute_ReshapeAttribute,
317 &reshapeGatherAttribute,
319 {reshapeOpOutputName});
320 operators.push_back(reshapeGatherOutputOp);
325 TosaTransposeAttribute transposeOutputAttribute(resultTransposePerm);
327 auto* transposeOutputOp =
new TosaSerializationOperator(Op_TRANSPOSE,
328 Attribute_TransposeAttribute,
329 &transposeOutputAttribute,
330 {outputReshapeGatherName},
332 operators.push_back(transposeOutputOp);
337 return new TosaSerializationBasicBlock(blockName,
341 {inputParamsName, inputIndicesName},
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
int32_t m_Axis
The axis in params to gather indices from.