24 #ifndef ARM_COMPUTE_CPP_SPLIT_H
25 #define ARM_COMPUTE_CPP_SPLIT_H
37 template <
typename SliceType,
typename TensorInterfaceType = ITensor>
41 CPPSplit() : _outputs_vector(), _slice_functions(), _num_outputs(0)
62 unsigned int total_output_shape_size = 0;
65 const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(),
68 unsigned int output_shape_size =
69 info->tensor_shape().total_size();
70 total_output_shape_size += output_shape_size;
71 return output_shape_size == 0;
74 if (using_split_shapes)
85 unsigned int axis_offset = 0;
86 for (
const auto &output : outputs)
89 if (using_split_shapes)
102 end_coords.
set(d, -1);
109 tmp_output_info =
input->clone()->set_is_resizable(
true).set_tensor_shape(
output_shape);
113 start_coords.
set(axis, axis_offset);
114 end_coords.
set(axis, axis_offset + axis_split_step);
117 axis_offset += axis_split_step;
132 configure(
const TensorInterfaceType *
input,
const std::vector<TensorInterfaceType *> &outputs,
unsigned int axis)
135 _num_outputs = outputs.size();
136 _slice_functions.resize(_num_outputs);
139 std::vector<ITensorInfo *> outputs_info;
140 for (
auto &output : outputs)
143 outputs_info.emplace_back(output->info());
147 const bool outputs_have_sizes =
148 std::none_of(outputs_info.begin(), outputs_info.end(),
149 [](
ITensorInfo *
info) { return info->tensor_shape().total_size() == 0; });
154 unsigned int axis_offset = 0;
173 end_coords.
set(d, -1);
177 start_coords.
set(axis, axis_offset);
178 end_coords.
set(axis, axis_offset + axis_split_step);
181 _slice_functions[i].configure(
input, outputs[i], start_coords, end_coords);
187 axis_offset += axis_split_step;
193 std::vector<TensorInterfaceType *> _outputs_vector;
194 std::vector<SliceType> _slice_functions;
195 unsigned int _num_outputs;