ArmNN
 25.11
Loading...
Searching...
No Matches
Gather.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017,2022-2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
9
10#include <fmt/format.h>
11
12namespace armnn
13{
14template<typename I, typename O>
15void Gather(const TensorInfo& paramsInfo,
16 const TensorInfo& indicesInfo,
17 const TensorInfo& outputInfo,
18 Decoder<I>& params,
19 const int32_t* indices,
20 Encoder<O>& output,
21 const int32_t axis_int)
22{
23 IgnoreUnused(outputInfo);
24
25 const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
26 if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
27 {
28 throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range",
29 axis_int, paramsRank, paramsRank)));
30 }
31 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32 : static_cast<unsigned int>(axis_int);
33
34 const TensorShape& paramsShape = paramsInfo.GetShape();
35
36 // Product of all dimensions to the left side of the axis
37 unsigned int paramsOuterProduct = 1;
38 for (unsigned int i = 0; i < axis; ++i)
39 {
40 paramsOuterProduct *= paramsShape[i];
41 }
42 // Product of all dimensions to the right side of the axis
43 unsigned int paramsInnerProduct = 1;
44 for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
45 {
46 paramsInnerProduct *= paramsShape[k];
47 }
48
49 unsigned int offset = 0;
50 unsigned int outIndex = 0;
51 for (unsigned int i = 0; i < paramsOuterProduct; ++i)
52 {
53 for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
54 {
55 unsigned int index =
56 (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j])
57 : static_cast<unsigned int>(indices[j]);
58
59 if (index >= paramsShape[axis])
60 {
61 throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}",
62 index, paramsShape[axis] )));
63 }
64
65 unsigned int startOffset = (paramsInnerProduct * index) + offset;
66 unsigned int endOffset = startOffset + paramsInnerProduct;
67
68 for (unsigned int k = startOffset; k < endOffset; ++k)
69 {
70 params[k];
71 auto outputValue = params.Get();
72 output[outIndex];
73 output.Set(outputValue);
74 ++outIndex;
75 }
76 }
77 offset += paramsShape[axis] * paramsInnerProduct;
78 }
79
80 if (outIndex != outputInfo.GetNumElements())
81 {
82 throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex)));
83 }
84}
85
86// Template method instantiation
87template void Gather(const TensorInfo& paramsInfo,
88 const TensorInfo& indicesInfo,
89 const TensorInfo& outputInfo,
90 Decoder<float>& params,
91 const int32_t* indices,
92 Encoder<float>& output,
93 const int32_t axis_int);
94
95template void Gather(const TensorInfo& paramsInfo,
96 const TensorInfo& indicesInfo,
97 const TensorInfo& outputInfo,
98 Decoder<double_t>& params,
99 const int32_t* indices,
100 Encoder<double_t>& output,
101 const int32_t axis_int);
102} //namespace armnn
virtual IType Get() const =0
virtual void Set(IType right)=0
const TensorShape & GetShape() const
Definition Tensor.hpp:193
unsigned int GetNumDimensions() const
Definition Tensor.hpp:197
unsigned int GetNumElements() const
Definition Tensor.hpp:198
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)