ArmNN
 25.11
Loading...
Searching...
No Matches
BatchMatMulOperator.hpp File Reference
Include dependency graph for BatchMatMulOperator.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertBatchMatMulToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const BatchMatMulDescriptor *descriptor=nullptr)

Function Documentation

◆ ConvertBatchMatMulToTosaOperator()

TosaSerializationBasicBlock * ConvertBatchMatMulToTosaOperator ( const Layer * layer,
const std::vector< const TensorInfo * > & inputs,
const std::vector< const TensorInfo * > & outputs,
const BatchMatMulDescriptor * descriptor = nullptr )

Definition at line 14 of file BatchMatMulOperator.cpp.

18{
19 if (descriptor->m_AdjointX || descriptor->m_AdjointY )
20 {
21 throw Exception("Support for adjoint not implemented.");
22 }
24 {
25 throw Exception("MatMul only supported in the last 2 dimensions");
26 }
27
28 std::string input0Name = std::string("input_0");
29 std::string input1Name = std::string("input_1");
30 std::string outputName = std::string("output_0");
31 std::string outputReshape0Name = std::string("layer_intermediate0_") + GetUniqueTosaMappingID();
32 std::string outputReshape1Name = std::string("layer_intermediate0_") + GetUniqueTosaMappingID();
33 std::string outputTranspose0Name = std::string("layer_intermediate1_") + GetUniqueTosaMappingID();
34 std::string outputTranspose1Name = std::string("layer_intermediate1_") + GetUniqueTosaMappingID();
35
36 std::string blockName = std::string("Op_BATCHMATMUL_block_") + GetUniqueTosaMappingID();
37
38 // If a layer is present then the block will be used for execution, so input and output names need to be determined
39 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
40 if(layer != nullptr)
41 {
42 // Get the layer connected to the input slot and determine unique tensor names.
43 input0Name = GenerateUniqueInputName(layer->GetInputSlot(0));
44 input1Name = GenerateUniqueInputName(layer->GetInputSlot(1));
45 outputName = GenerateUniqueOutputName(*layer);
46 }
47
48 // Assumes both input types are same data type
49 DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
50 bool isInputInt8 = (inputDType == DType_INT8);
51 bool isInputInt16 = (inputDType == DType_INT16);
52
53 std::vector<TosaSerializationTensor*> tensors;
54 std::vector<TosaSerializationOperator*> operators;
55
56 // Only add input tensors if connected layer is an input layer.
57 // As intermediate or constant tensors will be created separately.
58 // There also can't be duplicate tensor.
59 if(input0Name.find("input_") != std::string::npos)
60 {
61 std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
62 tensors.push_back(new TosaSerializationTensor(input0Name, inputShape0, inputDType, {}));
63 }
64 if(input1Name.find("input_") != std::string::npos)
65 {
66 std::vector<int32_t> inputShape1 = GetTosaTensorShape(inputs[1]->GetShape());
67 tensors.push_back(new TosaSerializationTensor(input1Name, inputShape1, inputDType, {}));
68 }
69
70 std::string input0TransposeName = input0Name;
71 std::string input1TransposeName = input1Name;
72 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
73
74 std::string input0MatMulName = input0Name;
75 std::string input1MatMulName = input1Name;
76
77 // *** ADD OP STEPS ***
78
79 // ADD a RESHAPE OPs if BATCH DIMS > 1
80 // RESHAPE input 1
81 std::vector<int32_t> targetShape0 = GetTosaTensorShape(outputs[0]->GetShape());
82 uint32_t input0Dimensions = inputs[0]->GetNumDimensions();
83 if (input0Dimensions > 3)
84 {
85 uint32_t x = 1;
86 for (uint32_t i = 0; i < (input0Dimensions - 2); ++i)
87 {
88 x *=(inputs[0]->GetShape()[i]);
89 }
90
91 targetShape0 = {static_cast<int32_t>(x),
92 static_cast<int32_t>(inputs[0]->GetShape()[input0Dimensions - 2]),
93 static_cast<int32_t>(inputs[0]->GetShape()[input0Dimensions - 1])};
94
95 TosaReshapeAttribute attribute(targetShape0);
96
97 auto* input0ReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
98 Attribute_ReshapeAttribute,
99 &attribute,
100 {input0Name},
101 {outputReshape0Name});
102
103 operators.push_back(input0ReshapeOp);
104 tensors.push_back(new TosaSerializationTensor(outputReshape0Name, targetShape0, inputDType, {}));
105 input0TransposeName = outputReshape0Name;
106 input0MatMulName = outputReshape0Name;
107 }
108
109 // RESHAPE input 2
110 std::vector<int32_t> targetShape1 = GetTosaTensorShape(outputs[0]->GetShape());
111 uint32_t input1Dimensions = inputs[1]->GetNumDimensions();
112 if (input1Dimensions > 3)
113 {
114 uint32_t x = 1;
115 for (uint32_t i = 0; i < (input1Dimensions - 2); i++)
116 {
117 x *= (inputs[1]->GetShape()[i]);
118 }
119
120 targetShape1 = {static_cast<int32_t>(x),
121 static_cast<int32_t>(inputs[1]->GetShape()[input1Dimensions - 2]),
122 static_cast<int32_t>(inputs[1]->GetShape()[input1Dimensions - 1])};
123
124 TosaReshapeAttribute attribute(targetShape1);
125
126 auto* input1ReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
127 Attribute_ReshapeAttribute,
128 &attribute,
129 {input1Name},
130 {outputReshape1Name});
131
132 operators.push_back(input1ReshapeOp);
133 tensors.push_back(new TosaSerializationTensor(outputReshape1Name, targetShape1, inputDType, {}));
134 input1TransposeName = outputReshape1Name;
135 input1MatMulName = outputReshape1Name;
136 }
137 bool needsReshape = input0Dimensions > 3 || input1Dimensions > 3;
138
139 // ADD a TRANSPOSE OP for one/both inputs if transpose set to true
140 if (descriptor->m_TransposeX)
141 {
142 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->m_DataLayoutX,
143 inputs[0]->GetShape());
144 std::vector<int32_t> mappings(permuteVec.begin(),
145 permuteVec.end());
146 if (input0Dimensions > 3)
147 {
148 auto input0BatchedDims = input0Dimensions - 3;
149 mappings = {static_cast<int>(permuteVec[0]),
150 static_cast<int>(permuteVec[input0Dimensions - 2] - input0BatchedDims),
151 static_cast<int>(permuteVec[input0Dimensions - 1] - input0BatchedDims)};
152 }
153
154 TosaTransposeAttribute transposeAttribute(mappings);
155
156 TosaSerializationOperator *transposeOp = new TosaSerializationOperator(Op_TRANSPOSE,
157 Attribute_TransposeAttribute,
158 &transposeAttribute,
159 {input0TransposeName},
160 {outputTranspose0Name});
161
162 std::vector<int32_t> transpose0Shape =
163 {
164 targetShape0[static_cast<unsigned int>(mappings[0])],
165 targetShape0[static_cast<unsigned int>(mappings[1])],
166 targetShape0[static_cast<unsigned int>(mappings[2])]
167 };
168
169 operators.push_back(transposeOp);
170 tensors.push_back(new TosaSerializationTensor(outputTranspose0Name, transpose0Shape, inputDType, {}));
171 input0MatMulName = outputTranspose0Name;
172 }
173
174 if (descriptor->m_TransposeY)
175 {
176 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->m_DataLayoutY,
177 inputs[1]->GetShape());
178
179 std::vector<int32_t> mappings(permuteVec.begin(),
180 permuteVec.end());
181
182 auto input1BatchedDims = input1Dimensions - 3;
183 if (input1Dimensions > 3)
184 {
185 mappings = {static_cast<int>(permuteVec[0]),
186 static_cast<int>(permuteVec[input1Dimensions - 2] - input1BatchedDims),
187 static_cast<int>(permuteVec[input1Dimensions - 1] - input1BatchedDims)};
188 }
189
190 TosaTransposeAttribute transposeAttribute(mappings);
191
192 TosaSerializationOperator *transposeOp = new TosaSerializationOperator(Op_TRANSPOSE,
193 Attribute_TransposeAttribute,
194 &transposeAttribute,
195 {input1TransposeName},
196 {outputTranspose1Name});
197 std::vector<int32_t> transpose1Shape =
198 {
199 targetShape1[static_cast<unsigned int>(mappings[0])],
200 targetShape1[static_cast<unsigned int>(mappings[1])],
201 targetShape1[static_cast<unsigned int>(mappings[2])]
202 };
203
204 operators.push_back(transposeOp);
205 tensors.push_back(new TosaSerializationTensor(outputTranspose1Name, transpose1Shape, inputDType, {}));
206 input1MatMulName = outputTranspose1Name;
207 }
208
209 // ADD MAT MUL layer
210 std::string matMulOutputStr = needsReshape || isInputInt8 || isInputInt16 ?
211 std::string("layer_intermediate2_") + GetUniqueTosaMappingID() : outputName;
212
213 TosaMatMulAttribute matMulAttribute(0,0); // input0_zp, input1_zp
214 DType matMulOutDType = ArmNNToDType(inputs[1]->GetDataType());
215 if (isInputInt8)
216 {
217 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
218 matMulOutDType = DType_INT32;
219 }
220 if (isInputInt16)
221 {
222 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
223 matMulOutDType = DType_INT48;
224 }
225 TosaSerializationOperator* matMulOp = new TosaSerializationOperator(Op_MATMUL,
226 Attribute_MatMulAttribute,
227 &matMulAttribute,
228 {input0MatMulName, input1MatMulName},
229 {matMulOutputStr});
230
231 uint32_t outputDimensions = outputs[0]->GetNumDimensions();
232 if (outputDimensions > 3)
233 {
234 uint32_t x = 1;
235 for (uint32_t i = 0; i < (outputDimensions - 2); ++i)
236 {
237 x *=(outputs[0]->GetShape()[i]);
238 }
239
240 outputShape0 = {static_cast<int32_t>(x),
241 static_cast<int32_t>(outputs[0]->GetShape()[outputDimensions - 2]),
242 static_cast<int32_t>(outputs[0]->GetShape()[outputDimensions - 1])};
243 }
244
245 operators.push_back(matMulOp);
246 tensors.push_back(new TosaSerializationTensor(matMulOutputStr, outputShape0, matMulOutDType, {}));
247
248 std::string outputRescale = needsReshape ?
249 std::string("layer_intermediate3_") + GetUniqueTosaMappingID() : outputName;
250 std::string inputReshape2Name = isInputInt8 || isInputInt16 ? outputRescale : matMulOutputStr;
251
252 // ADD Rescale layer if it is int8
253 if (isInputInt8 || isInputInt16)
254 {
255 bool scale32 = isInputInt16 ? false : true;
256 bool doubleRound = isInputInt16 ? false : true;
257
258 int32_t output_zp = outputs[0]->GetQuantizationOffset();
259 double output_scale = outputs[0]->GetQuantizationScales()[0];
260 double input_scale = inputs[0]->GetQuantizationScales()[0];
261 const std::vector<float>& weight_scales = inputs[1]->GetQuantizationScales();
262
263 TosaSerializationOperator* rescaleOp = nullptr;
265 outputRescale,
266 0,
267 output_zp,
268 false,
269 false,
270 doubleRound,
271 scale32,
272 input_scale,
273 output_scale,
274 weight_scales,
275 &rescaleOp);
276
277 tensors.push_back(new TosaSerializationTensor(outputRescale,
278 outputShape0,
279 inputDType, {}));
280
281 operators.push_back(rescaleOp);
282 }
283
284 // ADD a RESHAPE back to expected rank
285 if (needsReshape)
286 {
287 const std::vector<int32_t>& targetShape = GetTosaTensorShape(TensorShape(outputs[0]->GetShape()));
288 TosaReshapeAttribute attribute(targetShape);
289
290 auto* outputReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
291 Attribute_ReshapeAttribute,
292 &attribute,
293 {inputReshape2Name},
294 {outputName});
295
296 operators.push_back(outputReshapeOp);
297 tensors.push_back(new TosaSerializationTensor(outputName, targetShape, inputDType, {}));
298 }
299
300 return new TosaSerializationBasicBlock(blockName, // name
301 mainName, // region name
302 {operators}, // operators
303 tensors, // tensors
304 {input0Name, input1Name}, // inputs
305 {outputName}); // outputs
306}
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
void CreateRescaleTosaOperatorForWeights(const std::string &inputName, const std::string &outputName, int32_t input_zp, int32_t output_zp, bool input_unsigned, bool output_unsigned, bool double_round, bool scale32, double input_scale, double output_scale, const std::vector< float > &weight_scales, TosaSerializationOperator **op)
Creates a TOSA rescale operator for weight tensors.
Base class for all ArmNN exceptions so that users can filter to just those.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)

References ArmNNToDType(), CreateRescaleTosaOperatorForWeights(), GenerateUniqueInputName(), GenerateUniqueOutputName(), Layer::GetInputSlot(), BatchMatMulDescriptor::GetPermuteVec(), GetTosaTensorShape(), GetUniqueTosaMappingID(), BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, BatchMatMulDescriptor::m_TransposeY, mainName, and armnn::NCHW.

Referenced by GetTosaMapping().