20{
22 "ConvertGatherToTosaOperator: Gather must have two inputs");
23
25 "ConvertGatherToTosaOperator: Gather must have only one output");
26
27 unsigned int paramsRank = inputs[0]->GetNumDimensions();
28 unsigned int indicesRank = inputs[1]->GetNumDimensions();
29
30 int batch_dims = 0;
31
33 gatherDescriptor->
m_Axis <
static_cast<int32_t
>(paramsRank),
34 "ConvertGatherToTosaOperator: axis must be < values rank");
35
37 "ConvertGatherToTosaOperator: batch dimensions must be <= indices rank");
38
40 "ConvertGatherToTosaOperator: axis must be >= batch dimensions.");
41
43 "ConvertGatherToTosaOperator: Tosa gather does not support unsigned types.");
44
46 "ConvertGatherToTosaOperator: Tosa gather does not support int 64 indices.");
47
48 unsigned int axis =
static_cast<unsigned int>(gatherDescriptor->
m_Axis);
49 unsigned int batchDims = static_cast<unsigned int>(batch_dims);
50
51 std::string inputParamsName = std::string("input_0_params");
52 std::string inputIndicesName = std::string("input_1_indices");
53 std::string outputTransposeParamsName = std::string(
"intermediate_0_transpose_params") +
GetUniqueTosaMappingID();
54 std::string outputReshapeParamsName = std::string(
"intermediate_1_reshape_params") +
GetUniqueTosaMappingID();
55 std::string outputReshapeIndicesName = std::string(
"intermediate_2_reshape_indices") +
GetUniqueTosaMappingID();
57 std::string outputReshapeGatherName = std::string(
"intermediate_4_reshape_gather") +
GetUniqueTosaMappingID();
58 std::string outputName = std::string("output_0");
59
61
62 std::vector<TosaSerializationTensor*> tensors;
63 std::vector<TosaSerializationOperator*> operators;
64
65
66
67 if(layer)
68 {
69
73 }
74
75 auto inputParamsDType =
ArmNNToDType(inputs[0]->GetDataType());
76 auto inputIndicesDType =
ArmNNToDType(inputs[1]->GetDataType());
77
78
79
80
81 if(inputParamsName.find("input_") != std::string::npos)
82 {
84 tensors.push_back(new TosaSerializationTensor(inputParamsName, inputParamsShape, inputParamsDType, {}));
85 }
86 if(inputIndicesName.find("input_") != std::string::npos)
87 {
89 tensors.push_back(new TosaSerializationTensor(inputIndicesName, inputIndicesShape, inputIndicesDType, {}));
90 }
91
93 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
94 tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
95
98
99
100 std::vector<int32_t> paramsBatch;
101 std::vector<int32_t> paramsIndices;
102 std::vector<int32_t> paramsLeftChannels;
103 std::vector<int32_t> paramsRightChannels;
104
105 std::vector<int32_t> paramsIdxBatch;
106 std::vector<int32_t> paramsIdxIndices;
107 std::vector<int32_t> paramsIdxLeftChannels;
108 std::vector<int32_t> paramsIdxRightChannels;
109
110 for (unsigned int i = 0; i < paramsRank; i++)
111 {
112 if (i < batchDims && i < axis)
113 {
114 paramsBatch.push_back(paramsShape[i]);
115 paramsIdxBatch.push_back(static_cast<int32_t>(i));
116 }
117 else if (i < axis)
118 {
119 paramsLeftChannels.push_back(paramsShape[i]);
120 paramsIdxLeftChannels.push_back(static_cast<int32_t>(i));
121 }
122 else if (i < (axis + 1))
123 {
124 paramsIndices.push_back(paramsShape[i]);
125 paramsIdxIndices.push_back(static_cast<int32_t>(i));
126 }
127 else
128 {
129 paramsRightChannels.push_back(paramsShape[i]);
130 paramsIdxRightChannels.push_back(static_cast<int32_t>(i));
131 }
132 }
133
134
135
136
137
138
139 std::vector<int32_t> paramsLow;
140 std::vector<int32_t> paramsMid;
141 std::vector<int32_t> paramsHigh;
142 std::vector<int32_t> indicesMid;
143
144
145 for (unsigned int i = 0; i < batchDims; i++)
146 {
147 paramsLow.push_back(paramsShape[i]);
148 }
149
150 for (unsigned int i = 0; i < (axis - batchDims); i++)
151 {
152 paramsMid.push_back(paramsShape[batchDims + i]);
153 }
154
155 for (unsigned int i = 0; i < (paramsRank - axis - 1); i++)
156 {
157 paramsHigh.push_back(paramsShape[axis + 1 + i]);
158 }
159
160 for (unsigned int i = 0; i < (indicesRank - batchDims); i++)
161 {
162 indicesMid.push_back(indicesShape[batchDims + i]);
163 }
164
165 auto lowProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsMid),
166 std::end(paramsMid),
167 1,
168 std::multiplies<>() ));
169 auto highProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsHigh),
170 std::end(paramsHigh),
171 1,
172 std::multiplies<>() ));
173
174 auto N = static_cast<int32_t>(std::accumulate(std::begin(paramsLow),
175 std::end(paramsLow),
176 1,
177 std::multiplies<>() ));
178 auto W = static_cast<int32_t>(std::accumulate(std::begin(indicesMid),
179 std::end(indicesMid),
180 1,
181 std::multiplies<>() ));
182 auto K = paramsShape[axis];
183 auto C = lowProduct * highProduct;
184
185
186 std::vector<int32_t> inputTransposePermutation;
187 std::vector<int32_t> inputTransposeShape;
188 for (unsigned int i = 0; i < paramsBatch.size(); i++)
189 {
190 inputTransposePermutation.push_back(paramsIdxBatch[i]);
191 inputTransposeShape.push_back(paramsBatch[i]);
192 }
193 for (unsigned int i = 0; i < paramsIndices.size(); i++)
194 {
195 inputTransposePermutation.push_back(paramsIdxIndices[i]);
196 inputTransposeShape.push_back(paramsIndices[i]);
197 }
198 for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
199 {
200 inputTransposePermutation.push_back(paramsIdxLeftChannels[i]);
201 inputTransposeShape.push_back(paramsLeftChannels[i]);
202 }
203 for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
204 {
205 inputTransposePermutation.push_back(paramsIdxRightChannels[i]);
206 inputTransposeShape.push_back(paramsRightChannels[i]);
207 }
208
209
210 std::vector<int32_t> resultReshapeShape;
211 resultReshapeShape.insert(resultReshapeShape.end(), indicesShape.begin(), indicesShape.end());
212 resultReshapeShape.insert(resultReshapeShape.end(), paramsLeftChannels.begin(), paramsLeftChannels.end());
213 resultReshapeShape.insert(resultReshapeShape.end(), paramsRightChannels.begin(), paramsRightChannels.end());
214
215 std::vector<int32_t> resultTransposePerm;
216 for (unsigned int i = 0; i < batchDims; i++)
217 {
218 resultTransposePerm.push_back(static_cast<int32_t>(i));
219 }
220 for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
221 {
222 resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions()));
223 }
224 for (unsigned int i = batchDims; i < inputs[1]->GetNumDimensions(); i++)
225 {
226 resultTransposePerm.push_back(static_cast<int32_t>(i));
227 }
228 for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
229 {
230 resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions() +
231 paramsLeftChannels.size()));
232 }
233
234 std::vector<int32_t> tosaValuesShape = {N, K, C};
235 std::vector<int32_t> tosaIndicesShape = {N, W};
236 std::vector<int32_t> tosaGatherResultShape = {N, W, C};
237
238
239 if (axis > 0)
240 {
241 tensors.emplace_back(new TosaSerializationTensor(outputTransposeParamsName,
242 inputTransposeShape,
243 inputParamsDType,
244 {}));
245
246 TosaTransposeAttribute transposeInputAttribute(inputTransposePermutation);
247
248 auto *transposeInputOp = new TosaSerializationOperator(Op_TRANSPOSE,
249 Attribute_TransposeAttribute,
250 &transposeInputAttribute,
251 {inputParamsName},
252 {outputTransposeParamsName});
253 operators.push_back(transposeInputOp);
254 }
255
256
257 std::string& reshapeOpInputParamsName = axis > 0 ? outputTransposeParamsName : inputParamsName;
258
259 tensors.emplace_back(new TosaSerializationTensor(outputReshapeParamsName,
260 tosaValuesShape,
261 inputParamsDType,
262 {}));
263
264 TosaReshapeAttribute reshapeValuesAttribute(tosaValuesShape);
265
266 auto* reshapeInputParamsOp = new TosaSerializationOperator(Op_RESHAPE,
267 Attribute_ReshapeAttribute,
268 &reshapeValuesAttribute,
269 {reshapeOpInputParamsName},
270 {outputReshapeParamsName});
271 operators.push_back(reshapeInputParamsOp);
272
273
274 tensors.emplace_back(new TosaSerializationTensor(outputReshapeIndicesName,
275 tosaIndicesShape,
276 inputIndicesDType,
277 {}));
278
279 TosaReshapeAttribute reshapeIndicesAttribute(tosaIndicesShape);
280
281 auto* reshapeInputIndicesOp = new TosaSerializationOperator(Op_RESHAPE,
282 Attribute_ReshapeAttribute,
283 &reshapeIndicesAttribute,
284 {inputIndicesName},
285 {outputReshapeIndicesName});
286 operators.push_back(reshapeInputIndicesOp);
287
288
289 tensors.emplace_back(new TosaSerializationTensor(outputGatherName,
290 tosaGatherResultShape,
291 inputParamsDType,
292 {}));
293
294 auto* gatherOp = new TosaSerializationOperator(Op_GATHER,
295 Attribute_NONE,
296 nullptr,
297 {outputReshapeParamsName, outputReshapeIndicesName},
298 {outputGatherName});
299 operators.push_back(gatherOp);
300
301
302 if (axis > 0)
303 {
304
305 tensors.emplace_back(new TosaSerializationTensor(outputReshapeGatherName,
306 resultReshapeShape,
307 outputDType0,
308 {}));
309 }
310
311 std::string& reshapeOpOutputName = axis > 0 ? outputReshapeGatherName : outputName;
312
313 TosaReshapeAttribute reshapeGatherAttribute(resultReshapeShape);
314
315 auto* reshapeGatherOutputOp = new TosaSerializationOperator(Op_RESHAPE,
316 Attribute_ReshapeAttribute,
317 &reshapeGatherAttribute,
318 {outputGatherName},
319 {reshapeOpOutputName});
320 operators.push_back(reshapeGatherOutputOp);
321
322
323 if (axis > 0)
324 {
325 TosaTransposeAttribute transposeOutputAttribute(resultTransposePerm);
326
327 auto* transposeOutputOp = new TosaSerializationOperator(Op_TRANSPOSE,
328 Attribute_TransposeAttribute,
329 &transposeOutputAttribute,
330 {outputReshapeGatherName},
331 {outputName});
332 operators.push_back(transposeOutputOp);
333 }
334
335
336
337 return new TosaSerializationBasicBlock(blockName,
339 operators,
340 tensors,
341 {inputParamsName, inputIndicesName},
342 {outputName});
343}
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
int32_t m_Axis
The axis in params to gather indices from.