ArmNN
 25.11
Loading...
Searching...
No Matches
StridedSliceOperator.cpp
Go to the documentation of this file.
1//
2// Copyright © 2024-2025 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SliceOperator.hpp"
7
8// This function is paraphrased from:
9// tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
10
11TosaSerializationBasicBlock* ConvertStridedSliceToTosaOperator(const Layer* layer,
12 const std::vector<const TensorInfo*>& inputs,
13 const std::vector<const TensorInfo*>& outputs,
14 const StridedSliceDescriptor* stridedSliceDescriptor)
15{
16 // Limitations
17 if (stridedSliceDescriptor->m_EllipsisMask != 0)
18 {
19 throw armnn::Exception("ConvertStridedSliceToTosaOperator: Ellipses mask not supported.");
20 }
21
22 /// Begin with the slice
23 std::vector<int32_t> begin(stridedSliceDescriptor->m_Begin);
24 std::vector<int32_t> end(stridedSliceDescriptor->m_End);
25 std::vector<int32_t> strides(stridedSliceDescriptor->m_Stride);
26
27 for (auto stride : strides)
28 {
29 if (stride != 1)
30 {
31 // Only strides with values 1 supported otherwise reshape invoked which creates tensors with more than 5D
32 throw armnn::Exception("ConvertStridedSliceToTosaOperator: Strides greater than 1 not supported.");
33 }
34 }
35
36 std::string inputName = std::string("input_");
37 std::string outputNameSlice = std::string("layer_intermediate1_") + GetUniqueTosaMappingID();
38 std::string outputNameReshape = std::string("layer_intermediate2_") + GetUniqueTosaMappingID();
39 std::string outputName = std::string("output0_");
40 std::string blockName = std::string("Op_SLICE_block_") + GetUniqueTosaMappingID();
41
42 // If a layer is present then the block will be used for execution, so input and output names need to be determined
43 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
44 if(layer != nullptr)
45 {
46 inputName = GenerateUniqueInputName(layer->GetInputSlot(0));
47 outputName = GenerateUniqueOutputName(*layer);
48 }
49
50 std::vector<TosaSerializationTensor*> tensors;
51 std::vector<TosaSerializationOperator *> operators;
52
53 std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape());
54 DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
55
56 // Only add input tensors if connected layer is an input layer.
57 // As intermediate or constant tensors will be created separately.
58 // There also can't be duplicate tensor.
59 if(inputName.find("input_") != std::string::npos)
60 {
61 tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
62 }
63
64 DType outputDType = ArmNNToDType(outputs[0]->GetDataType());
65 std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape());
66
67 // Figure out size
68 uint32_t inputRank = inputs[0]->GetShape().GetNumDimensions();
69
70 // handle cases where end or begin values are negative
71 for (uint32_t i = 0; i < inputRank; ++i)
72 {
73 if (end[i] < 0)
74 {
75 end[i] = inputShape[i] + end[i];
76 }
77 if (begin[i] < 0)
78 {
79 begin[i] = inputShape[i] + begin[i];
80 }
81 }
82
83 std::vector<int32_t> a1_size(inputRank);
84
85 // If mask set default to begin and end size from input tensor
86 for (uint32_t i = 0; i < inputRank; ++i)
87 {
88 if (stridedSliceDescriptor->m_BeginMask & (1 << i))
89 {
90 begin[i] = 0;
91 }
92 if (stridedSliceDescriptor->m_EndMask & (1 << i))
93 {
94 end[i] = inputShape[i];
95 }
96
97 a1_size[i] = end[i] - begin[i];
98 }
99
100 TosaSliceAttribute sliceAttribute(begin, a1_size);
101
102 auto* sliceOp1 = new TosaSerializationOperator(Op_SLICE,
103 Attribute_SliceAttribute,
104 &sliceAttribute,
105 {inputName},
106 {outputNameSlice});
107
108 tensors.push_back(new TosaSerializationTensor(outputNameSlice, a1_size, outputDType, {}));
109 operators.push_back(sliceOp1);
110
111 // If unary striding is used we can reverse, reshape, and return the result.
112 std::vector<int32_t> newShape;
113
114 for (uint32_t i = 0; i < inputRank; ++i)
115 {
116 // Remove dimension specified in ShrinkAxisMask
117 if (!(stridedSliceDescriptor->m_ShrinkAxisMask & (1 << i)))
118 {
119 newShape.push_back(a1_size[i]);
120 }
121 }
122
123 TosaReshapeAttribute reshapeAttribute2(newShape);
124
125 auto* reshapeOp2 = new TosaSerializationOperator(Op_RESHAPE,
126 Attribute_ReshapeAttribute,
127 &reshapeAttribute2,
128 {outputNameSlice},
129 {outputName});
130
131 tensors.push_back(new TosaSerializationTensor(outputName, newShape, outputDType, {}));
132 operators.push_back(reshapeOp2);
133
134 // operatorInputNames/operatorOutputNames ends up being the same as
135 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
136 return new TosaSerializationBasicBlock(blockName, // name
137 mainName, // region name
138 operators, // operators
139 tensors, // tensors
140 {inputName}, // inputs
141 {outputName}); // outputs
142}
TosaSerializationBasicBlock * ConvertStridedSliceToTosaOperator(const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const StridedSliceDescriptor *stridedSliceDescriptor)
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
Base class for all ArmNN exceptions so that users can filter to just those.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
A StridedSliceDescriptor for the StridedSliceLayer.
std::vector< int > m_Stride
Stride values for the input that will be sliced.
std::vector< int > m_Begin
Begin values for the input that will be sliced.
int32_t m_BeginMask
Begin mask value.
int32_t m_ShrinkAxisMask
Shrink axis mask value. If set, the nth specification shrinks the dimensionality by 1.
std::vector< int > m_End
End values for the input that will be sliced.
int32_t m_EndMask
End mask value.
int32_t m_EllipsisMask
Ellipsis mask value.