Compute Library
 21.02
CPPSplit.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CPP_SPLIT_H
25 #define ARM_COMPUTE_CPP_SPLIT_H
26 
27 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Types.h"
32 
34 
36 
37 namespace arm_compute
38 {
39 /** Basic function to split a tensor along a given axis */
40 template <typename SliceType, typename TensorInterfaceType = ITensor>
41 class CPPSplit : public IFunction
42 {
43 public:
45  : _outputs_vector(), _slice_functions(), _num_outputs(0)
46  {
47  }
48  /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit
49  *
50  * @param[in] input The input tensor info. Data types supported: All.
51  * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input.
52  * The output tensors should match the input tensor dimensions for all shape dimensions apart
53  * from the split dimension
54  * @param[in] axis Axis on which to split the input.
55  *
56  * @return a status
57  */
58  static Status validate(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs, unsigned int axis)
59  {
62  ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
63 
64  // Get output shape
66  unsigned int total_output_shape_size = 0;
67 
68  // Sum the output sizes and fall back to evenly-sized splits if any are zero
69  const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info)
70  {
71  unsigned int output_shape_size = info->tensor_shape().total_size();
72  total_output_shape_size += output_shape_size;
73  return output_shape_size == 0;
74  });
75 
76  if(using_split_shapes)
77  {
78  ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size);
79  }
80  else
81  {
84  }
85 
86  // Validate output tensors
87  unsigned int axis_offset = 0;
88  for(const auto &output : outputs)
89  {
91  if(using_split_shapes)
92  {
93  output_shape = output->tensor_shape();
95  }
96 
97  const size_t axis_split_step = output_shape[axis];
98 
99  // Start/End coordinates
100  Coordinates start_coords;
101  Coordinates end_coords;
102  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
103  {
104  end_coords.set(d, -1);
105  }
106 
107  // Output auto inizialitation if not yet initialized
108  TensorInfo tmp_output_info = *output->clone();
109  if(tmp_output_info.tensor_shape().total_size() == 0)
110  {
111  tmp_output_info = input->clone()->set_is_resizable(true).set_tensor_shape(output_shape);
112  }
113 
114  // Update coordinate on axis
115  start_coords.set(axis, axis_offset);
116  end_coords.set(axis, axis_offset + axis_split_step);
117 
118  ARM_COMPUTE_RETURN_ON_ERROR(SliceType::validate(input, output, start_coords, end_coords));
119  axis_offset += axis_split_step;
120  }
121 
122  return Status{};
123  }
124 
125  /** Initialise the kernel's input and outputs.
126  *
127  * @param[in] input The input tensor. Data types supported: All
128  * @param[out] outputs A vector containing the output tensors. Data types supported: Same as @p input.
129  * The output tensors should match the input tensor dimensions for all shape dimensions apart
130  * from the split dimension.
131  * @param[in] axis Axis on which to split the input.
132  */
133  void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis)
134  {
135  // Create Slice functions
136  _num_outputs = outputs.size();
137  _slice_functions.resize(_num_outputs);
138 
139  // Extract output tensor info
140  std::vector<ITensorInfo *> outputs_info;
141  for(auto &output : outputs)
142  {
144  outputs_info.emplace_back(output->info());
145  }
146 
147  // If any of the outputs have a zero size, fall-back to using evenly-sized output splits
148  const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info)
149  {
150  return info->tensor_shape().total_size() == 0;
151  });
152 
153  // Validate
154  ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis));
155 
156  unsigned int axis_offset = 0;
157  unsigned int i = 0;
158 
159  for(const auto &output_info : outputs_info)
160  {
161  // Get output shape
162  TensorShape output_shape = (outputs_have_sizes ?
163  output_info->tensor_shape() :
164  arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs));
165 
166  const size_t axis_split_step = output_shape[axis];
167 
168  // Start/End coordinates
169  Coordinates start_coords;
170  Coordinates end_coords;
171 
172  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
173  {
174  end_coords.set(d, -1);
175  }
176 
177  // Update coordinate on axis
178  start_coords.set(axis, axis_offset);
179  end_coords.set(axis, axis_offset + axis_split_step);
180 
181  // Configure slice function
182  _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
183 
184  // Set valid region from shape
185  outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
186 
187  // Update axis offset
188  axis_offset += axis_split_step;
189  ++i;
190  }
191  }
192 
193 protected:
194  std::vector<TensorInterfaceType *> _outputs_vector;
195  std::vector<SliceType> _slice_functions;
196  unsigned int _num_outputs;
197 };
198 
199 } // namespace arm_compute
200 #endif /* ARM_COMPUTE_CPP_SPLIT_H */
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
Definition: Dimensions.h:76
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
static Status validate(const ITensorInfo *input, const std::vector< ITensorInfo *> &outputs, unsigned int axis)
Static function to check if given info will lead to a valid configuration of CPPSplit.
Definition: CPPSplit.h:58
Shape of a tensor.
Definition: TensorShape.h:39
Base class for all functions.
Definition: IFunction.h:30
std::unique_ptr< ITensorInfo > clone() const override
Provide a clone of the current object of class T.
Definition: TensorInfo.cpp:316
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Status class.
Definition: Error.h:52
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
Copyright (c) 2017-2021 Arm Limited.
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:163
void configure(const TensorInterfaceType *input, const std::vector< TensorInterfaceType *> &outputs, unsigned int axis)
Initialise the kernel&#39;s input and outputs.
Definition: CPPSplit.h:133
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
Coordinates of an item.
Definition: Coordinates.h:37
size_t total_size() const
Collapses all dimensions to a single linear total size.
Definition: TensorShape.h:172
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
Definition: Dimensions.h:143
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:45
Container for valid region of a window.
Definition: Types.h:188
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
Definition: TensorInfo.h:262
TensorShape compute_split_shape(const ITensorInfo *input, unsigned int axis, unsigned int num_splits)
Calculate the split output shape of a tensor.
Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes, const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
Basic function to split a tensor along a given axis.
Definition: CPPSplit.h:41