231 if (descriptor->
m_Operation != BinaryOperation::SqDiff)
233 throw Exception(
"ElementwiseBinaryDescriptor operation must be SqDiff"
234 "in ConvertSquaredDifferenceToTosaOperator().");
237 std::string input0Name = std::string(
"input_0");
238 std::string input1Name = std::string(
"input_1");
239 std::string outputName = std::string(
"output0_");
245 if (layer !=
nullptr)
258 std::vector<TosaSerializationTensor*> tensors {};
259 std::vector<TosaSerializationOperator*> operators {};
260 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
261 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
262 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
263 bool isInputInt8 = (inputDType0 == DType_INT8);
268 if(input0Name.find(
"input_") != std::string::npos)
271 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
273 if(input1Name.find(
"input_") != std::string::npos)
276 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
281 if (inputDType0 == DType_FP32 ||
282 inputDType0 == DType_FP16 ||
283 inputDType0 == DType_INT32)
285 operators.push_back(
new TosaSerializationOperator(
289 {input0Name, input1Name},
290 {interElemenwiseBinaryName}));
291 tensors.push_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
297 TosaMulAttribute mulAttribute(shift);
299 operators.push_back(
new TosaSerializationOperator(
301 Attribute_MulAttribute,
303 {interElemenwiseBinaryName, interElemenwiseBinaryName},
306 else if (isInputInt8)
317 double in_x_scale = inputs[0]->GetQuantizationScale();
318 double in_y_scale = inputs[1]->GetQuantizationScale();
319 double result_scale = outputs[0]->GetQuantizationScale();
320 double twice_max_input_scale = 2.0 * std::max(in_x_scale, in_y_scale);
321 const int32_t LEFT_SHIFT = 7;
322 double x_rescale_scale = in_x_scale / twice_max_input_scale;
323 double y_rescale_scale = in_y_scale / twice_max_input_scale;
324 double output_rescale_scale =
325 (twice_max_input_scale * twice_max_input_scale) /
326 ((
static_cast<double>(1 << LEFT_SHIFT * 2)) * result_scale);
328 TosaSerializationOperator* xShiftOp =
nullptr;
332 inputs[0]->GetQuantizationOffset(),
337 operators.push_back(xShiftOp);
338 tensors.push_back(
new TosaSerializationTensor(rescale0Output0Name,
343 TosaSerializationOperator* yShiftOp =
nullptr;
347 inputs[1]->GetQuantizationOffset(),
352 operators.push_back(yShiftOp);
353 tensors.push_back(
new TosaSerializationTensor(rescale0Output1Name,
358 TosaSerializationOperator* xScaledOp =
nullptr;
367 operators.push_back(xScaledOp);
368 tensors.push_back(
new TosaSerializationTensor(rescale1Output0Name,
373 TosaSerializationOperator* yScaledOp =
nullptr;
382 operators.push_back(yScaledOp);
383 tensors.push_back(
new TosaSerializationTensor(rescale1Output1Name,
390 operators.push_back(
new TosaSerializationOperator(
394 {rescale1Output0Name, rescale1Output1Name},
395 {interElemenwiseBinaryName}));
396 tensors.push_back(
new TosaSerializationTensor(interElemenwiseBinaryName,
401 TosaMulAttribute mulAttribute(shift);
403 operators.push_back(
new TosaSerializationOperator(
405 Attribute_MulAttribute,
407 {interElemenwiseBinaryName, interElemenwiseBinaryName},
409 tensors.push_back(
new TosaSerializationTensor(mulOutputName,
415 TosaSerializationOperator* rescaleOutputOp =
nullptr;
418 output_rescale_scale,
420 outputs[0]->GetQuantizationOffset(),
424 operators.push_back(rescaleOutputOp);
428 throw Exception(
"TOSA spec only supports INT8, INT32, FP16 and FP32 datatypes for SqDiff.");
431 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
433 return new TosaSerializationBasicBlock(blockName,
437 {input0Name, input1Name},