11 #include <fmt/format.h>
21 const unsigned int& concatAxis,
22 unsigned int inputIndex,
23 unsigned int& mergeDimOrigin)
31 "The number of dimensions: {0} for input tensors of the "
32 "concatenation op should be {1} {2}",
38 for (
unsigned int j = 0; j < concatAxis; ++j)
44 mergeDimOrigin += inputTensorInfo.
GetShape()[concatAxis];
46 for (
unsigned int j = concatAxis + 1; j < inputRank; ++j)
53 const std::set<unsigned int>& axisSet,
57 std::vector<unsigned int> outputShapeVector;
58 bool dimensionFound =
false;
59 unsigned int size = 1;
63 dimensionFound =
false;
64 for (
unsigned int axis: axisSet)
68 dimensionFound =
true;
75 size *= inputTensorInfo.
GetShape()[i];
79 outputShapeVector.push_back(inputTensorInfo.
GetShape()[i]);
86 outputShapeVector.push_back(1);
109 std::vector<unsigned int> outputShapeVector;
121 int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
122 ((start - stop) - stride - 1) / -stride;
124 newSize = std::max(0, newSize);
126 outputShapeVector.push_back(
static_cast<unsigned int>(newSize));