17 const std::vector<const TensorInfo*>& inputs,
18 const std::vector<const TensorInfo*>& outputs,
21 std::string inputName;
22 std::vector<std::string> inputNames;
23 std::vector<std::string> fcInputNames;
24 std::string outputName = std::string(
"output0_");
27 DType inputDType0 =
ArmNNToDType(inputs[0]->GetDataType());
28 DType outputDType0 =
ArmNNToDType(outputs[0]->GetDataType());
33 inputNames.emplace_back(
"input_0");
34 inputNames.emplace_back(
"constant_1");
37 inputNames.emplace_back(
"constant_2");
46 inputNames.push_back(inputName);
49 inputNames.push_back(inputName);
54 inputNames.push_back(inputName);
61 std::vector<TosaSerializationTensor*> tensors;
62 std::vector<TosaSerializationOperator*> operators;
68 if(inputNames[0].find(
"input_") != std::string::npos)
71 tensors.push_back(
new TosaSerializationTensor(inputNames[0], inputShape0, inputDType0, {}));
79 DType inputDType1 =
ArmNNToDType(inputs[1]->GetDataType());
80 tensors.push_back(
new TosaSerializationTensor(inputNames[1], inputShape1, inputDType1, {}));
85 if(!inputs[2]->IsConstant() || layer ==
nullptr)
88 DType inputDType2 =
ArmNNToDType(inputs[2]->GetDataType());
89 tensors.push_back(
new TosaSerializationTensor(inputNames[2], inputShape2, inputDType2, {}));
96 inputNames.push_back(inputName);
98 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputName}));
100 const DType dType = (inputDType0 == DType_INT8) ? DType_INT32 : outputDType0;
101 std::vector<float> data(outputs[0]->GetShape()[1], 0);
103 std::vector<uint8_t> uint8Data;
104 TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
106 tensors.push_back(
new TosaSerializationTensor(inputName,
107 {
static_cast<int32_t
>(outputs[0]->GetShape()[1])},
112 fcInputNames = inputNames;
115 if (inputs[0]->GetShape().GetNumDimensions() != 2)
117 uint32_t num_elems = inputs[1]->GetShape()[1];
118 uint32_t num_batch = inputs[0]->GetShape().GetNumElements() / num_elems;
121 const std::vector<int32_t>& targetShape = {
static_cast<int32_t
>(num_batch),
static_cast<int32_t
>(num_elems)};
124 auto* reshapeOp =
new TosaSerializationOperator(Op_RESHAPE,
125 Attribute_ReshapeAttribute,
128 {outputReshapeName});
129 operators.push_back(reshapeOp);
131 tensors.push_back(
new TosaSerializationTensor(outputReshapeName, targetShape, inputDType0, {}));
133 fcInputNames[0] = outputReshapeName;
139 std::string fcOutputName;
140 bool isInputInt8 = (inputDType0 == DType_INT8);
144 tensors.push_back(
new TosaSerializationTensor(fcOutputName, outputShape0, DType_INT32, {}));
148 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
152 TosaFullyConnectedAttribute attribute(inputs[0]->GetQuantizationOffset(),
153 inputs[1]->GetQuantizationOffset());
155 std::string& fcOutStr = isInputInt8 ? fcOutputName : outputName;
156 auto* fullyConnected_op =
new TosaSerializationOperator(Op_FULLY_CONNECTED,
157 Attribute_FullyConnectedAttribute,
161 operators.push_back(fullyConnected_op);
165 int32_t output_zp = outputs[0]->GetQuantizationOffset();
166 double output_scale = outputs[0]->GetQuantizationScales()[0];
167 double input_scale = inputs[0]->GetQuantizationScales()[0];
168 const std::vector<float>& weight_scales = inputs[1]->GetQuantizationScales();
170 TosaSerializationOperator* rescaleOp =
nullptr;
183 operators.push_back(rescaleOp);
184 tensors.push_back(
new TosaSerializationTensor(outputName,
191 return new TosaSerializationBasicBlock(blockName,
void CreateRescaleTosaOperatorForWeights(const std::string &inputName, const std::string &outputName, int32_t input_zp, int32_t output_zp, bool input_unsigned, bool output_unsigned, bool double_round, bool scale32, double input_scale, double output_scale, const std::vector< float > &weight_scales, TosaSerializationOperator **op)
Creates a TOSA rescale operator for weight tensors.