11 const std::vector<const TensorInfo*>& inputs,
12 const std::vector<const TensorInfo*>& outputs,
15 auto input0Name = std::string(
"input_0");
16 auto input1Name = std::string(
"input_1");
17 auto outputName = std::string(
"output0_");
21 std::string blockName;
32 TosaSerializationOperator* op =
nullptr;
33 std::vector<TosaSerializationTensor*> tensors;
34 std::vector<TosaSerializationOperator*> operators;
36 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
37 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
38 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
40 bool isInputInt8 = (inputDType0 == DType_INT8);
45 if(input0Name.find(
"input_") != std::string::npos)
48 tensors.emplace_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
50 if(input1Name.find(
"input_") != std::string::npos && input0Name != input1Name)
53 tensors.emplace_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
58 std::string outputElemenwiseBinaryName;
63 tensors.emplace_back(
new TosaSerializationTensor(outputElemenwiseBinaryName, outputShape0, DType_INT32, {}));
67 tensors.emplace_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
70 float input0Scale = 0;
71 float input1Scale = 0;
72 float outputScale = 0;
76 input0Scale = inputs[0]->GetQuantizationScale();
77 input1Scale = inputs[1]->GetQuantizationScale();
78 outputScale = outputs[0]->GetQuantizationScale();
82 TosaSerializationOperator* rescaleOp0 =
nullptr;
85 inputs[0]->GetQuantizationOffset(),
93 tensors.emplace_back(
new TosaSerializationTensor(input0ElementwiseBinaryName,
97 operators.emplace_back(rescaleOp0);
99 TosaSerializationOperator* rescaleOp1 =
nullptr;
101 bool isSub = type == LayerType::Subtraction || (descriptor && descriptor->
m_Operation == BinaryOperation::Sub);
105 auto maxScale = 2.0 * std::max(inputs[0]->GetQuantizationScale(), inputs[1]->GetQuantizationScale());
106 auto rescaleScale =
static_cast<float>((inputs[0]->GetQuantizationScale() / maxScale) * (1 << 21));
108 input1ElementwiseBinaryName,
110 inputs[1]->GetQuantizationOffset(),
117 operators.emplace_back(rescaleOp1);
118 tensors.emplace_back(
new TosaSerializationTensor(input1ElementwiseBinaryName,
123 TosaSerializationOperator* rescaleOp2 =
nullptr;
125 input2ElementwiseBinaryName,
134 operators.emplace_back(rescaleOp2);
135 tensors.emplace_back(
new TosaSerializationTensor(input2ElementwiseBinaryName,
143 input1ElementwiseBinaryName,
145 inputs[1]->GetQuantizationOffset(),
152 operators.emplace_back(rescaleOp1);
153 tensors.emplace_back(
new TosaSerializationTensor(input1ElementwiseBinaryName,
160 std::string
const& elementwiseInput0Str = isInputInt8 ? input0ElementwiseBinaryName : input0Name;
161 std::string elementwiseInput1Str = isInputInt8 ? input1ElementwiseBinaryName : input1Name;
162 std::string
const& elementwiseOutputStr = isInputInt8 ? outputElemenwiseBinaryName : outputName;
166 case LayerType::ElementwiseBinary:
170 case BinaryOperation::Add:
173 {elementwiseOutputStr},
178 case BinaryOperation::Maximum:
180 op =
new TosaSerializationOperator(Op_MAXIMUM,
183 {elementwiseInput0Str, elementwiseInput1Str},
184 {elementwiseOutputStr});
188 case BinaryOperation::Mul:
191 {elementwiseOutputStr},
196 case BinaryOperation::Sub:
200 elementwiseInput1Str = input2ElementwiseBinaryName;
204 {elementwiseOutputStr},
209 case BinaryOperation::SqDiff:
211 throw Exception(
"TOSA mappings of Squared Difference operator "
212 "implemented under ConvertSquaredDifferenceToTosaOperator().");
215 throw Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
219 case LayerType::Addition:
227 case LayerType::Multiplication:
235 case LayerType::Subtraction:
244 throw Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
249 operators.emplace_back(op);
254 if (inputDType0 == DType_INT8)
256 TosaSerializationOperator* rescaleOp =
nullptr;
261 outputs[0]->GetQuantizationOffset(),
267 tensors.emplace_back(
new TosaSerializationTensor(outputName,
271 operators.emplace_back(rescaleOp);
274 if(input0Name == input1Name)
276 return new TosaSerializationBasicBlock(blockName,
284 return new TosaSerializationBasicBlock(blockName,
288 {input0Name, input1Name},
333 auto maxScale = 2.0 * std::max(input0Scale, input1Scale);
336 auto inputShift = 20;
338 input0Scale =
static_cast<float>((input0Scale / maxScale) * (1 << inputShift));
339 input1Scale =
static_cast<float>((input1Scale / maxScale) * (1 << inputShift));
340 outputScale =
static_cast<float>(maxScale / (outputScale * (
static_cast<float>(1 << inputShift))));
345 if(input0Scale > input1Scale)
347 outputScale = (input0Scale * input1Scale) / outputScale;
348 input1Scale =
static_cast<float>((input0Scale / maxScale) * (1 << inputShift));
349 input0Scale =
static_cast<float>((input0Scale / maxScale) * (1 << inputShift));
353 outputScale = (input0Scale * input1Scale) / outputScale;
354 input0Scale =
static_cast<float>((input1Scale / maxScale) * (1 << inputShift));
355 input1Scale =
static_cast<float>((input1Scale / maxScale) * (1 << inputShift));
360 auto inputShift = 20;
362 input0Scale =
static_cast<float>((input0Scale / maxScale) * (1 << inputShift));
363 input1Scale =
static_cast<float>((input1Scale / maxScale) * (1 << 0));
364 outputScale =
static_cast<float>(maxScale / (outputScale * (
static_cast<float>(1 << inputShift))));
370 const std::vector<const TensorInfo*>& inputs,
371 const std::vector<const TensorInfo*>& outputs,
374 if (descriptor->
m_Operation != BinaryOperation::SqDiff)
376 throw Exception(
"ElementwiseBinaryDescriptor operation must be SqDiff"
377 "in ConvertSquaredDifferenceToTosaOperator().");
380 auto input0Name = std::string(
"input_0");
381 auto input1Name = std::string(
"input_1");
382 auto outputName = std::string(
"output0_");
388 if (layer !=
nullptr)
401 std::vector<TosaSerializationTensor*> tensors {};
402 std::vector<TosaSerializationOperator*> operators {};
403 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
404 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
405 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
406 bool isInputInt8 = (inputDType0 == DType_INT8);
411 if(input0Name.find(
"input_") != std::string::npos)
414 tensors.emplace_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
416 if(input1Name.find(
"input_") != std::string::npos)
419 tensors.emplace_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
424 if (inputDType0 == DType_FP32 ||
425 inputDType0 == DType_FP16 ||
426 inputDType0 == DType_INT32)
429 {interElemenwiseBinaryName},
432 tensors.emplace_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
440 else if (isInputInt8)
451 double in_x_scale = inputs[0]->GetQuantizationScale();
452 double in_y_scale = inputs[1]->GetQuantizationScale();
453 double result_scale = outputs[0]->GetQuantizationScale();
454 double twice_max_input_scale = 2.0 * std::max(in_x_scale, in_y_scale);
455 const int32_t LEFT_SHIFT = 7;
456 double x_rescale_scale = in_x_scale / twice_max_input_scale;
457 double y_rescale_scale = in_y_scale / twice_max_input_scale;
458 double output_rescale_scale =
459 (twice_max_input_scale * twice_max_input_scale) /
460 ((
static_cast<double>(1 << LEFT_SHIFT * 2)) * result_scale);
462 TosaSerializationOperator* xShiftOp =
nullptr;
466 inputs[0]->GetQuantizationOffset(),
473 operators.emplace_back(xShiftOp);
474 tensors.emplace_back(
new TosaSerializationTensor(rescale0Output0Name,
479 TosaSerializationOperator* yShiftOp =
nullptr;
483 inputs[1]->GetQuantizationOffset(),
490 operators.emplace_back(yShiftOp);
491 tensors.emplace_back(
new TosaSerializationTensor(rescale0Output1Name,
496 TosaSerializationOperator* xScaledOp =
nullptr;
507 operators.emplace_back(xScaledOp);
508 tensors.emplace_back(
new TosaSerializationTensor(rescale1Output0Name,
513 TosaSerializationOperator* yScaledOp =
nullptr;
524 operators.emplace_back(yScaledOp);
525 tensors.emplace_back(
new TosaSerializationTensor(rescale1Output1Name,
531 {interElemenwiseBinaryName},
534 tensors.emplace_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
543 tensors.emplace_back(
new TosaSerializationTensor(mulOutputName,
549 TosaSerializationOperator* rescaleOutputOp =
nullptr;
552 output_rescale_scale,
554 outputs[0]->GetQuantizationOffset(),
560 operators.emplace_back(rescaleOutputOp);
564 throw Exception(
"TOSA spec only supports INT8, INT32, FP16 and FP32 datatypes for SqDiff.");
567 tensors.emplace_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
569 return new TosaSerializationBasicBlock(blockName,
573 {input0Name, input1Name},