12 const std::vector<const TensorInfo*>& inputs,
13 const std::vector<const TensorInfo*>& outputs,
19 throw armnn::Exception(
"ConvertStridedSliceToTosaOperator: Ellipses mask not supported.");
23 std::vector<int32_t> begin(stridedSliceDescriptor->
m_Begin);
24 std::vector<int32_t> end(stridedSliceDescriptor->
m_End);
25 std::vector<int32_t> strides(stridedSliceDescriptor->
m_Stride);
27 for (
auto stride : strides)
32 throw armnn::Exception(
"ConvertStridedSliceToTosaOperator: Strides greater than 1 not supported.");
36 std::string inputName = std::string(
"input_");
39 std::string outputName = std::string(
"output0_");
50 std::vector<TosaSerializationTensor*> tensors;
51 std::vector<TosaSerializationOperator *> operators;
54 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
59 if(inputName.find(
"input_") != std::string::npos)
61 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
64 DType outputDType =
ArmNNToDType(outputs[0]->GetDataType());
68 uint32_t inputRank = inputs[0]->GetShape().GetNumDimensions();
71 for (uint32_t i = 0; i < inputRank; ++i)
75 end[i] = inputShape[i] + end[i];
79 begin[i] = inputShape[i] + begin[i];
83 std::vector<int32_t> a1_size(inputRank);
86 for (uint32_t i = 0; i < inputRank; ++i)
92 if (stridedSliceDescriptor->
m_EndMask & (1 << i))
94 end[i] = inputShape[i];
97 a1_size[i] = end[i] - begin[i];
100 TosaSliceAttribute sliceAttribute(begin, a1_size);
102 auto* sliceOp1 =
new TosaSerializationOperator(Op_SLICE,
103 Attribute_SliceAttribute,
108 tensors.push_back(
new TosaSerializationTensor(outputNameSlice, a1_size, outputDType, {}));
109 operators.push_back(sliceOp1);
112 std::vector<int32_t> newShape;
114 for (uint32_t i = 0; i < inputRank; ++i)
119 newShape.push_back(a1_size[i]);
123 TosaReshapeAttribute reshapeAttribute2(newShape);
125 auto* reshapeOp2 =
new TosaSerializationOperator(Op_RESHAPE,
126 Attribute_ReshapeAttribute,
131 tensors.push_back(
new TosaSerializationTensor(outputName, newShape, outputDType, {}));
132 operators.push_back(reshapeOp2);
136 return new TosaSerializationBasicBlock(blockName,