11 #include <fmt/format.h>
27 return input + update;
29 return input - update;
31 return std::max(input, update);
33 return std::min(input, update);
35 return input * update;
67 unsigned int outterDim = indicesShape[indicesDim - 1];
68 unsigned int innerDim = dimension - outterDim;
71 unsigned int numElementsCount = 1;
72 std::vector<unsigned int> elementInDim(dimension);
73 for (
unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
75 elementInDim[dimIndex - 1] = numElementsCount;
76 numElementsCount *= inputShape[dimIndex - 1];
80 unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
83 unsigned int numIndices = indicesShape[0];
87 if (indicesDim < 1 || updatesDim < 1)
102 if (updatesDim != 1 + innerDim)
107 for (
unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
109 if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
112 fmt::format(
"ScatterNd: input and updates shape not match on dimension {}",
113 dimension - dimBackIndex));
118 std::set<int> indicesSet;
119 std::vector<int> flattenIndices(numIndices);
120 for (
unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
123 int flattenIndex = 0;
125 for (
unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
127 int outterIndexValue = indices.
Get();
130 if (outterIndexValue < 0 || outterIndexValue >=
int(inputShape[outterIdx]))
133 fmt::format(
"ScatterNd: indices {} out of bound [0, {})",
134 outterIndexValue, inputShape[outterIdx]));
137 flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
143 indicesSet.find(flattenIndex) != indicesSet.end())
146 fmt::format(
"ScatterNd: duplicate indices occurs {}", flattenIndex));
149 flattenIndices[indicesIdx] = flattenIndex;
150 indicesSet.insert(flattenIndex);
154 for (
unsigned int idx = 0; idx < inputInfo.
GetNumElements(); ++idx)
156 float inputValue = input.
Get();
158 output.
Set(inputValue);
163 for (
unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
166 int flattenIndex = flattenIndices[indicesIdx];
169 unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
170 for (
unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
172 updates[updatesStartIdx + updatesIdx];
173 input[
static_cast<unsigned int>(flattenIndex) + updatesIdx];
175 output[
static_cast<unsigned int>(flattenIndex) + updatesIdx];
176 output.
Set(updateValue);
207 for (
auto shapeValue : shapeValues)
215 std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end());
216 unsigned int inputElementsNum =
static_cast<unsigned int>(
217 std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>()));
225 unsigned int outterDim = indicesShape[indicesDim - 1];
226 unsigned int innerDim = dimension - outterDim;
229 unsigned int numElementsCount = 1;
230 std::vector<unsigned int> elementInDim(dimension);
231 for (
unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
233 elementInDim[dimIndex - 1] = numElementsCount;
234 numElementsCount *= inputShape[dimIndex - 1];
238 unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
241 unsigned int numIndices = indicesShape[0];
245 if (indicesDim < 1 || updatesDim < 1)
259 if (updatesDim != 1 + innerDim)
264 for (
unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
266 if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
269 fmt::format(
"ScatterNd: input and updates shape not match on dimension {}",
270 dimension - dimBackIndex));
275 std::set<int> indicesSet;
276 std::vector<int> flattenIndices(numIndices);
277 for (
unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
280 int flattenIndex = 0;
282 for (
unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
284 int outterIndexValue = indices.
Get();
287 if (outterIndexValue < 0 || outterIndexValue >=
int(inputShape[outterIdx]))
290 fmt::format(
"ScatterNd: indices {} out of bound [0, {})",
291 outterIndexValue, inputShape[outterIdx]));
294 flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
300 indicesSet.find(flattenIndex) != indicesSet.end())
303 fmt::format(
"ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.",
307 flattenIndices[indicesIdx] = flattenIndex;
308 indicesSet.insert(flattenIndex);
312 for (
unsigned int idx = 0; idx < inputElementsNum; ++idx)
319 for (
unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
322 int flattenIndex = flattenIndices[indicesIdx];
325 unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
326 for (
unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
328 updates[updatesStartIdx + updatesIdx];
330 output[
static_cast<unsigned int>(flattenIndex) + updatesIdx];
331 output.
Set(updateValue);