ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
GatherOperator.hpp File Reference
Include dependency graph for GatherOperator.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertGatherToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const GatherDescriptor *gatherDescriptor)
 

Function Documentation

◆ ConvertGatherToTosaOperator()

TosaSerializationBasicBlock* ConvertGatherToTosaOperator ( const Layer layer,
const std::vector< const TensorInfo * > &  inputs,
const std::vector< const TensorInfo * > &  outputs,
const GatherDescriptor gatherDescriptor 
)

Definition at line 16 of file GatherOperator.cpp.

20 {
21  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs.size() == 2,
22  "ConvertGatherToTosaOperator: Gather must have two inputs");
23 
24  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(outputs.size() == 1,
25  "ConvertGatherToTosaOperator: Gather must have only one output");
26 
27  unsigned int paramsRank = inputs[0]->GetNumDimensions();
28  unsigned int indicesRank = inputs[1]->GetNumDimensions();
29 
30  int batch_dims = 0; // ArmNN does not currently support this parameter, setting it to the default value.
31 
32  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(gatherDescriptor->m_Axis >= 0 &&
33  gatherDescriptor->m_Axis < static_cast<int32_t>(paramsRank),
34  "ConvertGatherToTosaOperator: axis must be < values rank");
35 
36  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(batch_dims <= static_cast<int32_t>(indicesRank),
37  "ConvertGatherToTosaOperator: batch dimensions must be <= indices rank");
38 
39  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(gatherDescriptor->m_Axis >= batch_dims,
40  "ConvertGatherToTosaOperator: axis must be >= batch dimensions.");
41 
42  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs[0]->GetDataType() != DataType::QAsymmU8,
43  "ConvertGatherToTosaOperator: Tosa gather does not support unsigned types.");
44 
45  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs[1]->GetDataType() != DataType::Signed64,
46  "ConvertGatherToTosaOperator: Tosa gather does not support int 64 indices.");
47 
48  unsigned int axis = static_cast<unsigned int>(gatherDescriptor->m_Axis);
49  unsigned int batchDims = static_cast<unsigned int>(batch_dims);
50 
51  std::string inputParamsName = std::string("input_0_params");
52  std::string inputIndicesName = std::string("input_1_indices");
53  std::string outputTransposeParamsName = std::string("intermediate_0_transpose_params") + GetUniqueTosaMappingID();
54  std::string outputReshapeParamsName = std::string("intermediate_1_reshape_params") + GetUniqueTosaMappingID();
55  std::string outputReshapeIndicesName = std::string("intermediate_2_reshape_indices") + GetUniqueTosaMappingID();
56  std::string outputGatherName = std::string("intermediate_3_gather") + GetUniqueTosaMappingID();
57  std::string outputReshapeGatherName = std::string("intermediate_4_reshape_gather") + GetUniqueTosaMappingID();
58  std::string outputName = std::string("output_0");
59 
60  std::string blockName = std::string("Op_GATHER_block_") + GetUniqueTosaMappingID();
61 
62  std::vector<TosaSerializationTensor*> tensors;
63  std::vector<TosaSerializationOperator*> operators;
64 
65  // If a layer is present then the block will be used for execution, so input and output names need to be determined
66  // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
67  if(layer)
68  {
69  // Get the layer connected to the input slot and determine unique tensor names.
70  inputParamsName = GenerateUniqueInputName(layer->GetInputSlot(0));
71  inputIndicesName = GenerateUniqueInputName(layer->GetInputSlot(1));
72  outputName = GenerateUniqueOutputName(*layer);
73  }
74 
75  auto inputParamsDType = ArmNNToDType(inputs[0]->GetDataType());
76  auto inputIndicesDType = ArmNNToDType(inputs[1]->GetDataType());
77 
78  // Only add input tensors if connected layer is an input layer.
79  // As intermediate or constant tensors will be created separately.
80  // There also can't be duplicate tensor.
81  if(inputParamsName.find("input_") != std::string::npos)
82  {
83  std::vector<int32_t> inputParamsShape = GetTosaTensorShape(inputs[0]->GetShape());
84  tensors.push_back(new TosaSerializationTensor(inputParamsName, inputParamsShape, inputParamsDType, {}));
85  }
86  if(inputIndicesName.find("input_") != std::string::npos)
87  {
88  std::vector<int32_t> inputIndicesShape = GetTosaTensorShape(inputs[1]->GetShape());
89  tensors.push_back(new TosaSerializationTensor(inputIndicesName, inputIndicesShape, inputIndicesDType, {}));
90  }
91 
92  std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
93  DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
94  tensors.push_back(new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
95 
96  std::vector<int32_t> paramsShape = GetTosaTensorShape(inputs[0]->GetShape());
97  std::vector<int32_t> indicesShape = GetTosaTensorShape(inputs[1]->GetShape());
98 
99  // Parameters needed to calculate output shapes and transpose permutations
100  std::vector<int32_t> paramsBatch;
101  std::vector<int32_t> paramsIndices;
102  std::vector<int32_t> paramsLeftChannels;
103  std::vector<int32_t> paramsRightChannels;
104 
105  std::vector<int32_t> paramsIdxBatch;
106  std::vector<int32_t> paramsIdxIndices;
107  std::vector<int32_t> paramsIdxLeftChannels;
108  std::vector<int32_t> paramsIdxRightChannels;
109 
110  for (unsigned int i = 0; i < paramsRank; i++)
111  {
112  if (i < batchDims && i < axis)
113  {
114  paramsBatch.push_back(paramsShape[i]);
115  paramsIdxBatch.push_back(static_cast<int32_t>(i));
116  }
117  else if (i < axis)
118  {
119  paramsLeftChannels.push_back(paramsShape[i]);
120  paramsIdxLeftChannels.push_back(static_cast<int32_t>(i));
121  }
122  else if (i < (axis + 1))
123  {
124  paramsIndices.push_back(paramsShape[i]);
125  paramsIdxIndices.push_back(static_cast<int32_t>(i));
126  }
127  else
128  {
129  paramsRightChannels.push_back(paramsShape[i]);
130  paramsIdxRightChannels.push_back(static_cast<int32_t>(i));
131  }
132  }
133 
134  // Calculate N, K, W, C
135  // N: number of batches
136  // W: number of indices in each batch
137  // K: range of each index
138  // C: number of channels for each index
139  std::vector<int32_t> paramsLow;
140  std::vector<int32_t> paramsMid;
141  std::vector<int32_t> paramsHigh;
142  std::vector<int32_t> indicesMid;
143 
144  // Copy the first batchDims number of paramsShape values to paramsLow
145  for (unsigned int i = 0; i < batchDims; i++)
146  {
147  paramsLow.push_back(paramsShape[i]);
148  }
149  // Starting at batchDims index, copy the next (axis - batchDims) number of paramsShape values to paramsMid
150  for (unsigned int i = 0; i < (axis - batchDims); i++)
151  {
152  paramsMid.push_back(paramsShape[batchDims + i]);
153  }
154  // Starting at (axis + 1) index, copy the next (paramsRank - axis - 1) number of paramsShape values to paramsHigh
155  for (unsigned int i = 0; i < (paramsRank - axis - 1); i++)
156  {
157  paramsHigh.push_back(paramsShape[axis + 1 + i]);
158  }
159  // Starting at batchDims index, copy the next (indicesRank - batchDims) number of indicesShape values to indicesMid
160  for (unsigned int i = 0; i < (indicesRank - batchDims); i++)
161  {
162  indicesMid.push_back(indicesShape[batchDims + i]);
163  }
164 
165  auto lowProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsMid),
166  std::end(paramsMid),
167  1,
168  std::multiplies<>() ));
169  auto highProduct = static_cast<int32_t>(std::accumulate(std::begin(paramsHigh),
170  std::end(paramsHigh),
171  1,
172  std::multiplies<>() ));
173 
174  auto N = static_cast<int32_t>(std::accumulate(std::begin(paramsLow),
175  std::end(paramsLow),
176  1,
177  std::multiplies<>() ));
178  auto W = static_cast<int32_t>(std::accumulate(std::begin(indicesMid),
179  std::end(indicesMid),
180  1,
181  std::multiplies<>() ));
182  auto K = paramsShape[axis];
183  auto C = lowProduct * highProduct;
184 
185  // Parameters needed for input transpose
186  std::vector<int32_t> inputTransposePermutation;
187  std::vector<int32_t> inputTransposeShape;
188  for (unsigned int i = 0; i < paramsBatch.size(); i++)
189  {
190  inputTransposePermutation.push_back(paramsIdxBatch[i]);
191  inputTransposeShape.push_back(paramsBatch[i]);
192  }
193  for (unsigned int i = 0; i < paramsIndices.size(); i++)
194  {
195  inputTransposePermutation.push_back(paramsIdxIndices[i]);
196  inputTransposeShape.push_back(paramsIndices[i]);
197  }
198  for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
199  {
200  inputTransposePermutation.push_back(paramsIdxLeftChannels[i]);
201  inputTransposeShape.push_back(paramsLeftChannels[i]);
202  }
203  for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
204  {
205  inputTransposePermutation.push_back(paramsIdxRightChannels[i]);
206  inputTransposeShape.push_back(paramsRightChannels[i]);
207  }
208 
209  // Parameters needed for result/output transpose
210  std::vector<int32_t> resultReshapeShape;
211  resultReshapeShape.insert(resultReshapeShape.end(), indicesShape.begin(), indicesShape.end());
212  resultReshapeShape.insert(resultReshapeShape.end(), paramsLeftChannels.begin(), paramsLeftChannels.end());
213  resultReshapeShape.insert(resultReshapeShape.end(), paramsRightChannels.begin(), paramsRightChannels.end());
214 
215  std::vector<int32_t> resultTransposePerm;
216  for (unsigned int i = 0; i < batchDims; i++)
217  {
218  resultTransposePerm.push_back(static_cast<int32_t>(i));
219  }
220  for (unsigned int i = 0; i < paramsLeftChannels.size(); i++)
221  {
222  resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions()));
223  }
224  for (unsigned int i = batchDims; i < inputs[1]->GetNumDimensions(); i++)
225  {
226  resultTransposePerm.push_back(static_cast<int32_t>(i));
227  }
228  for (unsigned int i = 0; i < paramsRightChannels.size(); i++)
229  {
230  resultTransposePerm.push_back(static_cast<int32_t>(i + inputs[1]->GetNumDimensions() +
231  paramsLeftChannels.size()));
232  }
233 
234  std::vector<int32_t> tosaValuesShape = {N, K, C};
235  std::vector<int32_t> tosaIndicesShape = {N, W};
236  std::vector<int32_t> tosaGatherResultShape = {N, W, C};
237 
238  // 1. Transpose params values. This operation is only need if the axis is not 0.
239  if (axis > 0)
240  {
241  tensors.emplace_back(new TosaSerializationTensor(outputTransposeParamsName,
242  inputTransposeShape,
243  inputParamsDType,
244  {}));
245 
246  TosaTransposeAttribute transposeInputAttribute(inputTransposePermutation);
247 
248  auto *transposeInputOp = new TosaSerializationOperator(Op_TRANSPOSE,
249  Attribute_TransposeAttribute,
250  &transposeInputAttribute,
251  {inputParamsName},
252  {outputTransposeParamsName});
253  operators.push_back(transposeInputOp);
254  }
255 
256  // 2. Reshape params
257  std::string& reshapeOpInputParamsName = axis > 0 ? outputTransposeParamsName : inputParamsName;
258 
259  tensors.emplace_back(new TosaSerializationTensor(outputReshapeParamsName,
260  tosaValuesShape,
261  inputParamsDType,
262  {}));
263 
264  TosaReshapeAttribute reshapeValuesAttribute(tosaValuesShape);
265 
266  auto* reshapeInputParamsOp = new TosaSerializationOperator(Op_RESHAPE,
267  Attribute_ReshapeAttribute,
268  &reshapeValuesAttribute,
269  {reshapeOpInputParamsName},
270  {outputReshapeParamsName});
271  operators.push_back(reshapeInputParamsOp);
272 
273  // 3. Reshape indices
274  tensors.emplace_back(new TosaSerializationTensor(outputReshapeIndicesName,
275  tosaIndicesShape,
276  inputIndicesDType,
277  {}));
278 
279  TosaReshapeAttribute reshapeIndicesAttribute(tosaIndicesShape);
280 
281  auto* reshapeInputIndicesOp = new TosaSerializationOperator(Op_RESHAPE,
282  Attribute_ReshapeAttribute,
283  &reshapeIndicesAttribute,
284  {inputIndicesName},
285  {outputReshapeIndicesName});
286  operators.push_back(reshapeInputIndicesOp);
287 
288  // 4. Gather params, indices
289  tensors.emplace_back(new TosaSerializationTensor(outputGatherName,
290  tosaGatherResultShape,
291  inputParamsDType,
292  {}));
293 
294  auto* gatherOp = new TosaSerializationOperator(Op_GATHER,
295  Attribute_NONE,
296  nullptr,
297  {outputReshapeParamsName, outputReshapeIndicesName},
298  {outputGatherName});
299  operators.push_back(gatherOp);
300 
301  // 5. Reshape gather output
302  if (axis > 0)
303  {
304  // If a Transpose op is needed below, an additional tensor is needed to store the reshape output.
305  tensors.emplace_back(new TosaSerializationTensor(outputReshapeGatherName,
306  resultReshapeShape,
307  outputDType0,
308  {}));
309  }
310 
311  std::string& reshapeOpOutputName = axis > 0 ? outputReshapeGatherName : outputName;
312 
313  TosaReshapeAttribute reshapeGatherAttribute(resultReshapeShape);
314 
315  auto* reshapeGatherOutputOp = new TosaSerializationOperator(Op_RESHAPE,
316  Attribute_ReshapeAttribute,
317  &reshapeGatherAttribute,
318  {outputGatherName},
319  {reshapeOpOutputName});
320  operators.push_back(reshapeGatherOutputOp);
321 
322  // 6. Transpose result output. This operator is only needed if the axis is not 0
323  if (axis > 0)
324  {
325  TosaTransposeAttribute transposeOutputAttribute(resultTransposePerm);
326 
327  auto* transposeOutputOp = new TosaSerializationOperator(Op_TRANSPOSE,
328  Attribute_TransposeAttribute,
329  &transposeOutputAttribute,
330  {outputReshapeGatherName},
331  {outputName});
332  operators.push_back(transposeOutputOp);
333  }
334 
335  // operatorInputNames/operatorOutputNames ends up being the same as
336  // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
337  return new TosaSerializationBasicBlock(blockName, // name
338  mainName, // region name
339  operators, // operators
340  tensors, // tensors
341  {inputParamsName, inputIndicesName}, // inputs
342  {outputName}); // outputs
343 }
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
Definition: Exceptions.hpp:210
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:337
int32_t m_Axis
The axis in params to gather indices from.

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE, ArmNNToDType(), GenerateUniqueInputName(), GenerateUniqueOutputName(), Layer::GetInputSlot(), GetTosaTensorShape(), GetUniqueTosaMappingID(), and GatherDescriptor::m_Axis.

Referenced by GetTosaMapping().