ArmNN
 25.11
Loading...
Searching...
No Matches
ScatterNd.cpp
Go to the documentation of this file.
1//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ScatterNd.hpp"
7#include "Encoders.hpp"
9#include <armnn/Logging.hpp>
10
11#include <fmt/format.h>
12
13#include <numeric>
14
15namespace armnn
16{
17
19 float input,
20 float update)
21{
22 switch (operation)
23 {
25 return update;
27 return input + update;
29 return input - update;
31 return std::max(input, update);
33 return std::min(input, update);
35 return input * update;
36 default:
37 throw InvalidArgumentException("ScatterNd: cannot execute this operation.");
38 }
39}
40
41void ScatterNd(const TensorInfo& inputInfo,
42 const TensorInfo& indicesInfo,
43 const TensorInfo& updatesInfo,
44 Decoder<float>& input,
45 Decoder<int>& indices,
46 Decoder<float>& updates,
47 Encoder<float>& output,
48 ScatterNdDescriptor descriptor)
49{
50 // Axis Unsupported
51 if (descriptor.m_AxisEnabled)
52 {
53 throw InvalidArgumentException("ScatterNd: axis param not supported.");
54 }
55
56 // Get the shape for indices, updates, and input
57 TensorShape indicesShape = indicesInfo.GetShape();
58 TensorShape updatesShape = updatesInfo.GetShape();
59 TensorShape inputShape = inputInfo.GetShape();
60
61 // Get the dimensions for indices and updates
62 unsigned int dimension = inputInfo.GetNumDimensions();
63 unsigned int indicesDim = indicesInfo.GetNumDimensions();
64 unsigned int updatesDim = updatesInfo.GetNumDimensions();
65
66 // Calculate the outter and inner dimensions
67 unsigned int outterDim = indicesShape[indicesDim - 1];
68 unsigned int innerDim = dimension - outterDim;
69
70 // Calculate the number of elements in each dimension
71 unsigned int numElementsCount = 1;
72 std::vector<unsigned int> elementInDim(dimension);
73 for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
74 {
75 elementInDim[dimIndex - 1] = numElementsCount;
76 numElementsCount *= inputShape[dimIndex - 1];
77 }
78
79 // Number of updates per index
80 unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
81
82 // Number of indices to update
83 unsigned int numIndices = indicesShape[0];
84
85 // Check Input Requirements
86 // Requirement 1: Indices and Updates must have rank at least 1
87 if (indicesDim < 1 || updatesDim < 1)
88 {
89 throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
90 }
91
92 // Requirement 2: Input, Indices and Updates must have values
93 if (inputInfo.GetNumElements() == 0 ||
94 indicesInfo.GetNumElements() == 0 ||
95 updatesInfo.GetNumElements() == 0)
96 {
97 throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values.");
98 }
99
100 // Requirement 3: Indices and Updates must match in shape
101 // The updates dimension should equals to 1 + inner dimension
102 if (updatesDim != 1 + innerDim)
103 {
104 throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
105 }
106 // The inner dimension of updates has to match with shape of input
107 for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
108 {
109 if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
110 {
112 fmt::format("ScatterNd: input and updates shape not match on dimension {}",
113 dimension - dimBackIndex));
114 }
115 }
116
117 // Requirement 4: Check duplicate indices and out of bound indices
118 std::set<int> indicesSet;
119 std::vector<int> flattenIndices(numIndices);
120 for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
121 {
122 // Get the index
123 int flattenIndex = 0;
124
125 for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
126
127 int outterIndexValue = indices.Get();
128
129 // Check bounds
130 if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
131 {
133 fmt::format("ScatterNd: indices {} out of bound [0, {})",
134 outterIndexValue, inputShape[outterIdx]));
135 }
136
137 flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
138 ++indices;
139 }
140
141 // Check duplicates when executing ScatterNd::Update
142 if (descriptor.m_Function == ScatterNdFunction::Update &&
143 indicesSet.find(flattenIndex) != indicesSet.end())
144 {
146 fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex));
147 }
148
149 flattenIndices[indicesIdx] = flattenIndex;
150 indicesSet.insert(flattenIndex);
151 }
152
153 // Set the input data to output
154 for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx)
155 {
156 float inputValue = input.Get();
157 ++input;
158 output.Set(inputValue);
159 ++output;
160 }
161
162 // Iterate through all indices to scatter updates
163 for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
164 {
165 // Get the index and calculate the flatten index
166 int flattenIndex = flattenIndices[indicesIdx];
167
168 // FlattenIndex is the place that we are going to update the elements
169 unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
170 for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
171 {
172 updates[updatesStartIdx + updatesIdx];
173 input[static_cast<unsigned int>(flattenIndex) + updatesIdx];
174 float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get());
175 output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
176 output.Set(updateValue);
177 }
178 }
179}
180
181void ScatterNd(const TensorInfo& indicesInfo,
182 const TensorInfo& updatesInfo,
183 const TensorInfo& shapeInfo,
184 Decoder<int>& indices,
185 Decoder<float>& updates,
186 Decoder<int>& shape,
187 Encoder<float>& output,
188 ScatterNdDescriptor descriptor)
189{
190 // Axis Unsupported
191 if (descriptor.m_AxisEnabled)
192 {
193 throw InvalidArgumentException("ScatterNd: axis param not supported.");
194 }
195
196 // Get the shape for indices, updates, and input
197 TensorShape indicesShape = indicesInfo.GetShape();
198 TensorShape updatesShape = updatesInfo.GetShape();
199
200 // Get the shape values
201 std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape());
202 // Check the shape
203 if (shapeInfo.GetNumElements() == 0)
204 {
205 throw InvalidArgumentException("ScatterNd: shape must have values.");
206 }
207 for (auto shapeValue : shapeValues)
208 {
209 if (shapeValue <= 0)
210 {
211 throw InvalidArgumentException("ScatterNd: shape values must >= 0.");
212 }
213 }
214 // Get the input shape
215 std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end());
216 unsigned int inputElementsNum = static_cast<unsigned int>(
217 std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>()));
218
219 // Get the dimensions for indices and updates
220 unsigned int dimension = shapeInfo.GetNumElements();
221 unsigned int indicesDim = indicesInfo.GetNumDimensions();
222 unsigned int updatesDim = updatesInfo.GetNumDimensions();
223
224 // Calculate the outter and inner dimensions
225 unsigned int outterDim = indicesShape[indicesDim - 1];
226 unsigned int innerDim = dimension - outterDim;
227
228 // Calculate the number of elements in each dimension
229 unsigned int numElementsCount = 1;
230 std::vector<unsigned int> elementInDim(dimension);
231 for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
232 {
233 elementInDim[dimIndex - 1] = numElementsCount;
234 numElementsCount *= inputShape[dimIndex - 1];
235 }
236
237 // Number of updates per index
238 unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
239
240 // Number of indices to update
241 unsigned int numIndices = indicesShape[0];
242
243 // Check Input Requirements
244 // Requirement 1: Indices and Updates must have rank at least 1
245 if (indicesDim < 1 || updatesDim < 1)
246 {
247 throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
248 }
249
250 // Requirement 2: shape, Indices and Updates must have values
251 if (indicesInfo.GetNumElements() == 0 ||
252 updatesInfo.GetNumElements() == 0)
253 {
254 throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values.");
255 }
256
257 // Requirement 3: Indices and Updates must match in shape
258 // The updates dimension should equals to 1 + inner dimension
259 if (updatesDim != 1 + innerDim)
260 {
261 throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
262 }
263 // The inner dimension of updates has to match with shape of input
264 for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
265 {
266 if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
267 {
269 fmt::format("ScatterNd: input and updates shape not match on dimension {}",
270 dimension - dimBackIndex));
271 }
272 }
273
274 // Requirement 4: Check duplicate indices and out of bound indices
275 std::set<int> indicesSet;
276 std::vector<int> flattenIndices(numIndices);
277 for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
278 {
279 // Get the index
280 int flattenIndex = 0;
281
282 for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
283
284 int outterIndexValue = indices.Get();
285
286 // Check bounds
287 if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
288 {
290 fmt::format("ScatterNd: indices {} out of bound [0, {})",
291 outterIndexValue, inputShape[outterIdx]));
292 }
293
294 flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
295 ++indices;
296 }
297
298 // Check duplicates when executing ScatterNd::Update
299 if (descriptor.m_Function == ScatterNdFunction::Update &&
300 indicesSet.find(flattenIndex) != indicesSet.end())
301 {
303 fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.",
304 flattenIndex));
305 }
306
307 flattenIndices[indicesIdx] = flattenIndex;
308 indicesSet.insert(flattenIndex);
309 }
310
311 // Set zeros to output
312 for (unsigned int idx = 0; idx < inputElementsNum; ++idx)
313 {
314 output.Set(0.0f);
315 ++output;
316 }
317
318 // Iterate through all indices to scatter updates
319 for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
320 {
321 // Get the index and calculate the flatten index
322 int flattenIndex = flattenIndices[indicesIdx];
323
324 // FlattenIndex is the place that we are going to update the elements
325 unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
326 for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
327 {
328 updates[updatesStartIdx + updatesIdx];
329 float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get());
330 output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
331 output.Set(updateValue);
332 }
333 }
334}
335
336} // namespace armnn
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
virtual IType Get() const =0
virtual void Set(IType right)=0
const TensorShape & GetShape() const
Definition Tensor.hpp:193
unsigned int GetNumDimensions() const
Definition Tensor.hpp:197
unsigned int GetNumElements() const
Definition Tensor.hpp:198
Copyright (c) 2021 ARM Limited and Contributors.
ScatterNdFunction
Definition Types.hpp:503
float ScatterOperation(ScatterNdFunction operation, float input, float update)
Definition ScatterNd.cpp:18
A ScatterNdDescriptor for the ScatterNdLayer.
ScatterNdFunction m_Function
Specify if the function is update, add, sub, max or min.
bool m_AxisEnabled
Flag for ScatterElement, will be set to false by default, we do not support m_AxisEnable = true for n...