10 #include <fmt/format.h>
14 template<
typename I,
typename O>
19 const int32_t* indices,
21 const int32_t axis_int)
26 if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
29 axis_int, paramsRank, paramsRank)));
31 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32 :
static_cast<unsigned int>(axis_int);
37 unsigned int paramsOuterProduct = 1;
38 for (
unsigned int i = 0; i < axis; ++i)
40 paramsOuterProduct *= paramsShape[i];
43 unsigned int paramsInnerProduct = 1;
46 paramsInnerProduct *= paramsShape[k];
49 unsigned int offset = 0;
50 unsigned int outIndex = 0;
51 for (
unsigned int i = 0; i < paramsOuterProduct; ++i)
56 (indices[j] < 0) ?
static_cast<unsigned int>(
static_cast<int>(paramsShape[axis]) + indices[j])
57 :
static_cast<unsigned int>(indices[j]);
59 if (index >= paramsShape[axis])
62 index, paramsShape[axis] )));
65 unsigned int startOffset = (paramsInnerProduct * index) + offset;
66 unsigned int endOffset = startOffset + paramsInnerProduct;
68 for (
unsigned int k = startOffset; k < endOffset; ++k)
71 auto outputValue = params.
Get();
73 output.
Set(outputValue);
77 offset += paramsShape[axis] * paramsInnerProduct;
91 const int32_t* indices,
93 const int32_t axis_int);
99 const int32_t* indices,
101 const int32_t axis_int);
virtual IType Get() const =0
virtual void Set(IType right)=0
unsigned int GetNumDimensions() const
unsigned int GetNumElements() const
const TensorShape & GetShape() const
Copyright (c) 2021 ARM Limited and Contributors.
void Gather(const TensorInfo ¶msInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< I > ¶ms, const int32_t *indices, Encoder< O > &output, const int32_t axis_int)
void IgnoreUnused(Ts &&...)