6 #include <fmt/format.h>
36 std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->
Map());
38 const int32_t* indicesDataPtr =
reinterpret_cast<int32_t*
>(inputs[1]->Map());
39 std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.
GetNumElements());
49 std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
56 TensorShape paramsShape = inputInfo0.
GetShape();
57 std::vector<unsigned int> flattenedCoeff(keyIndices[
"ND"], 1);
58 for (
unsigned int i = 1; i < keyIndices[
"ND"]; ++i)
60 flattenedCoeff[i-1] = paramsShape[i];
62 for (
unsigned int i = keyIndices[
"ND"]-1; i > 0; --i)
64 flattenedCoeff[i-1] *= flattenedCoeff[i];
70 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
71 std::vector<int32_t> flattenedIndices(flattenedIndices_Info.
GetNumElements(), 0);
74 for (
unsigned int i = 0; i < keyIndices[
"W"]; ++i)
76 for (
unsigned int j = 0; j < keyIndices[
"ND"]; ++j)
78 flattenedIndices[i] += indices[i * keyIndices[
"ND"] + j] *
static_cast<int32_t
>(flattenedCoeff[j]);
85 params_K_C_Info.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
89 indices_N_W_Info.
SetShape({ keyIndices[
"N"], keyIndices[
"W"] });
94 outputGather_Info.
SetShape({ keyIndices[
"N"], keyIndices[
"W"], keyIndices[
"C"] });
97 Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info,
98 *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0);