ArmNN
 25.11
Loading...
Searching...
No Matches
GatherOperator.hpp File Reference
Include dependency graph for GatherOperator.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertGatherToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const GatherDescriptor *gatherDescriptor)

Function Documentation

◆ ConvertGatherToTosaOperator()

TosaSerializationBasicBlock * ConvertGatherToTosaOperator ( const Layer * layer,
const std::vector< const TensorInfo * > & inputs,
const std::vector< const TensorInfo * > & outputs,
const GatherDescriptor * gatherDescriptor )

Definition at line 16 of file GatherOperator.cpp.

20{
21 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs.size() == 2,
22 "ConvertGatherToTosaOperator: Gather must have two inputs");
23
24 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(outputs.size() == 1,
25 "ConvertGatherToTosaOperator: Gather must have only one output");
26
27 unsigned int paramsRank = inputs[0]->GetNumDimensions();
28 unsigned int indicesRank = inputs[1]->GetNumDimensions();
29
30 int batch_dims = 0; // ArmNN does not currently support this parameter, setting it to the default value.
31
32 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(gatherDescriptor->m_Axis >= 0 &&
33 gatherDescriptor->m_Axis < static_cast<int32_t>(paramsRank),
34 "ConvertGatherToTosaOperator: axis must be < values rank");
35
36 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(batch_dims <= static_cast<int32_t>(indicesRank),
37 "ConvertGatherToTosaOperator: batch dimensions must be <= indices rank");
38
39 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(gatherDescriptor->m_Axis >= batch_dims,
40 "ConvertGatherToTosaOperator: axis must be >= batch dimensions.");
41
42 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs[0]->GetDataType() != DataType::QAsymmU8,
43 "ConvertGatherToTosaOperator: Tosa gather does not support unsigned types.");
44
45 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs[1]->GetDataType() != DataType::Signed64,
46 "ConvertGatherToTosaOperator: Tosa gather does not support int 64 indices.");
47
48 unsigned int axis = static_cast<unsigned int>(gatherDescriptor->m_Axis);
49 unsigned int batchDims = static_cast<unsigned int>(batch_dims);
50
51 std::string inputParamsName = std::string("input_0_params");
52 std::string inputIndicesName = std::string("input_1_indices");
53 std::string outputTransposeParamsName = std::string("intermediate_0_transpose_params") + GetUniqueTosaMappingID();
54 std::string outputReshapeParamsName = std::string("intermediate_1_reshape_params") + GetUniqueTosaMappingID();
55 std::string outputReshapeIndicesName = std::string("intermediate_2_reshape_indices") + GetUniqueTosaMappingID();
56 std::string outputGatherName = std::string("intermediate_3_gather") + GetUniqueTosaMappingID();
57 std::string outputReshapeGatherName = std::string("intermediate_4_reshape_gather") + GetUniqueTosaMappingID();
58 std::string outputName = std::string("output_0");
59
60 std::string blockName = std::string("Op_GATHER_block_") + GetUniqueTosaMappingID();
61
62 std::vector<TosaSerializationTensor*> tensors;
63 std::vector<TosaSerializationOperator*> operators;
64
65 // If a layer is present then the block will be used for execution, so input and output names need to be determined
66 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
67 if(layer)
68 {
69 // Get the layer connected to the input slot and determine unique tensor names.
70 inputParamsName = GenerateUniqueInputName(layer->GetInputSlot(0));
71 inputIndicesName = GenerateUniqueInputName(layer->GetInputSlot(1));
72 outputName = GenerateUniqueOutputName(*layer);
73 }
74
75 auto inputParamsDType = ArmNNToDType(inputs[0]->GetDataType());
76 auto inputIndicesDType = ArmNNToDType(inputs[1]->GetDataType());
77
78 // Only add input tensors if connected layer is an input layer.
79 // As intermediate or constant tensors will be created separately.
80 // There also can't be duplicate tensor.
81 if(inputParamsName.find("input_") != std::string::npos)
82 {
83 std::vector<int32_t> inputParamsShape = GetTosaTensorShape(inputs[0]->GetShape());
84 tensors.push_back(new TosaSerializationTensor(inputParamsName, inputParamsShape, inputParamsDType, {}));
85 }
86 if(inputIndicesName.find("input_") != std::string::npos)
87 {
88 std::vector<int32_t> inputIndicesShape = GetTosaTensorShape(inputs[1]->GetShape());
89 tensors.push_back(new TosaSerializationTensor(inputIndicesName, inputIndicesShape, inputIndicesDType, {}));
90 }
91
92 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
93 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
94 tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
95
96 std::vector<int32_t> paramsShape = GetTosaTensorShape(inputs[0]->GetShape());
97 std::vector<int32_t> indicesShape = GetTosaTensorShape(inputs[1]->GetShape());
98
99 // Parameters needed to calculate output shapes and transpose permutations
100 std::vector<int32_t> paramsBatch;
101 std::vector<int32_t> paramsIndices;
102 std::vector<int32_t> paramsLeftChannels;
103 std::vector<int32_t> paramsRightChannels;
104
105 std::vector<int32_t> paramsIdxBatch;
106 std::vector<int32_t> paramsIdxIndices;
107 std::vector<int32_t> paramsIdxLeftChannels;
108 std::vector<int32_t> paramsIdxRightChannels;
109
110 for (unsigned int i = 0; i < paramsRank; i++)
111 {
112 if (i < batchDims && i < axis)
113 {
114 paramsBatch.push_back(paramsShape[i]);
115 paramsIdxBatch.push_back(static_cast<int32_t>(i));
116 }
117 else if (i < axis)
118 {
119 paramsLeftChannels.push_back(paramsShape[i]);
120 paramsIdxLeftChannels.push_back(static_cast<int32_t>(i));
121 }
122 else if (i < (axis + 1))
123 {
124 paramsIndices.push_back(paramsShape[i]);
125 paramsIdxIndices.push_back(static_cast<int32_t>(i));
126 }
127 else
128 {
129 paramsRightChannels.push_back(paramsShape[i]);
130 paramsIdxRightChannels.push_back(static_cast<int32_t>(i));
131 }
132 }
133
134 // Calculate N, K, W, C
135 // N: number of batches
136 // W: number of indices in each batch
137 // K: range of each index
138 // C: number of channels for each index
139 std::vector<int32_t> paramsLow;
140 std::vector<int32_t> paramsMid;
141 std::vector<int32_t> paramsHigh;
142 std::vector<int32_t> indicesMid;
143
144 // Copy the first batchDims number of paramsShape values to paramsLow
145 for (unsigned int i = 0; i < batchDims; i++)
146 {
147 paramsLow.push_back(paramsShape[i]);
148 }
149 // Starting at batchDims index, copy the next (axis - batchDims) number of paramsShape values to paramsMid
150 for (unsigned int i = 0; i < (axis - batchDims); i++)
151 {
152 paramsMid.push_back(paramsShape[batchDims + i]);
153 }
154 // Starting at (axis + 1) index, copy the next (paramsRank - axis - 1) number of paramsShape values to paramsHigh
155 for (unsigned int i = 0; i < (paramsRank - axis - 1); i++)
156 {
157 paramsHigh.push_back(paramsShape[axis + 1 + i]);
158 }
159 // Starting at batchDims index, copy the next (indicesRank - batchDims) number of indicesShape values to indicesMid
160 for (unsigned int i = 0; i < (indicesRank - batchDims); i++)
161 {
162 indicesMid.push_back(indicesShape[batchDims + i]);
163 }
164
165 auto lowProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsMid),
166 std::end(paramsMid),
167 1,
168 std::multiplies<>() ));
169 auto highProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsHigh),
170 std::end(paramsHigh),
171 1,
172 std::multiplies<>() ));
173
174 auto N = static_cast<int32_t>(std::accumulate(std::begin(paramsLow),
175 std::end(paramsLow),
176 1,
177 std::multiplies<>() ));
178 auto W = static_cast<int32_t>(std::accumulate(std::begin(indicesMid),
179 std::end(indicesMid),
180 1,
181 std::multiplies<>() ));
182 auto K = paramsShape[axis];
183 auto C = lowProduct * highProduct;
184
185 // Parameters needed for input transpose
186 std::vector<int32_t> inputTransposePermutation;
187 std::vector<int32_t> inputTransposeShape;
188 for (unsigned int i = 0; i < paramsBatch.size(); i++)
189 {
190 inputTransposePermutation.push_back(paramsIdxBatch[i]);
191 inputTransposeShape.push_back(paramsBatch[i]);
192 }
193 for (unsigned int i = 0; i < paramsIndices.size(); i++)
194 {
195 inputTransposePermutation.push_back(paramsIdxIndices[i]);
196 inputTransposeShape.push_back(paramsIndices[i]);
197 }
198 for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
199 {
200 inputTransposePermutation.push_back(paramsIdxLeftChannels[i]);
201 inputTransposeShape.push_back(paramsLeftChannels[i]);
202 }
203 for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
204 {
205 inputTransposePermutation.push_back(paramsIdxRightChannels[i]);
206 inputTransposeShape.push_back(paramsRightChannels[i]);
207 }
208
209 // Parameters needed for result/output transpose
210 std::vector<int32_t> resultReshapeShape;
211 resultReshapeShape.insert(resultReshapeShape.end(), indicesShape.begin(), indicesShape.end());
212 resultReshapeShape.insert(resultReshapeShape.end(), paramsLeftChannels.begin(), paramsLeftChannels.end());
213 resultReshapeShape.insert(resultReshapeShape.end(), paramsRightChannels.begin(), paramsRightChannels.end());
214
215 std::vector<int32_t> resultTransposePerm;
216 for (unsigned int i = 0; i < batchDims; i++)
217 {
218 resultTransposePerm.push_back(static_cast<int32_t>(i));
219 }
220 for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
221 {
222 resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions()));
223 }
224 for (unsigned int i = batchDims; i < inputs[1]->GetNumDimensions(); i++)
225 {
226 resultTransposePerm.push_back(static_cast<int32_t>(i));
227 }
228 for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
229 {
230 resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions() +
231 paramsLeftChannels.size()));
232 }
233
234 std::vector<int32_t> tosaValuesShape = {N, K, C};
235 std::vector<int32_t> tosaIndicesShape = {N, W};
236 std::vector<int32_t> tosaGatherResultShape = {N, W, C};
237
238 // 1. Transpose params values. This operation is only need if the axis is not 0.
239 if (axis > 0)
240 {
241 tensors.emplace_back(new TosaSerializationTensor(outputTransposeParamsName,
242 inputTransposeShape,
243 inputParamsDType,
244 {}));
245
246 TosaTransposeAttribute transposeInputAttribute(inputTransposePermutation);
247
248 auto *transposeInputOp = new TosaSerializationOperator(Op_TRANSPOSE,
249 Attribute_TransposeAttribute,
250 &transposeInputAttribute,
251 {inputParamsName},
252 {outputTransposeParamsName});
253 operators.push_back(transposeInputOp);
254 }
255
256 // 2. Reshape params
257 std::string& reshapeOpInputParamsName = axis > 0 ? outputTransposeParamsName : inputParamsName;
258
259 tensors.emplace_back(new TosaSerializationTensor(outputReshapeParamsName,
260 tosaValuesShape,
261 inputParamsDType,
262 {}));
263
264 TosaReshapeAttribute reshapeValuesAttribute(tosaValuesShape);
265
266 auto* reshapeInputParamsOp = new TosaSerializationOperator(Op_RESHAPE,
267 Attribute_ReshapeAttribute,
268 &reshapeValuesAttribute,
269 {reshapeOpInputParamsName},
270 {outputReshapeParamsName});
271 operators.push_back(reshapeInputParamsOp);
272
273 // 3. Reshape indices
274 tensors.emplace_back(new TosaSerializationTensor(outputReshapeIndicesName,
275 tosaIndicesShape,
276 inputIndicesDType,
277 {}));
278
279 TosaReshapeAttribute reshapeIndicesAttribute(tosaIndicesShape);
280
281 auto* reshapeInputIndicesOp = new TosaSerializationOperator(Op_RESHAPE,
282 Attribute_ReshapeAttribute,
283 &reshapeIndicesAttribute,
284 {inputIndicesName},
285 {outputReshapeIndicesName});
286 operators.push_back(reshapeInputIndicesOp);
287
288 // 4. Gather params, indices
289 tensors.emplace_back(new TosaSerializationTensor(outputGatherName,
290 tosaGatherResultShape,
291 inputParamsDType,
292 {}));
293
294 auto* gatherOp = new TosaSerializationOperator(Op_GATHER,
295 Attribute_NONE,
296 nullptr,
297 {outputReshapeParamsName, outputReshapeIndicesName},
298 {outputGatherName});
299 operators.push_back(gatherOp);
300
301 // 5. Reshape gather output
302 if (axis > 0)
303 {
304 // If a Transpose op is needed below, an additional tensor is needed to store the reshape output.
305 tensors.emplace_back(new TosaSerializationTensor(outputReshapeGatherName,
306 resultReshapeShape,
307 outputDType0,
308 {}));
309 }
310
311 std::string& reshapeOpOutputName = axis > 0 ? outputReshapeGatherName : outputName;
312
313 TosaReshapeAttribute reshapeGatherAttribute(resultReshapeShape);
314
315 auto* reshapeGatherOutputOp = new TosaSerializationOperator(Op_RESHAPE,
316 Attribute_ReshapeAttribute,
317 &reshapeGatherAttribute,
318 {outputGatherName},
319 {reshapeOpOutputName});
320 operators.push_back(reshapeGatherOutputOp);
321
322 // 6. Transpose result output. This operator is only needed if the axis is not 0
323 if (axis > 0)
324 {
325 TosaTransposeAttribute transposeOutputAttribute(resultTransposePerm);
326
327 auto* transposeOutputOp = new TosaSerializationOperator(Op_TRANSPOSE,
328 Attribute_TransposeAttribute,
329 &transposeOutputAttribute,
330 {outputReshapeGatherName},
331 {outputName});
332 operators.push_back(transposeOutputOp);
333 }
334
335 // operatorInputNames/operatorOutputNames ends up being the same as
336 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
337 return new TosaSerializationBasicBlock(blockName, // name
338 mainName, // region name
339 operators, // operators
340 tensors, // tensors
341 {inputParamsName, inputIndicesName}, // inputs
342 {outputName}); // outputs
343}
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
int32_t m_Axis
The axis in params to gather indices from.

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE, ArmNNToDType(), GenerateUniqueInputName(), GenerateUniqueOutputName(), Layer::GetInputSlot(), GetTosaTensorShape(), GetUniqueTosaMappingID(), GatherDescriptor::m_Axis, and mainName.

Referenced by GetTosaMapping().