ArmNN
 24.02
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Gather.hpp"
7 
9 
10 #include <fmt/format.h>
11 
12 namespace armnn
13 {
14 
15 void Gather(const TensorInfo& paramsInfo,
16  const TensorInfo& indicesInfo,
17  const TensorInfo& outputInfo,
18  Decoder<float>& params,
19  const int32_t* indices,
20  Encoder<float>& output,
21  const int32_t axis_int)
22 {
23  IgnoreUnused(outputInfo);
24 
25  const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
26  if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
27  {
28  throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range",
29  axis_int, paramsRank, paramsRank)));
30  }
31  const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32  : static_cast<unsigned int>(axis_int);
33 
34  const TensorShape& paramsShape = paramsInfo.GetShape();
35 
36  // Product of all dimensions to the left side of the axis
37  unsigned int paramsOuterProduct = 1;
38  for (unsigned int i = 0; i < axis; ++i)
39  {
40  paramsOuterProduct *= paramsShape[i];
41  }
42  // Product of all dimensions to the right side of the axis
43  unsigned int paramsInnerProduct = 1;
44  for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
45  {
46  paramsInnerProduct *= paramsShape[k];
47  }
48 
49  unsigned int offset = 0;
50  unsigned int outIndex = 0;
51  for (unsigned int i = 0; i < paramsOuterProduct; ++i)
52  {
53  for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
54  {
55  unsigned int index =
56  (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j])
57  : static_cast<unsigned int>(indices[j]);
58 
59  if (index >= paramsShape[axis])
60  {
61  throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}",
62  index, paramsShape[axis] )));
63  }
64 
65  unsigned int startOffset = (paramsInnerProduct * index) + offset;
66  unsigned int endOffset = startOffset + paramsInnerProduct;
67 
68  for (unsigned int k = startOffset; k < endOffset; ++k)
69  {
70  params[k];
71  float outputValue = params.Get();
72  output[outIndex];
73  output.Set(outputValue);
74  ++outIndex;
75  }
76  }
77  offset += paramsShape[axis] * paramsInnerProduct;
78  }
79 
80  if (outIndex != outputInfo.GetNumElements())
81  {
82  throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex)));
83  }
84 }
85 
86 } //namespace armnn
armnn::Decoder< float >
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
armnn::Encoder::Set
virtual void Set(IType right)=0
WorkloadData.hpp
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Encoder< float >
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
Gather.hpp
armnn::Decoder::Get
virtual IType Get() const =0
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::Gather
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > &params, const int32_t *indices, Encoder< float > &output, const int32_t axis_int)
Definition: Gather.cpp:15