9 TosaSerializationOperator*
AddRescaleOp(
const string &inputName,
10 const string &outputName,
11 std::vector<TosaSerializationTensor *> &tensors,
12 const std::vector<const TensorInfo *> &inputs,
13 const std::vector<const TensorInfo *> &outputs)
15 double scale_alpha = inputs[1]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
16 int32_t input_zp = inputs[1]->GetQuantizationOffset();
17 int32_t output_zp = outputs[0]->GetQuantizationOffset();
19 TosaSerializationOperator* rescaleOp =
nullptr;
30 tensors.push_back(
new TosaSerializationTensor(outputName,
38 const std::vector<const TensorInfo*>& inputs,
39 const std::vector<const TensorInfo*>& outputs,
42 std::string input0Name = std::string(
"input_0");
43 std::string input1Name = std::string(
"input_1");
44 std::string outputName = std::string(
"output0_");
47 std::string blockName;
58 TosaSerializationOperator* op =
nullptr;
60 std::vector<TosaSerializationTensor*> tensors;
61 std::vector<TosaSerializationOperator*> operators;
62 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
63 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
64 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
65 bool isInputInt8 = (inputDType0 == DType_INT8);
70 if(input0Name.find(
"input_") != std::string::npos)
73 tensors.push_back(
new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {}));
75 if(input1Name.find(
"input_") != std::string::npos)
78 tensors.push_back(
new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {}));
85 std::string outputElemenwiseBinaryName;
89 tensors.push_back(
new TosaSerializationTensor(outputElemenwiseBinaryName, outputShape0, DType_INT32, {}));
93 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
96 std::string& elementwiseInput0Str = isInputInt8 ? input0ElemenwiseBinaryName : input0Name;
97 std::string& elementwiseInput1Str = isInputInt8 ? input1ElemenwiseBinaryName : input1Name;
98 std::string& elementwiseOutputStr = isInputInt8 ? outputElemenwiseBinaryName : outputName;
101 case LayerType::Addition:
103 op =
new TosaSerializationOperator(Op_ADD,
106 {input0Name, input1Name},
111 case LayerType::ElementwiseBinary:
118 if (inputDType0 == DType_INT8)
121 AddRescaleOp(input0Name, input0ElemenwiseBinaryName, tensors, inputs, outputs));
124 AddRescaleOp(input1Name, input1ElemenwiseBinaryName, tensors, inputs, outputs));
126 op =
new TosaSerializationOperator(Op_ADD,
129 {elementwiseInput0Str, elementwiseInput1Str},
130 {elementwiseOutputStr});
137 if (inputDType0 == DType_INT8)
140 AddRescaleOp(input0Name, input0ElemenwiseBinaryName, tensors, inputs, outputs));
143 AddRescaleOp(input1Name, input1ElemenwiseBinaryName, tensors, inputs, outputs));
145 op =
new TosaSerializationOperator(Op_MAXIMUM,
148 {elementwiseInput0Str, elementwiseInput1Str},
149 {elementwiseOutputStr});
156 TosaMulAttribute mulAttribute(shift);
160 op =
new TosaSerializationOperator(Op_MUL,
161 Attribute_MulAttribute,
163 {input0Name, input1Name},
164 {elementwiseOutputStr});
171 if (inputDType0 == DType_INT8)
174 AddRescaleOp(input0Name, input0ElemenwiseBinaryName, tensors, inputs, outputs));
177 AddRescaleOp(input1Name, input1ElemenwiseBinaryName, tensors, inputs, outputs));
180 op =
new TosaSerializationOperator(Op_SUB,
183 {elementwiseInput0Str, elementwiseInput1Str},
184 {elementwiseOutputStr});
189 throw armnn::Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
193 case LayerType::Multiplication:
196 TosaMulAttribute mulAttribute(shift);
197 op =
new TosaSerializationOperator(Op_MUL,
198 Attribute_MulAttribute,
200 {input0Name, input1Name},
205 case LayerType::Subtraction:
207 op =
new TosaSerializationOperator(Op_SUB,
210 {input0Name, input1Name},
216 throw armnn::Exception(
"ConvertElementwiseBinaryToTosaOperator: Unsupported layer type.");
219 operators.push_back(op);
222 if (inputDType0 == DType_INT8)
225 AddRescaleOp(outputElemenwiseBinaryName, outputName, tensors, inputs, outputs));
228 return new TosaSerializationBasicBlock(blockName,
232 {input0Name, input1Name},