15 const std::vector<const TensorInfo*>& inputs,
16 const std::vector<const TensorInfo*>& outputs,
21 throw Exception(
"Support for adjoint not implemented.");
25 throw Exception(
"MatMul only supported in the last 2 dimensions");
28 std::string input0Name = std::string(
"input_0");
29 std::string input1Name = std::string(
"input_1");
30 std::string outputName = std::string(
"output_0");
49 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
50 bool isInputInt8 = (inputDType == DType_INT8);
51 bool isInputInt16 = (inputDType == DType_INT16);
53 std::vector<TosaSerializationTensor*> tensors;
54 std::vector<TosaSerializationOperator*> operators;
59 if(input0Name.find(
"input_") != std::string::npos)
62 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType, {}));
64 if(input1Name.find(
"input_") != std::string::npos)
67 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType, {}));
70 std::string input0TransposeName = input0Name;
71 std::string input1TransposeName = input1Name;
74 std::string input0MatMulName = input0Name;
75 std::string input1MatMulName = input1Name;
83 uint32_t input0Dimensions = inputs[0]->GetNumDimensions();
84 if (input0Dimensions > 3)
87 for (uint32_t i = 0; i < (input0Dimensions - 2); ++i)
89 x *=(inputs[0]->GetShape()[i]);
92 targetShape0 = {
static_cast<int32_t
>(x),
93 static_cast<int32_t
>(inputs[0]->GetShape()[input0Dimensions - 2]),
94 static_cast<int32_t
>(inputs[0]->GetShape()[input0Dimensions - 1])};
96 TosaReshapeAttribute attribute(targetShape0);
98 auto* input0ReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
99 Attribute_ReshapeAttribute,
102 {outputReshape0Name});
104 operators.push_back(input0ReshapeOp);
105 transpose0Shape = targetShape0;
106 tensors.push_back(
new TosaSerializationTensor(outputReshape0Name, targetShape0, inputDType, {}));
107 input0TransposeName = outputReshape0Name;
108 input0MatMulName = outputReshape0Name;
114 uint32_t input1Dimensions = inputs[1]->GetNumDimensions();
115 if (input1Dimensions > 3)
118 for (uint32_t i = 0; i < (input1Dimensions - 2); i++)
120 x *= (inputs[1]->GetShape()[i]);
123 targetShape1 = {
static_cast<int32_t
>(x),
124 static_cast<int32_t
>(inputs[1]->GetShape()[input1Dimensions - 2]),
125 static_cast<int32_t
>(inputs[1]->GetShape()[input1Dimensions - 1])};
127 TosaReshapeAttribute attribute(targetShape1);
129 auto* input1ReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
130 Attribute_ReshapeAttribute,
133 {outputReshape1Name});
135 operators.push_back(input1ReshapeOp);
136 transpose1Shape = targetShape1;
137 tensors.push_back(
new TosaSerializationTensor(outputReshape1Name, targetShape1, inputDType, {}));
138 input1TransposeName = outputReshape1Name;
139 input1MatMulName = outputReshape1Name;
141 bool needsReshape = input0Dimensions > 3 || input1Dimensions > 3;
146 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->
m_DataLayoutX,
147 inputs[0]->GetShape());
149 std::vector<int32_t> mappings(permuteVec.begin(),
151 TosaTransposeAttribute transposeAttribute(mappings);
153 TosaSerializationOperator *transposeOp =
new TosaSerializationOperator(Op_TRANSPOSE,
154 Attribute_TransposeAttribute,
156 {input0TransposeName},
157 {outputTranspose0Name});
158 operators.push_back(transposeOp);
159 tensors.push_back(
new TosaSerializationTensor(outputTranspose0Name, transpose0Shape, inputDType, {}));
160 input0MatMulName = outputTranspose0Name;
165 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->
m_DataLayoutY,
166 inputs[1]->GetShape());
169 std::vector<int32_t> mappings(permuteVec.begin(),
171 TosaTransposeAttribute transposeAttribute(mappings);
173 TosaSerializationOperator *transposeOp =
new TosaSerializationOperator(Op_TRANSPOSE,
174 Attribute_TransposeAttribute,
176 {input1TransposeName},
177 {outputTranspose1Name});
178 operators.push_back(transposeOp);
179 tensors.push_back(
new TosaSerializationTensor(outputTranspose1Name, transpose1Shape, inputDType, {}));
180 input1MatMulName = outputTranspose1Name;
184 std::string matMulOutputStr = needsReshape || isInputInt8 || isInputInt16 ?
187 TosaMatMulAttribute matMulAttribute(0,0);
188 DType matMulOutDType =
ArmNNToDType(inputs[1]->GetDataType());
191 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
192 matMulOutDType = DType_INT32;
196 matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
197 matMulOutDType = DType_INT48;
199 TosaSerializationOperator* matMulOp =
new TosaSerializationOperator(Op_MATMUL,
200 Attribute_MatMulAttribute,
202 {input0MatMulName, input1MatMulName},
205 operators.push_back(matMulOp);
206 tensors.push_back(
new TosaSerializationTensor(matMulOutputStr, targetShape0, matMulOutDType, {}));
208 std::string outputRescale = needsReshape ?
210 std::string inputReshape2Name = isInputInt8 || isInputInt16 ? outputRescale : matMulOutputStr;
213 if (isInputInt8 || isInputInt16)
215 bool scale32 = isInputInt16 ? false :
true;
216 bool doubleRound = isInputInt16 ? false :
true;
218 double scale_alpha = inputs[0]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
219 int32_t input_zp = inputs[0]->GetQuantizationOffset();
220 int32_t output_zp = outputs[0]->GetQuantizationOffset();
222 TosaSerializationOperator* rescaleOp =
nullptr;
232 tensors.push_back(
new TosaSerializationTensor(outputRescale,
236 operators.push_back(rescaleOp);
243 TosaReshapeAttribute attribute(targetShape);
245 auto* outputReshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
246 Attribute_ReshapeAttribute,
251 operators.push_back(outputReshapeOp);
252 tensors.push_back(
new TosaSerializationTensor(outputName, targetShape, inputDType, {}));
255 return new TosaSerializationBasicBlock(blockName,
259 {input0Name, input1Name},