24 #ifndef ARM_COMPUTE_CPP_SPLIT_H
25 #define ARM_COMPUTE_CPP_SPLIT_H
38 template <
typename SliceType,
typename TensorInterfaceType = ITensor>
43 : _outputs_vector(), _slice_functions(), _num_outputs(0)
64 unsigned int total_output_shape_size = 0;
67 const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](
ITensorInfo *
info)
69 unsigned int output_shape_size = info->tensor_shape().total_size();
70 total_output_shape_size += output_shape_size;
71 return output_shape_size == 0;
74 if(using_split_shapes)
85 unsigned int axis_offset = 0;
86 for(
const auto &output : outputs)
89 if(using_split_shapes)
100 for(
unsigned int d = 0; d <
output_shape.num_dimensions(); ++d)
102 end_coords.
set(d, -1);
109 tmp_output_info =
input->clone()->set_is_resizable(
true).set_tensor_shape(
output_shape);
113 start_coords.
set(axis, axis_offset);
114 end_coords.
set(axis, axis_offset + axis_split_step);
117 axis_offset += axis_split_step;
131 void configure(
const TensorInterfaceType *
input,
const std::vector<TensorInterfaceType *> &outputs,
unsigned int axis)
134 _num_outputs = outputs.size();
135 _slice_functions.resize(_num_outputs);
138 std::vector<ITensorInfo *> outputs_info;
139 for(
auto &output : outputs)
142 outputs_info.emplace_back(output->info());
146 const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](
ITensorInfo *
info)
148 return info->tensor_shape().total_size() == 0;
154 unsigned int axis_offset = 0;
170 for(
unsigned int d = 0; d <
output_shape.num_dimensions(); ++d)
172 end_coords.
set(d, -1);
176 start_coords.
set(axis, axis_offset);
177 end_coords.
set(axis, axis_offset + axis_split_step);
180 _slice_functions[i].configure(
input, outputs[i], start_coords, end_coords);
186 axis_offset += axis_split_step;
192 std::vector<TensorInterfaceType *> _outputs_vector;
193 std::vector<SliceType> _slice_functions;
194 unsigned int _num_outputs;