15 const std::vector<const TensorInfo*>& inputs,
16 const std::vector<const TensorInfo*>& outputs,
20 "ConvertSplitToTosaOperator: Split must have only one input" );
23 "ConvertSplitToTosaOperator: Split must have at least one output" );
25 if (!inputs[0]->GetShape().AreAllDimensionsSpecified())
27 throw armnn::Exception(
"ConvertSplitToTosaOperator: Dynamic input dimensions are unsupported.");
30 std::string inputName = std::string(
"input_");
31 std::vector<std::string> outputNames;
34 unsigned int numSplit = splitDescriptor->
GetNumViews();
41 for (
unsigned int i=0; i < numSplit; ++i)
45 outputNames.push_back(outputName);
50 for (
unsigned int i=0; i < numSplit; ++i)
53 std::string outputName =
"output" + std::to_string(i) +
"_";
54 outputNames.push_back(outputName);
59 std::set<unsigned int> splitAxis =
ComputeSplitAxis(*splitDescriptor, inputs[0]->GetShape());
60 if (splitAxis.size() != 1)
64 uint32_t axis = *splitAxis.begin();
66 std::vector<TosaSerializationOperator*> ops;
67 std::vector<int32_t> beginVals(inputs[0]->GetNumDimensions(), 0);
68 for (
unsigned int i = 0; i < numSplit; ++i)
71 TosaSliceAttribute attribute(beginVals, sizeVals);
72 auto* op =
new TosaSerializationOperator(Op_SLICE,
73 Attribute_SliceAttribute,
81 beginVals[axis] += sizeVals[axis];
84 std::vector<TosaSerializationTensor*> tensors;
88 if(inputName.find(
"input_") != std::string::npos)
91 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
93 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
96 DType outputDType =
ArmNNToDType(outputs[0]->GetDataType());
97 for (
unsigned int i = 0; i < numSplit; ++i)
100 tensors.push_back(
new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {}));
105 return new TosaSerializationBasicBlock(blockName,