24 #ifndef ARM_COMPUTE_CPP_SPLIT_H 25 #define ARM_COMPUTE_CPP_SPLIT_H 40 template <
typename SliceType,
typename TensorInterfaceType = ITensor>
45 : _outputs_vector(), _slice_functions(), _num_outputs(0)
66 unsigned int total_output_shape_size = 0;
69 const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](
ITensorInfo *
info)
71 unsigned int output_shape_size =
info->tensor_shape().total_size();
72 total_output_shape_size += output_shape_size;
73 return output_shape_size == 0;
76 if(using_split_shapes)
87 unsigned int axis_offset = 0;
88 for(
const auto &output : outputs)
91 if(using_split_shapes)
104 end_coords.
set(d, -1);
111 tmp_output_info = input->
clone()->set_is_resizable(
true).set_tensor_shape(
output_shape);
115 start_coords.
set(axis, axis_offset);
116 end_coords.
set(axis, axis_offset + axis_split_step);
119 axis_offset += axis_split_step;
133 void configure(
const TensorInterfaceType *
input,
const std::vector<TensorInterfaceType *> &outputs,
unsigned int axis)
136 _num_outputs = outputs.size();
137 _slice_functions.resize(_num_outputs);
140 std::vector<ITensorInfo *> outputs_info;
141 for(
auto &output : outputs)
144 outputs_info.emplace_back(output->info());
148 const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](
ITensorInfo *
info)
150 return info->tensor_shape().total_size() == 0;
156 unsigned int axis_offset = 0;
166 const size_t axis_split_step = output_shape[axis];
174 end_coords.
set(d, -1);
178 start_coords.
set(axis, axis_offset);
179 end_coords.
set(axis, axis_offset + axis_split_step);
182 _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
188 axis_offset += axis_split_step;
194 std::vector<TensorInterfaceType *> _outputs_vector;
195 std::vector<SliceType> _slice_functions;
196 unsigned int _num_outputs;
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
static Status validate(const ITensorInfo *input, const std::vector< ITensorInfo *> &outputs, unsigned int axis)
Static function to check if given info will lead to a valid configuration of CPPSplit.
Base class for all functions.
std::unique_ptr< ITensorInfo > clone() const override
Provide a clone of the current object of class T.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Store the tensor's metadata.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Copyright (c) 2017-2021 Arm Limited.
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
void configure(const TensorInterfaceType *input, const std::vector< TensorInterfaceType *> &outputs, unsigned int axis)
Initialise the kernel's input and outputs.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
size_t total_size() const
Collapses all dimensions to a single linear total size.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
Container for valid region of a window.
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
TensorShape compute_split_shape(const ITensorInfo *input, unsigned int axis, unsigned int num_splits)
Calculate the split output shape of a tensor.
Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes, const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
Basic function to split a tensor along a given axis.