22 if (inputs.size() != 1)
24 throw armnn::Exception(
"ConvertActivationToTosaOperator: 1 input tensors required.");
27 if (outputs.size() != 1)
29 throw armnn::Exception(
"ConvertActivationToTosaOperator: 1 output tensor required.");
32 std::string inputName = std::string(
"input0_");
35 std::string outputName = std::string(
"output0_");
50 std::vector<TosaSerializationTensor*> tensors;
55 std::vector<int32_t> inputShape0;
56 DType inputDType0 = DType::DType_UNKNOWN;
57 if(inputName.find(
"input0_") != std::string::npos)
61 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape0, inputDType0, {}));
65 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
66 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
68 #if TOSA_COMPAT_VERSION(0, 60, 0)
71 if (inputDType0 == DType::DType_FP32)
74 TosaSerializationOperator* alphaOp =
nullptr;
75 TosaSerializationTensor* alphaTensor =
nullptr;
76 CreateConstTosaOperator<float>(outputNameAlpha,
77 activationDescriptor->
m_A,
82 tensors.push_back(alphaTensor);
86 TosaMulAttribute mulAttribute(shift);
87 TosaSerializationOperator* mulOp =
new TosaSerializationOperator(Op_MUL,
88 Attribute_MulAttribute,
90 {inputName, outputNameAlpha},
92 tensors.push_back(
new TosaSerializationTensor(outputNameMul, inputShape0, inputDType0, {}));
94 TosaSerializationOperator* op =
nullptr;
95 if (activationDescriptor->
m_A <= 1.0)
97 op =
new TosaSerializationOperator(Op_MAXIMUM,
100 {inputName, outputNameMul},
105 op =
new TosaSerializationOperator(Op_MINIMUM,
108 {inputName, outputNameMul},
115 return new TosaSerializationBasicBlock(blockName,
117 {alphaOp, mulOp, op},
128 DType rescale_type = DType::DType_INT32;
129 float alpha = activationDescriptor->
m_A;
130 double scale_alpha = inputs[0]->GetQuantizationScale() * alpha / outputs[0]->GetQuantizationScale();
131 double scale_identity = inputs[0]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
132 int32_t input_zp = inputs[0]->GetQuantizationOffset();
133 int32_t output_zp = outputs[0]->GetQuantizationOffset();
138 TosaSerializationOperator* rescaleAlphaOp =
nullptr;
139 TosaSerializationTensor* rescaleAlphaTensor =
nullptr;
141 outputNameRescaleAlpha,
150 &rescaleAlphaTensor);
151 tensors.push_back(rescaleAlphaTensor);
156 TosaSerializationOperator* rescaleIdentityOp =
nullptr;
157 TosaSerializationTensor* rescaleIdentityTensor =
nullptr;
159 outputNameRescaleIdentity,
168 &rescaleIdentityTensor);
169 tensors.push_back(rescaleIdentityTensor);
183 TosaSerializationOperator* op =
nullptr;
186 op =
new TosaSerializationOperator(Op_MAXIMUM,
189 {outputNameRescaleAlpha, outputNameRescaleIdentity},
190 {outputNameRescaleMaxMin});
194 op =
new TosaSerializationOperator(Op_MINIMUM,
197 {outputNameRescaleAlpha, outputNameRescaleIdentity},
198 {outputNameRescaleMaxMin});
201 tensors.push_back(
new TosaSerializationTensor(outputNameRescaleMaxMin, inputShape0, rescale_type, {}));
205 TosaSerializationOperator* rescaleOutputOp =
nullptr;
217 return new TosaSerializationBasicBlock(blockName,
219 {rescaleAlphaOp, rescaleIdentityOp, op, rescaleOutputOp},
229 TosaSerializationOperator* zeroOp =
nullptr;
230 TosaSerializationTensor* zeroTensor =
nullptr;
231 CreateConstTosaOperator<float>(outputNameZero,
237 tensors.push_back(zeroTensor);
240 TosaSerializationOperator* alphaOp =
nullptr;
241 TosaSerializationTensor* alphaTensor =
nullptr;
242 CreateConstTosaOperator<float>(outputNameAlpha,
243 activationDescriptor->
m_A,
248 tensors.push_back(alphaTensor);
252 TosaMulAttribute mulAttribute(shift);
253 TosaSerializationOperator* mulOp =
new TosaSerializationOperator(Op_MUL,
254 Attribute_MulAttribute,
256 {inputName, outputNameAlpha},
258 tensors.push_back(
new TosaSerializationTensor(outputNameMul, inputShape0, inputDType0, {}));
261 TosaSerializationOperator* geOp =
new TosaSerializationOperator(Op_GREATER_EQUAL,
264 {inputName, outputNameZero},
266 tensors.push_back(
new TosaSerializationTensor(outputNameGE, outputShape0, DType::DType_BOOL, {}));
269 TosaSerializationOperator* selOp =
new TosaSerializationOperator(Op_SELECT,
272 {outputNameGE, inputName, outputNameMul},
277 return new TosaSerializationBasicBlock(blockName,
279 {zeroOp, alphaOp, mulOp, geOp, selOp},