14 if (unaryDescriptor->
m_Operation != UnaryOperation::Rsqrt)
16 throw armnn::Exception(
"ConvertRsqrtOperator: Unsupported elementwise unary operation in descriptor.");
20 "ConvertRsqrtOperator: Rsqrt must have only one input");
23 "ConvertRsqrtOperator: Rsqrt must have only one output");
26 std::string inputName = std::string(
"input_");
27 std::string outputName = std::string(
"output0_");
29 std::string supportedTypes = std::string(
" Supported Types: FLOAT32, FLOAT16 & INT8.");
39 std::vector<TosaSerializationTensor*> tensors;
40 std::vector<TosaSerializationOperator *> operators;
42 DataType inputDType = inputs[0]->GetDataType();
44 if (inputDType == DataType::QAsymmS8 || inputDType == DataType::QSymmS8)
46 float input_scale = inputs[0]->GetQuantizationScale();
47 float output_scale = outputs[0]->GetQuantizationScale();
48 int32_t input_zp = inputs[0]->GetQuantizationOffset();
49 int32_t output_zp = outputs[0]->GetQuantizationOffset();
51 const float output_max =
static_cast<float>(127 - output_zp) * output_scale;
53 auto rsqrt_func = [&](
float x) ->
float
60 return 1.0f / std::sqrt(x);
63 TosaTableAttribute attribute(
66 operators.push_back(
new TosaSerializationOperator(tosa::Op_TABLE,
67 Attribute_TableAttribute,
72 else if (inputDType == DataType::Float32 || inputDType == DataType::Float16)
74 operators.push_back(
new TosaSerializationOperator(tosa::Op_RSQRT,
80 else if (inputDType == DataType::QSymmS16)
82 throw Exception(
"ConvertRsqrtOperator(): unsupported datatype INT16 is not implemented yet." + supportedTypes);
84 else if (inputDType == DataType::Signed32 || inputDType == DataType::Signed64)
86 throw Exception(
"ConvertRsqrtOperator(): unsupported datatype INT32 or INT64." + supportedTypes);
90 throw Exception(
"ConvertRsqrtOperator(): TOSA specification does not support this datatype." + supportedTypes);
96 if (inputName.find(
"input_") != std::string::npos)
100 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape0, inputDType0, {}));
104 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
106 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
110 return new TosaSerializationBasicBlock(blockName,
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int16_t > getTosaConst8bitTable(float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< float(float)> func)
Base class for all ArmNN exceptions so that users can filter to just those.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
UnaryOperation m_Operation
Specifies the elementwiseUnary operation to execute.