18 const std::vector<const TensorInfo*>& inputs,
19 const std::vector<const TensorInfo*>& outputs,
25 throw armnn::Exception(
"ConvertReduceOperator: Must provide a valid input tensor.");
28 if (inputs[0]->IsQuantized() ^ outputs[0]->IsQuantized())
31 "Both input and output tensors must be either quantised or non-quantised data types.");
34 if (reduceDescriptor->
m_vAxis.empty())
36 throw armnn::Exception(
"ConvertReduceOperator: Reduce Operation with empty axis not implemented.");
40 std::string inputName =
"input_";
42 std::size_t intermediateCounter = 0;
44 std::string outputName =
"output0_";
58 std::vector<TosaSerializationTensor*> tensors;
59 std::vector<std::string> inputNames{inputName};
61 DType inputType =
ArmNNToDType(inputs[0]->GetDataType());
63 if (inputName.substr(0, 6) ==
"input_")
65 tensors.emplace_back(
new TosaSerializationTensor(inputName,
72 int64_t output_zp = 0;
74 double input_scale = 1.0;
75 double output_scale = 1.0;
77 int32_t input_multiplier = 1;
78 int32_t output_multiplier = 1;
80 int32_t input_shift = 0;
81 int32_t output_shift = 0;
83 int64_t numElemsOnReducedAxis = 1;
85 std::vector<int32_t> axes(reduceDescriptor->
m_vAxis.begin(), reduceDescriptor->
m_vAxis.end());
87 for (int64_t axis : axes)
89 numElemsOnReducedAxis *= inputShape[
static_cast<uint64_t
>(axis)];
92 std::vector<TosaSerializationOperator*> operators;
94 bool inputQuantised = inputs[0]->IsQuantized();
99 input_zp = inputs[0]->GetQuantizationOffset();
100 output_zp = outputs[0]->GetQuantizationOffset();
102 std::string outputNameRescale =
105 TosaSerializationOperator* rescaleOp1 =
nullptr;
109 case ReduceOperation::Sum:
112 input_scale =
static_cast<double>(1 << input_shift) * inputs[0]->GetQuantizationScale();
113 output_scale = 1.0 / (outputs[0]->GetQuantizationScale() *
static_cast<double>(1 << input_shift));
118 static_cast<int32_t
>(input_zp),
127 case ReduceOperation::Mean:
133 static_cast<double>(inputs[0]->GetQuantizationScale()) /
134 static_cast<double>(outputs[0]->GetQuantizationScale()),
139 int shift = 63 - __builtin_clzl(
static_cast<uint64_t
>(numElemsOnReducedAxis));
140 shift = std::min(shift, 32);
141 shift = std::min(shift, 62 - output_shift);
143 output_multiplier =
static_cast<int32_t
>(
144 (
static_cast<int64_t
>(output_multiplier) << shift) / numElemsOnReducedAxis);
146 output_shift += shift;
152 static_cast<int32_t
>(input_zp),
163 throw armnn::Exception(
"ConvertReduceOperator: Reduce Operation not implemented.");
166 operators.emplace_back(rescaleOp1);
168 tensors.emplace_back(
new TosaSerializationTensor(outputNameRescale,
174 std::string outputNameReduce;
175 bool reuseOutputName = !inputQuantised && reduceDescriptor->
m_ReduceOperation == ReduceOperation::Sum;
178 for (
const auto axis : axes)
180 auto rank =
static_cast<int64_t
>(inputs[0]->GetNumDimensions());
182 if (axis < 0 || axis >= rank)
187 TosaAxisAttribute reduceAttribute(axis);
189 std::vector<int32_t> outputShapeReduce = tensors.back()->GetShape();
190 outputShapeReduce[
static_cast<std::size_t
>(axis)] = 1;
192 outputNameReduce = (reuseOutputName && outputShapeReduce == outputShape)
198 case ReduceOperation::Sum:
199 case ReduceOperation::Mean:
200 operators.emplace_back(
new TosaSerializationOperator(Op_REDUCE_SUM,
201 Attribute_AxisAttribute,
203 { tensors.back()->GetName() },
204 { outputNameReduce }));
207 throw armnn::Exception(
"ConvertReduceOperator: Reduce Operation not implemented.");
210 tensors.emplace_back(
new TosaSerializationTensor(outputNameReduce,
212 tensors.back()->GetDtype(),
216 std::string outputNameReshape;
217 bool reshapeLogic =
false;
221 if (inputShape.size() == outputShape.size() && inputShape != outputShape && !axes.empty())
223 bool onlyMeanAxisChanged =
true;
225 for (
size_t i = 0; i < inputShape.size(); ++i)
227 if (inputShape[i] != outputShape[i] &&
228 std::find(axes.begin(), axes.end(),
static_cast<int64_t
>(i)) == axes.end())
230 onlyMeanAxisChanged =
false;
236 reshapeLogic = !onlyMeanAxisChanged;
238 else if (inputShape.size() != outputShape.size())
243 std::string outputNameRescale;
251 TosaReshapeAttribute reshapeAttribute(outputShape);
252 outputNameReshape = !inputQuantised && reduceDescriptor->
m_ReduceOperation == ReduceOperation::Mean
255 if(!outputNameRescale.empty())
257 outputNameReshape = outputNameRescale;
260 operators.emplace_back(
new TosaSerializationOperator(Op_RESHAPE,
261 Attribute_ReshapeAttribute,
263 { tensors.back()->GetName() },
264 { outputNameReshape }));
265 if(outputNameReshape != outputName)
267 tensors.emplace_back(
new TosaSerializationTensor(outputNameReshape,
269 tensors.back()->GetDtype(),
277 TosaSerializationOperator* rescaleOp2 =
nullptr;
281 case ReduceOperation::Sum:
286 static_cast<int32_t
>(output_zp),
293 case ReduceOperation::Mean:
299 static_cast<int32_t
>(output_zp),
308 throw armnn::Exception(
"ConvertReduceOperator: Reduce Operation not implemented.");
311 operators.emplace_back(rescaleOp2);
316 if (!inputQuantised && reduceDescriptor->
m_ReduceOperation == ReduceOperation::Mean)
320 inputNames.emplace_back(constNameDivScale);
322 operators.push_back(
new TosaSerializationOperator(Op_CONST,
326 { constNameDivScale }));
328 float divScale = 1.0f /
static_cast<float>(numElemsOnReducedAxis);
330 std::vector<uint8_t> uint8DivScale;
334 TosaSerializationHandler::ConvertF32toU8({divScale}, uint8DivScale);
337 TosaSerializationHandler::ConvertF16toU8({divScale}, uint8DivScale);
344 std::vector<int32_t> divConstantShape(outputShape.size(), 1);
346 tensors.push_back(
new TosaSerializationTensor(constNameDivScale,
353 TosaMulAttribute mulAttribute(shift);
354 if(reshapeLogic && !outputNameReshape.empty())
356 operators.emplace_back(
new TosaSerializationOperator(Op_MUL,
357 Attribute_MulAttribute,
359 { constNameDivScale, outputNameReshape },
362 else if (!outputNameReduce.empty())
364 operators.emplace_back(
new TosaSerializationOperator(Op_MUL,
365 Attribute_MulAttribute,
367 { constNameDivScale, outputNameReduce },
373 if(tensors.back()->GetName() != outputName)
375 tensors.emplace_back(
new TosaSerializationTensor(outputName,
381 return new TosaSerializationBasicBlock(blockName,