10 const string &outputName,
11 std::vector<TosaSerializationTensor*>& tensors,
12 const std::vector<const TensorInfo*>& inputs,
13 const std::vector<const TensorInfo*>& outputs,
14 std::vector<TosaSerializationOperator*>& operators)
16 double scale_alpha = inputs[1]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
17 int32_t input_zp = inputs[1]->GetQuantizationOffset();
18 int32_t output_zp = outputs[0]->GetQuantizationOffset();
20 TosaSerializationOperator* rescaleOp =
nullptr;
31 tensors.push_back(
new TosaSerializationTensor(outputName,
34 operators.push_back(rescaleOp);
39 const std::vector<const TensorInfo*>& inputs,
40 const std::vector<const TensorInfo*>& outputs,
43 std::string input0Name = std::string(
"input_0");
44 std::string input1Name = std::string(
"input_1");
45 std::string outputName = std::string(
"output0_");
48 std::string blockName;
59 TosaSerializationOperator* op =
nullptr;
61 std::vector<TosaSerializationTensor*> tensors;
62 std::vector<TosaSerializationOperator*> operators;
63 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
64 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
65 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
66 bool isInputInt8 = (inputDType0 == DType_INT8);
71 if(input0Name.find(
"input_") != std::string::npos)
74 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
76 if(input1Name.find(
"input_") != std::string::npos)
79 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
86 std::string outputElemenwiseBinaryName;
90 tensors.push_back(
new TosaSerializationTensor(outputElemenwiseBinaryName, outputShape0, DType_INT32, {}));
94 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
99 bool isMulDesc = descriptor ? descriptor->
m_Operation == BinaryOperation::Mul :
false;
100 bool isMulOp = (type == LayerType::Multiplication) || isMulDesc ?
true :
false;
101 if (isInputInt8 && !isMulOp)
103 AddRescaleOp(input0Name, input0ElemenwiseBinaryName, tensors, inputs, outputs, operators);
104 AddRescaleOp(input1Name, input1ElemenwiseBinaryName, tensors, inputs, outputs, operators);
107 std::string& elementwiseInput0Str = isInputInt8 ? input0ElemenwiseBinaryName : input0Name;
108 std::string& elementwiseInput1Str = isInputInt8 ? input1ElemenwiseBinaryName : input1Name;
109 std::string& elementwiseOutputStr = isInputInt8 ? outputElemenwiseBinaryName : outputName;
113 case LayerType::Addition:
115 op =
new TosaSerializationOperator(Op_ADD,
118 {input0Name, input1Name},
123 case LayerType::ElementwiseBinary:
127 case BinaryOperation::Add:
129 op =
new TosaSerializationOperator(Op_ADD,
132 {elementwiseInput0Str, elementwiseInput1Str},
133 {elementwiseOutputStr});
137 case BinaryOperation::Maximum:
139 op =
new TosaSerializationOperator(Op_MAXIMUM,
142 {elementwiseInput0Str, elementwiseInput1Str},
143 {elementwiseOutputStr});
147 case BinaryOperation::Mul:
150 TosaMulAttribute mulAttribute(shift);
154 op =
new TosaSerializationOperator(Op_MUL,
155 Attribute_MulAttribute,
157 {input0Name, input1Name},
158 {elementwiseOutputStr});
162 case BinaryOperation::Sub:
164 op =
new TosaSerializationOperator(Op_SUB,
167 {elementwiseInput0Str, elementwiseInput1Str},
168 {elementwiseOutputStr});
172 case BinaryOperation::SqDiff:
174 throw Exception(
"TOSA mappings of Squared Difference operator "
175 "implemented under ConvertSquaredDifferenceToTosaOperator().");
178 throw Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
182 case LayerType::Multiplication:
185 TosaMulAttribute mulAttribute(shift);
186 op =
new TosaSerializationOperator(Op_MUL,
187 Attribute_MulAttribute,
189 {input0Name, input1Name},
194 case LayerType::Subtraction:
196 op =
new TosaSerializationOperator(Op_SUB,
199 {input0Name, input1Name},
205 throw Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
208 operators.push_back(op);
212 if (inputDType0 == DType_INT8)
214 AddRescaleOp(outputElemenwiseBinaryName, outputName, tensors, inputs, outputs, operators);
217 return new TosaSerializationBasicBlock(blockName,
221 {input0Name, input1Name},
227 const std::vector<const TensorInfo*>& inputs,
228 const std::vector<const TensorInfo*>& outputs,
231 if (descriptor->
m_Operation != BinaryOperation::SqDiff)
233 throw Exception(
"ElementwiseBinaryDescriptor operation must be SqDiff"
234 "in ConvertSquaredDifferenceToTosaOperator().");
237 std::string input0Name = std::string(
"input_0");
238 std::string input1Name = std::string(
"input_1");
239 std::string outputName = std::string(
"output0_");
245 if (layer !=
nullptr)
258 std::vector<TosaSerializationTensor*> tensors {};
259 std::vector<TosaSerializationOperator*> operators {};
260 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
261 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
262 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
263 bool isInputInt8 = (inputDType0 == DType_INT8);
268 if(input0Name.find(
"input_") != std::string::npos)
271 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
273 if(input1Name.find(
"input_") != std::string::npos)
276 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
281 if (inputDType0 == DType_FP32 ||
282 inputDType0 == DType_FP16 ||
283 inputDType0 == DType_INT32)
285 operators.push_back(
new TosaSerializationOperator(
289 {input0Name, input1Name},
290 {interElemenwiseBinaryName}));
291 tensors.push_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
297 TosaMulAttribute mulAttribute(shift);
299 operators.push_back(
new TosaSerializationOperator(
301 Attribute_MulAttribute,
303 {interElemenwiseBinaryName, interElemenwiseBinaryName},
306 else if (isInputInt8)
317 double in_x_scale = inputs[0]->GetQuantizationScale();
318 double in_y_scale = inputs[1]->GetQuantizationScale();
319 double result_scale = outputs[0]->GetQuantizationScale();
320 double twice_max_input_scale = 2.0 * std::max(in_x_scale, in_y_scale);
321 const int32_t LEFT_SHIFT = 7;
322 double x_rescale_scale = in_x_scale / twice_max_input_scale;
323 double y_rescale_scale = in_y_scale / twice_max_input_scale;
324 double output_rescale_scale =
325 (twice_max_input_scale * twice_max_input_scale) /
326 ((
static_cast<double>(1 << LEFT_SHIFT * 2)) * result_scale);
328 TosaSerializationOperator* xShiftOp =
nullptr;
332 inputs[0]->GetQuantizationOffset(),
337 operators.push_back(xShiftOp);
338 tensors.push_back(
new TosaSerializationTensor(rescale0Output0Name,
343 TosaSerializationOperator* yShiftOp =
nullptr;
347 inputs[1]->GetQuantizationOffset(),
352 operators.push_back(yShiftOp);
353 tensors.push_back(
new TosaSerializationTensor(rescale0Output1Name,
358 TosaSerializationOperator* xScaledOp =
nullptr;
367 operators.push_back(xScaledOp);
368 tensors.push_back(
new TosaSerializationTensor(rescale1Output0Name,
373 TosaSerializationOperator* yScaledOp =
nullptr;
382 operators.push_back(yScaledOp);
383 tensors.push_back(
new TosaSerializationTensor(rescale1Output1Name,
390 operators.push_back(
new TosaSerializationOperator(
394 {rescale1Output0Name, rescale1Output1Name},
395 {interElemenwiseBinaryName}));
396 tensors.push_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
401 TosaMulAttribute mulAttribute(shift);
403 operators.push_back(
new TosaSerializationOperator(
405 Attribute_MulAttribute,
407 {interElemenwiseBinaryName, interElemenwiseBinaryName},
409 tensors.push_back(
new TosaSerializationTensor(mulOutputName,
415 TosaSerializationOperator* rescaleOutputOp =
nullptr;
418 output_rescale_scale,
420 outputs[0]->GetQuantizationOffset(),
424 operators.push_back(rescaleOutputOp);
428 throw Exception(
"TOSA spec only supports INT8, INT32, FP16 and FP32 datatypes for SqDiff.");
431 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
433 return new TosaSerializationBasicBlock(blockName,
437 {input0Name, input1Name},