15 const std::vector<const TensorInfo*>& inputs,
16 const std::vector<const TensorInfo*>& outputs,
21 throw Exception(
"Support for adjoint not implemented.");
25 throw Exception(
"MatMul only supported in the last 2 dimensions");
28 std::string input0Name = std::string(
"input_0");
29 std::string input1Name = std::string(
"input_1");
30 std::string outputName = std::string(
"output_0");
49 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
50 bool isInputInt8 = (inputDType == DType_INT8);
51 bool isInputInt16 = (inputDType == DType_INT16);
53 std::vector<TosaSerializationTensor*> tensors;
54 std::vector<TosaSerializationOperator*> operators;
59 if(input0Name.find(
"input_") != std::string::npos)
62 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType, {}));
64 if(input1Name.find(
"input_") != std::string::npos)
67 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType, {}));
70 std::string input0TransposeName = input0Name;
71 std::string input1TransposeName = input1Name;
74 std::string input0MatMulName = input0Name;
75 std::string input1MatMulName = input1Name;
82 uint32_t input0Dimensions = inputs[0]->GetNumDimensions();
83 if (input0Dimensions > 3)
86 for (uint32_t i = 0; i < (input0Dimensions - 2); ++i)
88 x *=(inputs[0]->GetShape()[i]);
91 targetShape0 = {
static_cast<int32_t
>(x),
92 static_cast<int32_t
>(inputs[0]->GetShape()[input0Dimensions - 2]),
93 static_cast<int32_t
>(inputs[0]->GetShape()[input0Dimensions - 1])};
95 TosaReshapeAttribute attribute(targetShape0);
97 auto* input0ReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
98 Attribute_ReshapeAttribute,
101 {outputReshape0Name});
103 operators.push_back(input0ReshapeOp);
104 tensors.push_back(
new TosaSerializationTensor(outputReshape0Name, targetShape0, inputDType, {}));
105 input0TransposeName = outputReshape0Name;
106 input0MatMulName = outputReshape0Name;
111 uint32_t input1Dimensions = inputs[1]->GetNumDimensions();
112 if (input1Dimensions > 3)
115 for (uint32_t i = 0; i < (input1Dimensions - 2); i++)
117 x *= (inputs[1]->GetShape()[i]);
120 targetShape1 = {
static_cast<int32_t
>(x),
121 static_cast<int32_t
>(inputs[1]->GetShape()[input1Dimensions - 2]),
122 static_cast<int32_t
>(inputs[1]->GetShape()[input1Dimensions - 1])};
124 TosaReshapeAttribute attribute(targetShape1);
126 auto* input1ReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
127 Attribute_ReshapeAttribute,
130 {outputReshape1Name});
132 operators.push_back(input1ReshapeOp);
133 tensors.push_back(
new TosaSerializationTensor(outputReshape1Name, targetShape1, inputDType, {}));
134 input1TransposeName = outputReshape1Name;
135 input1MatMulName = outputReshape1Name;
137 bool needsReshape = input0Dimensions > 3 || input1Dimensions > 3;
143 inputs[0]->GetShape());
144 std::vector<int32_t> mappings(permuteVec.begin(),
146 if (input0Dimensions > 3)
148 auto input0BatchedDims = input0Dimensions - 3;
149 mappings = {
static_cast<int>(permuteVec[0]),
150 static_cast<int>(permuteVec[input0Dimensions - 2] - input0BatchedDims),
151 static_cast<int>(permuteVec[input0Dimensions - 1] - input0BatchedDims)};
154 TosaTransposeAttribute transposeAttribute(mappings);
156 TosaSerializationOperator *transposeOp =
new TosaSerializationOperator(Op_TRANSPOSE,
157 Attribute_TransposeAttribute,
159 {input0TransposeName},
160 {outputTranspose0Name});
162 std::vector<int32_t> transpose0Shape =
164 targetShape0[
static_cast<unsigned int>(mappings[0])],
165 targetShape0[
static_cast<unsigned int>(mappings[1])],
166 targetShape0[
static_cast<unsigned int>(mappings[2])]
169 operators.push_back(transposeOp);
170 tensors.push_back(
new TosaSerializationTensor(outputTranspose0Name, transpose0Shape, inputDType, {}));
171 input0MatMulName = outputTranspose0Name;
177 inputs[1]->GetShape());
179 std::vector<int32_t> mappings(permuteVec.begin(),
182 auto input1BatchedDims = input1Dimensions - 3;
183 if (input1Dimensions > 3)
185 mappings = {
static_cast<int>(permuteVec[0]),
186 static_cast<int>(permuteVec[input1Dimensions - 2] - input1BatchedDims),
187 static_cast<int>(permuteVec[input1Dimensions - 1] - input1BatchedDims)};
190 TosaTransposeAttribute transposeAttribute(mappings);
192 TosaSerializationOperator *transposeOp =
new TosaSerializationOperator(Op_TRANSPOSE,
193 Attribute_TransposeAttribute,
195 {input1TransposeName},
196 {outputTranspose1Name});
197 std::vector<int32_t> transpose1Shape =
199 targetShape1[
static_cast<unsigned int>(mappings[0])],
200 targetShape1[
static_cast<unsigned int>(mappings[1])],
201 targetShape1[
static_cast<unsigned int>(mappings[2])]
204 operators.push_back(transposeOp);
205 tensors.push_back(
new TosaSerializationTensor(outputTranspose1Name, transpose1Shape, inputDType, {}));
206 input1MatMulName = outputTranspose1Name;
210 std::string matMulOutputStr = needsReshape || isInputInt8 || isInputInt16 ?
213 TosaMatMulAttribute matMulAttribute(0,0);
214 DType matMulOutDType =
ArmNNToDType(inputs[1]->GetDataType());
217 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
218 matMulOutDType = DType_INT32;
222 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
223 matMulOutDType = DType_INT48;
225 TosaSerializationOperator* matMulOp =
new TosaSerializationOperator(Op_MATMUL,
226 Attribute_MatMulAttribute,
228 {input0MatMulName, input1MatMulName},
231 uint32_t outputDimensions = outputs[0]->GetNumDimensions();
232 if (outputDimensions > 3)
235 for (uint32_t i = 0; i < (outputDimensions - 2); ++i)
237 x *=(outputs[0]->GetShape()[i]);
240 outputShape0 = {
static_cast<int32_t
>(x),
241 static_cast<int32_t
>(outputs[0]->GetShape()[outputDimensions - 2]),
242 static_cast<int32_t
>(outputs[0]->GetShape()[outputDimensions - 1])};
245 operators.push_back(matMulOp);
246 tensors.push_back(
new TosaSerializationTensor(matMulOutputStr, outputShape0, matMulOutDType, {}));
248 std::string outputRescale = needsReshape ?
250 std::string inputReshape2Name = isInputInt8 || isInputInt16 ? outputRescale : matMulOutputStr;
253 if (isInputInt8 || isInputInt16)
255 bool scale32 = isInputInt16 ? false :
true;
256 bool doubleRound = isInputInt16 ? false :
true;
258 int32_t output_zp = outputs[0]->GetQuantizationOffset();
259 double output_scale = outputs[0]->GetQuantizationScales()[0];
260 double input_scale = inputs[0]->GetQuantizationScales()[0];
261 const std::vector<float>& weight_scales = inputs[1]->GetQuantizationScales();
263 TosaSerializationOperator* rescaleOp =
nullptr;
277 tensors.push_back(
new TosaSerializationTensor(outputRescale,
281 operators.push_back(rescaleOp);
288 TosaReshapeAttribute attribute(targetShape);
290 auto* outputReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
291 Attribute_ReshapeAttribute,
296 operators.push_back(outputReshapeOp);
297 tensors.push_back(
new TosaSerializationTensor(outputName, targetShape, inputDType, {}));
300 return new TosaSerializationBasicBlock(blockName,
304 {input0Name, input1Name},