ArmNN
 25.11
Loading...
Searching...
No Matches
BatchToSpaceOperator.cpp File Reference
Include dependency graph for BatchToSpaceOperator.cpp:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertBatchToSpaceToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const BatchToSpaceNdDescriptor *batchToSpaceDescriptor)

Function Documentation

◆ ConvertBatchToSpaceToTosaOperator()

TosaSerializationBasicBlock * ConvertBatchToSpaceToTosaOperator ( const Layer * layer,
const std::vector< const TensorInfo * > & inputs,
const std::vector< const TensorInfo * > & outputs,
const BatchToSpaceNdDescriptor * batchToSpaceDescriptor )

Definition at line 12 of file BatchToSpaceOperator.cpp.

16{
17
18 /*
19 * BatchToSpaceND - TOSA Lowering Overview
20 * --------------------------------------
21 * This operation takes a tensor shaped like [B, D1, D2, ..., DN, C]
22 * and moves data from the batch dimension into the spatial dimensions.
23 *
24 * It essentially reverses the logic of SpaceToBatchND, undoing the folding of spatial data into the batch.
25 *
26 * List of the steps involved:
27 *
28 * 1. Reshape:
29 * - We begin by expanding the batch dimension into block shapes.
30 * Specifically, B is split into: [b1, b2, ..., bN, B’], where B’ is the original batch divided by the product
31 * of block sizes.
32 * - This produces an intermediate shape like:
33 * [b1, b2, ..., bN, B’, D1, D2, ..., DN, C]
34 * e.g. if input is [4, 2, 2, 1] with block size [2,2], then:
35 * Reshape to [2, 2, 1, 2, 2, 1]
36 *
37 * 2. Transpose:
38 * - We rearrange the dimensions so that the blocks align with the spatial dimensions.
39 * - The transpose permutation reorders the tensor to:
40 * [B’, D1, b1, D2, b2, ..., DN, bN, C]
41 * e.g. [2, 2, 1, 2, 2, 1] becomes [1, 2, 2, 2, 2, 1]
42 *
43 * 3. Reshape:
44 * - Each spatial dimension is now expanded:
45 * Di' = Di * bi
46 * - After reshaping, the tensor looks like:
47 * [B’, D1 * b1, D2 * b2, ..., DN * bN, C]
48 * Continuing the example: [1, 2, 2, 2, 2, 1] → [1, 4, 4, 1]
49 *
50 * 4. Slice:
51 * - The final step removes any excess padding that may have existed in the original SpaceToBatchND.
52 * - Begin and end paddings are subtracted from the spatial dimensions.
53 * This restores the original unpadded spatial shape.
54 * e.g. if padded spatial shape was [4,4] and crop sizes were [[0,0],[0,0]] → no slice needed,
55 * but with crops [[1,1],[1,1]] → output becomes [1,2,2,1]
56 */
57
58 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputs.size() == 1,
59 "ConvertBatchToSpaceToTosaOperator: BatchToSpace must have only one input");
60
61 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(outputs.size() == 1,
62 "ConvertBatchToSpaceToTosaOperator: BatchToSpace must have only one output");
63
64 std::string inputName = "input_";
65 std::string outputNameReshape1 = "layer_intermediate1_" + GetUniqueTosaMappingID();
66 std::string outputNameTranspose = "layer_intermediate2_" + GetUniqueTosaMappingID();
67 std::string outputNameReshape2 = "layer_intermediate3_" + GetUniqueTosaMappingID();
68 std::string outputName = "output0_";
69 std::string blockName = "Op_BATCHTOSPACE_block_" + GetUniqueTosaMappingID();
70
71 if (layer != nullptr)
72 {
73 inputName = GenerateUniqueInputName(layer->GetInputSlot(0));
74 outputName = GenerateUniqueOutputName(*layer);
75 }
76 std::vector<TosaSerializationTensor*> tensors;
77 std::vector<TosaSerializationOperator*> operators;
78
79 const auto& crops = batchToSpaceDescriptor->m_Crops;
80 const auto& blockShape = batchToSpaceDescriptor->m_BlockShape;
81 const std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape());
82 const DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
83 const size_t inputRank = inputShape.size();
84 const size_t blockRank = blockShape.size();
85 const size_t remRank = inputRank - blockRank - 1;
86
87 if (inputName.find("input_") != std::string::npos)
88 {
89 tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
90 }
91
92 if (inputRank < 2 || blockRank < 1 || blockShape.size() != crops.size())
93 {
94 throw armnn::Exception("ConvertBatchToSpaceToTosaOperator: Unsupported BatchToSpaceND config.");
95 return nullptr;
96 }
97
98 if (layer != nullptr)
99 {
100 inputName = GenerateUniqueInputName(layer->GetInputSlot(0));
101 outputName = GenerateUniqueOutputName(*layer);
102 }
103
104 // calculate the total number of blockElements
105 int32_t blockNumElems = 1;
106 for (size_t i = 0; i < blockShape.size(); ++i)
107 {
108 blockNumElems *= static_cast<int32_t>(blockShape[i]);
109 }
110 // using the input batch work out the new batch value
111 const int32_t inputBatch = inputShape[0];
112 const int32_t newBatch = inputBatch / blockNumElems;
113
114 // Reshape input to [block_shape..., batch / product(block), input_dims[1..]]
115 std::vector<int32_t> reshape1Shape;
116 for (size_t i = 0; i < blockRank; ++i)
117 {
118 reshape1Shape.push_back(static_cast<int32_t>(blockShape[i]));
119 }
120
121 reshape1Shape.push_back(newBatch);
122 reshape1Shape.insert(reshape1Shape.end(), inputShape.begin() + 1, inputShape.end());
123
124 tensors.push_back(new TosaSerializationTensor(outputNameReshape1, reshape1Shape, inputDType, {}));
125 TosaReshapeAttribute reshape1Attr(reshape1Shape);
126
127 operators.push_back(new TosaSerializationOperator(Op_RESHAPE,
128 Attribute_ReshapeAttribute,
129 &reshape1Attr,
130 {inputName},
131 {outputNameReshape1}));
132
133 // interleave block dimensions with spatial dims
134 std::vector<int32_t> perm;
135 perm.push_back(static_cast<int32_t>(blockRank));
136 for (size_t i = 0; i < blockRank; ++i)
137 {
138 perm.push_back(static_cast<int32_t>(blockRank + 1 + i));
139 perm.push_back(static_cast<int32_t>(i));
140 }
141 for (size_t i = 0; i < remRank; ++i)
142 {
143 perm.push_back(static_cast<int32_t>(2 * blockRank + 1 + i));
144 }
145
146 std::vector<int32_t> transposeShape(perm.size());
147 for (size_t i = 0; i < perm.size(); ++i)
148 {
149 transposeShape[i] = reshape1Shape[static_cast<size_t>(perm[i])];
150 }
151
152 tensors.push_back(new TosaSerializationTensor(outputNameTranspose, transposeShape, inputDType, {}));
153 TosaTransposeAttribute transposeAttr(perm);
154
155 operators.push_back(new TosaSerializationOperator(Op_TRANSPOSE,
156 Attribute_TransposeAttribute,
157 &transposeAttr,
158 {outputNameReshape1},
159 {outputNameTranspose}));
160
161 // Reshape data to [new_batch, spatial dims * block, remainder]
162 std::vector<int32_t> reshape2Shape;
163 reshape2Shape.push_back(newBatch);
164
165 for (size_t i = 0; i < blockRank; ++i)
166 {
167 int32_t value = inputShape[1 + i] * static_cast<int32_t>(blockShape[i]);
168 reshape2Shape.push_back(value);
169 }
170
171 for (size_t i = 0; i < remRank; ++i)
172 {
173 reshape2Shape.push_back(inputShape[1 + blockRank + i]);
174 }
175
176 tensors.push_back(new TosaSerializationTensor(outputNameReshape2, reshape2Shape, inputDType, {}));
177 TosaReshapeAttribute reshape2Attr(reshape2Shape);
178
179 operators.push_back(new TosaSerializationOperator(Op_RESHAPE,
180 Attribute_ReshapeAttribute,
181 &reshape2Attr,
182 {outputNameTranspose},
183 {outputNameReshape2}));
184
185 // slice the data to remove cropped areas from spatial dims
186 std::vector<int32_t> begin(reshape2Shape.size(), 0);
187 std::vector<int32_t> slicedShape = reshape2Shape;
188
189 for (size_t i = 0; i < crops.size(); ++i)
190 {
191 begin[1 + i] = static_cast<int32_t>(crops[i].first);
192 slicedShape[1 + i] -= static_cast<int32_t>(crops[i].first + crops[i].second);
193 }
194
195 tensors.push_back(new TosaSerializationTensor(outputName, slicedShape, inputDType, {}));
196
197 TosaSliceAttribute sliceAttr(begin, slicedShape);
198 operators.push_back(new TosaSerializationOperator(Op_SLICE,
199 Attribute_SliceAttribute,
200 &sliceAttr,
201 {outputNameReshape2},
202 {outputName}));
203
204 std::vector<int32_t> expectedShape = GetTosaTensorShape(outputs[0]->GetShape());
205
206 if (slicedShape != expectedShape)
207 {
208 throw armnn::Exception("ConvertSpaceToBatchToTosaOperator: Mismatch expected output and generated shape differ");
209 }
210 return new TosaSerializationBasicBlock(blockName, mainName, operators, tensors, {inputName}, {outputName});
211}
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
Base class for all ArmNN exceptions so that users can filter to just those.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
std::vector< unsigned int > m_BlockShape
Block shape values.
std::vector< std::pair< unsigned int, unsigned int > > m_Crops
The values to crop from the input dimension.

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE, ArmNNToDType(), GenerateUniqueInputName(), GenerateUniqueOutputName(), Layer::GetInputSlot(), GetTosaTensorShape(), GetUniqueTosaMappingID(), BatchToSpaceNdDescriptor::m_BlockShape, BatchToSpaceNdDescriptor::m_Crops, and mainName.

Referenced by GetTosaMapping().