10 #include <fmt/format.h>
14 template<
typename I,
typename O>
19 const int32_t* indices,
21 const int32_t axis_int)
26 if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
29 axis_int, paramsRank, paramsRank)));
31 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32 :
static_cast<unsigned int>(axis_int);
37 unsigned int paramsOuterProduct = 1;
38 for (
unsigned int i = 0; i < axis; ++i)
40 paramsOuterProduct *= paramsShape[i];
43 unsigned int paramsInnerProduct = 1;
46 paramsInnerProduct *= paramsShape[k];
49 unsigned int offset = 0;
50 unsigned int outIndex = 0;
51 for (
unsigned int i = 0; i < paramsOuterProduct; ++i)
56 (indices[j] < 0) ?
static_cast<unsigned int>(
static_cast<int>(paramsShape[axis]) + indices[j])
57 :
static_cast<unsigned int>(indices[j]);
59 if (index >= paramsShape[axis])
62 index, paramsShape[axis] )));
65 unsigned int startOffset = (paramsInnerProduct * index) + offset;
66 unsigned int endOffset = startOffset + paramsInnerProduct;
68 for (
unsigned int k = startOffset; k < endOffset; ++k)
71 auto outputValue = params.
Get();
73 output.
Set(outputValue);
77 offset += paramsShape[axis] * paramsInnerProduct;
87 template void Gather(
const TensorInfo& paramsInfo,
88 const TensorInfo& indicesInfo,
89 const TensorInfo& outputInfo,
90 Decoder<float>& params,
91 const int32_t* indices,
92 Encoder<float>& output,
93 const int32_t axis_int);
95 template void Gather(
const TensorInfo& paramsInfo,
96 const TensorInfo& indicesInfo,
97 const TensorInfo& outputInfo,
98 Decoder<double_t>& params,
99 const int32_t* indices,
100 Encoder<double_t>& output,
101 const int32_t axis_int);