13 const std::vector<const TensorInfo*>& inputs,
14 const std::vector<const TensorInfo*>& outputs,
58 "ConvertSpaceToBatchToTosaOperator: SpaceToBatch must have only one input");
61 "ConvertSpaceToBatchToTosaOperator: SpaceToBatch must have only one output");
63 std::string inputName =
"input_";
67 std::string outputName =
"output0_";
76 const auto& paddings = spaceToBatchDescriptor->
m_PadList;
77 const auto& blockShape = spaceToBatchDescriptor->
m_BlockShape;
78 const unsigned int inputRank = inputs[0]->GetShape().GetNumDimensions();
79 const unsigned int blockRank =
static_cast<unsigned int>(blockShape.size());
82 if (inputRank <= blockRank)
84 throw armnn::Exception(
"ConvertSpaceToBatchToTosaOperator: input rank must be greater than block rank");
87 std::vector<TosaSerializationTensor*> tensors;
88 std::vector<TosaSerializationOperator*> operators;
92 std::vector<int32_t> a0Pad(2 * inputRank, 0);
93 std::vector<int32_t> paddedShape = inputShape;
95 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
97 if (inputName.find(
"input_") != std::string::npos)
99 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
102 for (
size_t i = 0; i < blockShape.size(); ++i)
104 int32_t loPad =
static_cast<int32_t
>(paddings[i].first);
105 int32_t hiPad =
static_cast<int32_t
>(paddings[i].second);
106 size_t dimIndex = i + 1;
107 a0Pad[2 * dimIndex] = loPad;
108 a0Pad[2 * dimIndex + 1] = hiPad;
109 paddedShape[dimIndex] = inputShape[dimIndex] + loPad + hiPad;
112 std::string padOutput = outputNamePad +
"_padded";
114 tensors.push_back(
new TosaSerializationTensor(padOutput, paddedShape, inputDType, {}));
117 float padValue = 0.0f;
118 if (inputs[0]->IsQuantized())
120 padValue =
static_cast<float>(inputs[0]->GetQuantizationOffset()) * inputs[0]->GetQuantizationScale();
123 TosaPadAttribute padAttr(a0Pad, 0, padValue);
124 operators.push_back(
new TosaSerializationOperator(Op_PAD,
125 Attribute_PadAttribute,
131 std::vector<int32_t> reshape1;
133 reshape1.push_back(inputShape[0]);
136 int32_t blockNumElems = 1;
139 for (
size_t i = 0; i < blockShape.size(); ++i)
141 int32_t paddedDim = paddedShape[i + 1];
142 int32_t blockDim =
static_cast<int32_t
>(blockShape[i]);
143 if (paddedDim % blockDim != 0)
145 throw armnn::Exception(
"ConvertSpaceToBatchToTosaOperator: padded spatial dim not divisible by block size");
147 reshape1.push_back(paddedDim / blockDim);
148 reshape1.push_back(blockDim);
150 blockNumElems *= blockDim;
154 for (
size_t i = 1 + blockShape.size(); i < inputShape.size(); ++i)
156 reshape1.push_back(inputShape[i]);
159 tensors.push_back(
new TosaSerializationTensor(outputNameReshape1, reshape1, inputDType, {}));
160 TosaReshapeAttribute reshapeAttr(reshape1);
161 operators.push_back(
new TosaSerializationOperator(Op_RESHAPE,
162 Attribute_ReshapeAttribute,
165 {outputNameReshape1}));
167 std::vector<int32_t> transposeVec;
170 for (
size_t i = 0; i < blockShape.size(); ++i)
172 transposeVec.push_back(
static_cast<int32_t
>(1 + 2 * i + 1));
175 transposeVec.push_back(0);
178 for (
size_t i = 0; i < blockShape.size(); ++i)
180 transposeVec.push_back(
static_cast<int32_t
>(1 + 2 * i));
184 for (
size_t i = 1 + 2 * blockShape.size(); i < reshape1.size(); ++i)
186 transposeVec.push_back(
static_cast<int32_t
>(i));
189 std::vector<int32_t> transposeShape(transposeVec.size());
190 for (
size_t i = 0; i < transposeVec.size(); ++i)
192 transposeShape[i] = reshape1[
static_cast<size_t>(transposeVec[i])];
194 tensors.push_back(
new TosaSerializationTensor(outputNameTranspose, transposeShape, inputDType, {}));
196 TosaTransposeAttribute transposeAttr(transposeVec);
198 operators.push_back(
new TosaSerializationOperator(Op_TRANSPOSE,
199 Attribute_TransposeAttribute,
201 {outputNameReshape1},
202 {outputNameTranspose}));
205 std::vector<int32_t> reshape2;
207 const int32_t newBatch =
static_cast<int32_t
>(inputShape[0]) *
static_cast<int32_t
>(blockNumElems);
208 reshape2.push_back(newBatch);
211 for (
size_t i = 0; i < blockShape.size(); ++i)
213 int32_t paddedDim = paddedShape[i + 1];
214 int32_t blockDim =
static_cast<int32_t
>(blockShape[i]);
216 if (blockDim == 0 || paddedDim % blockDim != 0)
218 throw armnn::Exception(
"ConvertSpaceToBatchToTosaOperator: Invalid block Shape or padding in final reshape");
221 reshape2.push_back(paddedDim / blockDim);
225 reshape2.push_back(inputShape.back());
226 tensors.push_back(
new TosaSerializationTensor(outputName, reshape2, inputDType, {}));
228 TosaReshapeAttribute reshape2Attr(reshape2);
229 operators.push_back(
new TosaSerializationOperator(Op_RESHAPE,
230 Attribute_ReshapeAttribute,
232 {outputNameTranspose},
237 if (reshape2 != expectedShape)
239 throw armnn::Exception(
"ConvertSpaceToBatchToTosaOperator: Mismatch expected output and generated shape differ");
242 return new TosaSerializationBasicBlock(blockName,
mainName, operators, tensors, {inputName}, {outputName});