ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017,2022-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Gather.hpp"
7 
9 
10 #include <fmt/format.h>
11 
12 namespace armnn
13 {
14 template<typename I, typename O>
15 void Gather(const TensorInfo& paramsInfo,
16  const TensorInfo& indicesInfo,
17  const TensorInfo& outputInfo,
18  Decoder<I>& params,
19  const int32_t* indices,
20  Encoder<O>& output,
21  const int32_t axis_int)
22 {
23  IgnoreUnused(outputInfo);
24 
25  const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
26  if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
27  {
28  throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range",
29  axis_int, paramsRank, paramsRank)));
30  }
31  const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32  : static_cast<unsigned int>(axis_int);
33 
34  const TensorShape& paramsShape = paramsInfo.GetShape();
35 
36  // Product of all dimensions to the left side of the axis
37  unsigned int paramsOuterProduct = 1;
38  for (unsigned int i = 0; i < axis; ++i)
39  {
40  paramsOuterProduct *= paramsShape[i];
41  }
42  // Product of all dimensions to the right side of the axis
43  unsigned int paramsInnerProduct = 1;
44  for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
45  {
46  paramsInnerProduct *= paramsShape[k];
47  }
48 
49  unsigned int offset = 0;
50  unsigned int outIndex = 0;
51  for (unsigned int i = 0; i < paramsOuterProduct; ++i)
52  {
53  for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
54  {
55  unsigned int index =
56  (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j])
57  : static_cast<unsigned int>(indices[j]);
58 
59  if (index >= paramsShape[axis])
60  {
61  throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}",
62  index, paramsShape[axis] )));
63  }
64 
65  unsigned int startOffset = (paramsInnerProduct * index) + offset;
66  unsigned int endOffset = startOffset + paramsInnerProduct;
67 
68  for (unsigned int k = startOffset; k < endOffset; ++k)
69  {
70  params[k];
71  auto outputValue = params.Get();
72  output[outIndex];
73  output.Set(outputValue);
74  ++outIndex;
75  }
76  }
77  offset += paramsShape[axis] * paramsInnerProduct;
78  }
79 
80  if (outIndex != outputInfo.GetNumElements())
81  {
82  throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex)));
83  }
84 }
85 
86 // Template method instantiation
87 template void Gather(const TensorInfo& paramsInfo,
88  const TensorInfo& indicesInfo,
89  const TensorInfo& outputInfo,
90  Decoder<float>& params,
91  const int32_t* indices,
92  Encoder<float>& output,
93  const int32_t axis_int);
94 
95 template void Gather(const TensorInfo& paramsInfo,
96  const TensorInfo& indicesInfo,
97  const TensorInfo& outputInfo,
98  Decoder<double_t>& params,
99  const int32_t* indices,
100  Encoder<double_t>& output,
101  const int32_t axis_int);
102 } //namespace armnn
virtual IType Get() const =0
virtual void Set(IType right)=0
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< I > &params, const int32_t *indices, Encoder< O > &output, const int32_t axis_int)
Definition: Gather.cpp:15
void IgnoreUnused(Ts &&...)