Compute Library
 23.08
CPPSplit.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2021,2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CPP_SPLIT_H
25 #define ARM_COMPUTE_CPP_SPLIT_H
26 
27 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Types.h"
32 
34 
35 namespace arm_compute
36 {
37 /** Basic function to split a tensor along a given axis */
38 template <typename SliceType, typename TensorInterfaceType = ITensor>
39 class CPPSplit : public IFunction
40 {
41 public:
43  : _outputs_vector(), _slice_functions(), _num_outputs(0)
44  {
45  }
46  /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit
47  *
48  * @param[in] input The input tensor info. Data types supported: All.
49  * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input.
50  * The output tensors should match the input tensor dimensions for all shape dimensions apart
51  * from the split dimension
52  * @param[in] axis Axis on which to split the input.
53  *
54  * @return a status
55  */
56  static Status validate(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs, unsigned int axis)
57  {
59  ARM_COMPUTE_RETURN_ERROR_ON(axis >= input->num_dimensions());
60  ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
61 
62  // Get output shape
64  unsigned int total_output_shape_size = 0;
65 
66  // Sum the output sizes and fall back to evenly-sized splits if any are zero
67  const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info)
68  {
69  unsigned int output_shape_size = info->tensor_shape().total_size();
70  total_output_shape_size += output_shape_size;
71  return output_shape_size == 0;
72  });
73 
74  if(using_split_shapes)
75  {
76  ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size);
77  }
78  else
79  {
81  ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
82  }
83 
84  // Validate output tensors
85  unsigned int axis_offset = 0;
86  for(const auto &output : outputs)
87  {
89  if(using_split_shapes)
90  {
91  output_shape = output->tensor_shape();
92  ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
93  }
94 
95  const size_t axis_split_step = output_shape[axis];
96 
97  // Start/End coordinates
98  Coordinates start_coords;
99  Coordinates end_coords;
100  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
101  {
102  end_coords.set(d, -1);
103  }
104 
105  // Output auto inizialitation if not yet initialized
106  TensorInfo tmp_output_info = *output->clone();
107  if(tmp_output_info.tensor_shape().total_size() == 0)
108  {
109  tmp_output_info = input->clone()->set_is_resizable(true).set_tensor_shape(output_shape);
110  }
111 
112  // Update coordinate on axis
113  start_coords.set(axis, axis_offset);
114  end_coords.set(axis, axis_offset + axis_split_step);
115 
116  ARM_COMPUTE_RETURN_ON_ERROR(SliceType::validate(input, output, start_coords, end_coords));
117  axis_offset += axis_split_step;
118  }
119 
120  return Status{};
121  }
122 
123  /** Initialise the kernel's input and outputs.
124  *
125  * @param[in] input The input tensor. Data types supported: All
126  * @param[out] outputs A vector containing the output tensors. Data types supported: Same as @p input.
127  * The output tensors should match the input tensor dimensions for all shape dimensions apart
128  * from the split dimension.
129  * @param[in] axis Axis on which to split the input.
130  */
131  void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis)
132  {
133  // Create Slice functions
134  _num_outputs = outputs.size();
135  _slice_functions.resize(_num_outputs);
136 
137  // Extract output tensor info
138  std::vector<ITensorInfo *> outputs_info;
139  for(auto &output : outputs)
140  {
142  outputs_info.emplace_back(output->info());
143  }
144 
145  // If any of the outputs have a zero size, fall-back to using evenly-sized output splits
146  const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info)
147  {
148  return info->tensor_shape().total_size() == 0;
149  });
150 
151  // Validate
152  ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis));
153 
154  unsigned int axis_offset = 0;
155  unsigned int i = 0;
156 
157  for(const auto &output_info : outputs_info)
158  {
159  // Get output shape
160  TensorShape output_shape = (outputs_have_sizes ?
161  output_info->tensor_shape() :
163 
164  const size_t axis_split_step = output_shape[axis];
165 
166  // Start/End coordinates
167  Coordinates start_coords;
168  Coordinates end_coords;
169 
170  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
171  {
172  end_coords.set(d, -1);
173  }
174 
175  // Update coordinate on axis
176  start_coords.set(axis, axis_offset);
177  end_coords.set(axis, axis_offset + axis_split_step);
178 
179  // Configure slice function
180  _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
181 
182  // Set valid region from shape
183  outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
184 
185  // Update axis offset
186  axis_offset += axis_split_step;
187  ++i;
188  }
189  }
190 
191 protected:
192  std::vector<TensorInterfaceType *> _outputs_vector;
193  std::vector<SliceType> _slice_functions;
194  unsigned int _num_outputs;
195 };
196 
197 } // namespace arm_compute
198 #endif /* ARM_COMPUTE_CPP_SPLIT_H */
arm_compute::TensorInfo::clone
std::unique_ptr< ITensorInfo > clone() const override
Definition: TensorInfo.cpp:316
arm_compute::Dimensions::set
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
Definition: Dimensions.h:76
Helpers.h
arm_compute::test::validation::output_shape
const auto output_shape
Definition: ConvolutionLayer.cpp:411
arm_compute::CPPSplit
Basic function to split a tensor along a given axis.
Definition: CPPSplit.h:39
arm_compute::IFunction
Base class for all functions.
Definition: IFunction.h:30
arm_compute::TensorShape
Shape of a tensor.
Definition: TensorShape.h:39
Types.h
TensorInfo.h
Error.h
arm_compute::CPPSplit::CPPSplit
CPPSplit()
Definition: CPPSplit.h:42
ARM_COMPUTE_RETURN_ON_ERROR
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
ARM_COMPUTE_ERROR_ON_NULLPTR
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
arm_compute::test::validation::output_info
output_info
Definition: DirectConvolutionLayer.cpp:547
ARM_COMPUTE_ERROR_THROW_ON
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:456
ARM_COMPUTE_RETURN_ERROR_ON
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:297
arm_compute::TensorShape::total_size
size_t total_size() const
Collapses all dimensions to a single linear total size.
Definition: TensorShape.h:176
arm_compute::ValidRegion
Container for valid region of a window.
Definition: Types.h:144
arm_compute::Status
Status class.
Definition: Error.h:52
arm_compute::CPPSplit::validate
static Status validate(const ITensorInfo *input, const std::vector< ITensorInfo * > &outputs, unsigned int axis)
Static function to check if given info will lead to a valid configuration of CPPSplit.
Definition: CPPSplit.h:56
arm_compute::CPPSplit::configure
void configure(const TensorInterfaceType *input, const std::vector< TensorInterfaceType * > &outputs, unsigned int axis)
Initialise the kernel's input and outputs.
Definition: CPPSplit.h:131
arm_compute::Coordinates
Coordinates of an item.
Definition: Coordinates.h:37
ShapeCalculator.h
IFunction.h
arm_compute::TensorInfo
Store the tensor's metadata.
Definition: TensorInfo.h:42
arm_compute
Copyright (c) 2017-2023 Arm Limited.
Definition: introduction.dox:24
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:163
arm_compute::validate
Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes, const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
Definition: CPPBoxWithNonMaximaSuppressionLimit.cpp:214
arm_compute::ITensorInfo
Store the tensor's metadata.
Definition: ITensorInfo.h:43
arm_compute::test::validation::info
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
arm_compute::misc::shape_calculator::compute_split_shape
TensorShape compute_split_shape(const ITensorInfo *input, unsigned int axis, unsigned int num_splits)
Calculate the split output shape of a tensor.
Definition: ShapeCalculator.h:1171
arm_compute::TensorInfo::tensor_shape
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
Definition: TensorInfo.h:235
arm_compute::test::validation::input
auto input
Definition: LSTMLayerQuantized.cpp:486