13 const std::vector<const TensorInfo*>& inputs,
14 const std::vector<const TensorInfo*>& outputs,
59 "ConvertBatchToSpaceToTosaOperator: BatchToSpace must have only one input");
62 "ConvertBatchToSpaceToTosaOperator: BatchToSpace must have only one output");
64 std::string inputName =
"input_";
68 std::string outputName =
"output0_";
76 std::vector<TosaSerializationTensor*> tensors;
77 std::vector<TosaSerializationOperator*> operators;
79 const auto& crops = batchToSpaceDescriptor->
m_Crops;
80 const auto& blockShape = batchToSpaceDescriptor->
m_BlockShape;
82 const DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
83 const size_t inputRank = inputShape.size();
84 const size_t blockRank = blockShape.size();
85 const size_t remRank = inputRank - blockRank - 1;
87 if (inputName.find(
"input_") != std::string::npos)
89 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
92 if (inputRank < 2 || blockRank < 1 || blockShape.size() != crops.size())
94 throw armnn::Exception(
"ConvertBatchToSpaceToTosaOperator: Unsupported BatchToSpaceND config.");
105 int32_t blockNumElems = 1;
106 for (
size_t i = 0; i < blockShape.size(); ++i)
108 blockNumElems *=
static_cast<int32_t
>(blockShape[i]);
111 const int32_t inputBatch = inputShape[0];
112 const int32_t newBatch = inputBatch / blockNumElems;
115 std::vector<int32_t> reshape1Shape;
116 for (
size_t i = 0; i < blockRank; ++i)
118 reshape1Shape.push_back(
static_cast<int32_t
>(blockShape[i]));
121 reshape1Shape.push_back(newBatch);
122 reshape1Shape.insert(reshape1Shape.end(), inputShape.begin() + 1, inputShape.end());
124 tensors.push_back(
new TosaSerializationTensor(outputNameReshape1, reshape1Shape, inputDType, {}));
125 TosaReshapeAttribute reshape1Attr(reshape1Shape);
127 operators.push_back(
new TosaSerializationOperator(Op_RESHAPE,
128 Attribute_ReshapeAttribute,
131 {outputNameReshape1}));
134 std::vector<int32_t> perm;
135 perm.push_back(
static_cast<int32_t
>(blockRank));
136 for (
size_t i = 0; i < blockRank; ++i)
138 perm.push_back(
static_cast<int32_t
>(blockRank + 1 + i));
139 perm.push_back(
static_cast<int32_t
>(i));
141 for (
size_t i = 0; i < remRank; ++i)
143 perm.push_back(
static_cast<int32_t
>(2 * blockRank + 1 + i));
146 std::vector<int32_t> transposeShape(perm.size());
147 for (
size_t i = 0; i < perm.size(); ++i)
149 transposeShape[i] = reshape1Shape[
static_cast<size_t>(perm[i])];
152 tensors.push_back(
new TosaSerializationTensor(outputNameTranspose, transposeShape, inputDType, {}));
153 TosaTransposeAttribute transposeAttr(perm);
155 operators.push_back(
new TosaSerializationOperator(Op_TRANSPOSE,
156 Attribute_TransposeAttribute,
158 {outputNameReshape1},
159 {outputNameTranspose}));
162 std::vector<int32_t> reshape2Shape;
163 reshape2Shape.push_back(newBatch);
165 for (
size_t i = 0; i < blockRank; ++i)
167 int32_t value = inputShape[1 + i] *
static_cast<int32_t
>(blockShape[i]);
168 reshape2Shape.push_back(value);
171 for (
size_t i = 0; i < remRank; ++i)
173 reshape2Shape.push_back(inputShape[1 + blockRank + i]);
176 tensors.push_back(
new TosaSerializationTensor(outputNameReshape2, reshape2Shape, inputDType, {}));
177 TosaReshapeAttribute reshape2Attr(reshape2Shape);
179 operators.push_back(
new TosaSerializationOperator(Op_RESHAPE,
180 Attribute_ReshapeAttribute,
182 {outputNameTranspose},
183 {outputNameReshape2}));
186 std::vector<int32_t> begin(reshape2Shape.size(), 0);
187 std::vector<int32_t> slicedShape = reshape2Shape;
189 for (
size_t i = 0; i < crops.size(); ++i)
191 begin[1 + i] =
static_cast<int32_t
>(crops[i].first);
192 slicedShape[1 + i] -=
static_cast<int32_t
>(crops[i].first + crops[i].second);
195 tensors.push_back(
new TosaSerializationTensor(outputName, slicedShape, inputDType, {}));
197 TosaSliceAttribute sliceAttr(begin, slicedShape);
198 operators.push_back(
new TosaSerializationOperator(Op_SLICE,
199 Attribute_SliceAttribute,
201 {outputNameReshape2},
206 if (slicedShape != expectedShape)
208 throw armnn::Exception(
"ConvertSpaceToBatchToTosaOperator: Mismatch expected output and generated shape differ");
210 return new TosaSerializationBasicBlock(blockName,
mainName, operators, tensors, {inputName}, {outputName});