ArmNN
 24.08
ScatterNd.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ScatterNd.hpp"
7 #include "Encoders.hpp"
9 #include <armnn/Logging.hpp>
10 
11 #include <fmt/format.h>
12 
13 #include <numeric>
14 
15 namespace armnn
16 {
17 
19  float input,
20  float update)
21 {
22  switch (operation)
23  {
25  return update;
27  return input + update;
29  return input - update;
31  return std::max(input, update);
33  return std::min(input, update);
35  return input * update;
36  default:
37  throw InvalidArgumentException("ScatterNd: cannot execute this operation.");
38  }
39 }
40 
41 void ScatterNd(const TensorInfo& inputInfo,
42  const TensorInfo& indicesInfo,
43  const TensorInfo& updatesInfo,
44  Decoder<float>& input,
45  Decoder<int>& indices,
46  Decoder<float>& updates,
47  Encoder<float>& output,
48  ScatterNdDescriptor descriptor)
49 {
50  // Axis Unsupported
51  if (descriptor.m_AxisEnabled)
52  {
53  throw InvalidArgumentException("ScatterNd: axis param not supported.");
54  }
55 
56  // Get the shape for indices, updates, and input
57  TensorShape indicesShape = indicesInfo.GetShape();
58  TensorShape updatesShape = updatesInfo.GetShape();
59  TensorShape inputShape = inputInfo.GetShape();
60 
61  // Get the dimensions for indices and updates
62  unsigned int dimension = inputInfo.GetNumDimensions();
63  unsigned int indicesDim = indicesInfo.GetNumDimensions();
64  unsigned int updatesDim = updatesInfo.GetNumDimensions();
65 
66  // Calculate the outter and inner dimensions
67  unsigned int outterDim = indicesShape[indicesDim - 1];
68  unsigned int innerDim = dimension - outterDim;
69 
70  // Calculate the number of elements in each dimension
71  unsigned int numElementsCount = 1;
72  std::vector<unsigned int> elementInDim(dimension);
73  for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
74  {
75  elementInDim[dimIndex - 1] = numElementsCount;
76  numElementsCount *= inputShape[dimIndex - 1];
77  }
78 
79  // Number of updates per index
80  unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
81 
82  // Number of indices to update
83  unsigned int numIndices = indicesShape[0];
84 
85  // Check Input Requirements
86  // Requirement 1: Indices and Updates must have rank at least 1
87  if (indicesDim < 1 || updatesDim < 1)
88  {
89  throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
90  }
91 
92  // Requirement 2: Input, Indices and Updates must have values
93  if (inputInfo.GetNumElements() == 0 ||
94  indicesInfo.GetNumElements() == 0 ||
95  updatesInfo.GetNumElements() == 0)
96  {
97  throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values.");
98  }
99 
100  // Requirement 3: Indices and Updates must match in shape
101  // The updates dimension should equals to 1 + inner dimension
102  if (updatesDim != 1 + innerDim)
103  {
104  throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
105  }
106  // The inner dimension of updates has to match with shape of input
107  for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
108  {
109  if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
110  {
112  fmt::format("ScatterNd: input and updates shape not match on dimension {}",
113  dimension - dimBackIndex));
114  }
115  }
116 
117  // Requirement 4: Check duplicate indices and out of bound indices
118  std::set<int> indicesSet;
119  std::vector<int> flattenIndices(numIndices);
120  for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
121  {
122  // Get the index
123  int flattenIndex = 0;
124 
125  for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
126 
127  int outterIndexValue = indices.Get();
128 
129  // Check bounds
130  if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
131  {
133  fmt::format("ScatterNd: indices {} out of bound [0, {})",
134  outterIndexValue, inputShape[outterIdx]));
135  }
136 
137  flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
138  ++indices;
139  }
140 
141  // Check duplicates when executing ScatterNd::Update
142  if (descriptor.m_Function == ScatterNdFunction::Update &&
143  indicesSet.find(flattenIndex) != indicesSet.end())
144  {
146  fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex));
147  }
148 
149  flattenIndices[indicesIdx] = flattenIndex;
150  indicesSet.insert(flattenIndex);
151  }
152 
153  // Set the input data to output
154  for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx)
155  {
156  float inputValue = input.Get();
157  ++input;
158  output.Set(inputValue);
159  ++output;
160  }
161 
162  // Iterate through all indices to scatter updates
163  for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
164  {
165  // Get the index and calculate the flatten index
166  int flattenIndex = flattenIndices[indicesIdx];
167 
168  // FlattenIndex is the place that we are going to update the elements
169  unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
170  for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
171  {
172  updates[updatesStartIdx + updatesIdx];
173  input[static_cast<unsigned int>(flattenIndex) + updatesIdx];
174  float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get());
175  output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
176  output.Set(updateValue);
177  }
178  }
179 }
180 
181 void ScatterNd(const TensorInfo& indicesInfo,
182  const TensorInfo& updatesInfo,
183  const TensorInfo& shapeInfo,
184  Decoder<int>& indices,
185  Decoder<float>& updates,
186  Decoder<int>& shape,
187  Encoder<float>& output,
188  ScatterNdDescriptor descriptor)
189 {
190  // Axis Unsupported
191  if (descriptor.m_AxisEnabled)
192  {
193  throw InvalidArgumentException("ScatterNd: axis param not supported.");
194  }
195 
196  // Get the shape for indices, updates, and input
197  TensorShape indicesShape = indicesInfo.GetShape();
198  TensorShape updatesShape = updatesInfo.GetShape();
199 
200  // Get the shape values
201  std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape());
202  // Check the shape
203  if (shapeInfo.GetNumElements() == 0)
204  {
205  throw InvalidArgumentException("ScatterNd: shape must have values.");
206  }
207  for (auto shapeValue : shapeValues)
208  {
209  if (shapeValue <= 0)
210  {
211  throw InvalidArgumentException("ScatterNd: shape values must >= 0.");
212  }
213  }
214  // Get the input shape
215  std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end());
216  unsigned int inputElementsNum = static_cast<unsigned int>(
217  std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>()));
218 
219  // Get the dimensions for indices and updates
220  unsigned int dimension = shapeInfo.GetNumElements();
221  unsigned int indicesDim = indicesInfo.GetNumDimensions();
222  unsigned int updatesDim = updatesInfo.GetNumDimensions();
223 
224  // Calculate the outter and inner dimensions
225  unsigned int outterDim = indicesShape[indicesDim - 1];
226  unsigned int innerDim = dimension - outterDim;
227 
228  // Calculate the number of elements in each dimension
229  unsigned int numElementsCount = 1;
230  std::vector<unsigned int> elementInDim(dimension);
231  for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
232  {
233  elementInDim[dimIndex - 1] = numElementsCount;
234  numElementsCount *= inputShape[dimIndex - 1];
235  }
236 
237  // Number of updates per index
238  unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
239 
240  // Number of indices to update
241  unsigned int numIndices = indicesShape[0];
242 
243  // Check Input Requirements
244  // Requirement 1: Indices and Updates must have rank at least 1
245  if (indicesDim < 1 || updatesDim < 1)
246  {
247  throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
248  }
249 
250  // Requirement 2: shape, Indices and Updates must have values
251  if (indicesInfo.GetNumElements() == 0 ||
252  updatesInfo.GetNumElements() == 0)
253  {
254  throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values.");
255  }
256 
257  // Requirement 3: Indices and Updates must match in shape
258  // The updates dimension should equals to 1 + inner dimension
259  if (updatesDim != 1 + innerDim)
260  {
261  throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
262  }
263  // The inner dimension of updates has to match with shape of input
264  for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
265  {
266  if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
267  {
269  fmt::format("ScatterNd: input and updates shape not match on dimension {}",
270  dimension - dimBackIndex));
271  }
272  }
273 
274  // Requirement 4: Check duplicate indices and out of bound indices
275  std::set<int> indicesSet;
276  std::vector<int> flattenIndices(numIndices);
277  for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
278  {
279  // Get the index
280  int flattenIndex = 0;
281 
282  for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
283 
284  int outterIndexValue = indices.Get();
285 
286  // Check bounds
287  if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
288  {
290  fmt::format("ScatterNd: indices {} out of bound [0, {})",
291  outterIndexValue, inputShape[outterIdx]));
292  }
293 
294  flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
295  ++indices;
296  }
297 
298  // Check duplicates when executing ScatterNd::Update
299  if (descriptor.m_Function == ScatterNdFunction::Update &&
300  indicesSet.find(flattenIndex) != indicesSet.end())
301  {
303  fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.",
304  flattenIndex));
305  }
306 
307  flattenIndices[indicesIdx] = flattenIndex;
308  indicesSet.insert(flattenIndex);
309  }
310 
311  // Set zeros to output
312  for (unsigned int idx = 0; idx < inputElementsNum; ++idx)
313  {
314  output.Set(0.0f);
315  ++output;
316  }
317 
318  // Iterate through all indices to scatter updates
319  for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
320  {
321  // Get the index and calculate the flatten index
322  int flattenIndex = flattenIndices[indicesIdx];
323 
324  // FlattenIndex is the place that we are going to update the elements
325  unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
326  for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
327  {
328  updates[updatesStartIdx + updatesIdx];
329  float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get());
330  output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
331  output.Set(updateValue);
332  }
333  }
334 }
335 
336 } // namespace armnn
armnn::Decoder< float >
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
armnn::ScatterNdFunction::Min
@ Min
armnn::Encoder::Set
virtual void Set(IType right)=0
WorkloadData.hpp
armnn::ScatterNdFunction::Sub
@ Sub
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::ScatterNdDescriptor::m_AxisEnabled
bool m_AxisEnabled
Flag for ScatterElement, will be set to false by default, we do not support m_AxisEnable = true for n...
Definition: Descriptors.hpp:1728
armnn::ScatterNdFunction::Mul
@ Mul
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Encoder< float >
Logging.hpp
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::ScatterNdFunction
ScatterNdFunction
Definition: Types.hpp:500
armnn::Decoder::DecodeTensor
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
armnn::ScatterNdFunction::Add
@ Add
armnn::ScatterNd
void ScatterNd(const TensorInfo &inputInfo, const TensorInfo &indicesInfo, const TensorInfo &updatesInfo, Decoder< float > &input, Decoder< int > &indices, Decoder< float > &updates, Encoder< float > &output, ScatterNdDescriptor descriptor)
Definition: ScatterNd.cpp:41
armnn::Decoder::Get
virtual IType Get() const =0
ScatterNd.hpp
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::ScatterNdFunction::Update
@ Update
armnn::ScatterNdFunction::Max
@ Max
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::ScatterNdDescriptor::m_Function
ScatterNdFunction m_Function
Specify if the function is update, add, sub, max or min.
Definition: Descriptors.hpp:1719
armnn::ScatterNdDescriptor
A ScatterNdDescriptor for the ScatterNdLayer.
Definition: Descriptors.hpp:1679
Encoders.hpp
armnn::ScatterOperation
float ScatterOperation(ScatterNdFunction operation, float input, float update)
Definition: ScatterNd.cpp:18