36 : _num_splits(num_splits), _axis(axis), _size_splits(size_splits)
61 if(_size_splits.empty())
63 output_descriptor.
shape.
set(tmp_axis, split_size);
64 coords.
set(tmp_axis, idx * split_size);
68 int split_size = _size_splits[idx];
71 split_size = input_descriptor.
shape[tmp_axis];
72 for(
unsigned int i = 0; i < _size_splits.size() - 1; ++i)
73 split_size -= _size_splits[i];
75 output_descriptor.
shape.
set(tmp_axis, split_size);
77 for(
unsigned int i = 0; i < idx; ++i)
78 coord_value += _size_splits[i];
79 coords.
set(tmp_axis, coord_value);
82 return std::make_pair(output_descriptor, coords);
90 for(
unsigned int i = 0; i < _outputs.size(); ++i)
116 int num_dimension =
static_cast<int32_t
>(src->desc().shape.num_dimensions());
119 int split_size = (_size_splits.empty()) ? (input_descriptor.
shape[tmp_axis] / _num_splits) : _size_splits[idx];
122 split_size = input_descriptor.
shape[tmp_axis];
123 for(
unsigned int i = 0; i < _size_splits.size() - 1; ++i)
124 split_size -= _size_splits[i];
126 output_descriptor.
shape.
set(tmp_axis, split_size);
128 return output_descriptor;
141 if(_size_splits.empty())
TensorShape shape
Tensor shape.
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
NodeType type() const override
Returns node's type.
unsigned int num_splits() const
Number of splits accessor.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
TensorDescriptor & desc()
TensorInfo metadata accessor.
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual void visit(INode &n)=0
Visit INode.
Tensor * output(size_t idx) const
Returns the tensor of a given output of the node.
TensorID input_id(size_t idx) const
Returns the tensor ID of a given input of the node.
unsigned int axis() const
Split axis accessor.
SplitLayerNode(unsigned int num_splits, int axis=0, std::vector< int > size_splits=std::vector< int >())
Default Constructor.
constexpr EdgeID EmptyEdgeID
void accept(INodeVisitor &v) override
Accepts a node visitor.
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
std::pair< TensorDescriptor, Coordinates > compute_output_descriptor(const TensorDescriptor &input_descriptor, unsigned int num_splits, int axis, unsigned int idx)
Computes split layer output descriptor.
TensorID output_id(size_t idx) const
Returns the tensor ID of a given output of the node.
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
bool forward_descriptors() override
Forwards descriptor information to outputs if possible.
constexpr TensorID NullTensorID
Constant NodeID specifying an equivalent of null node.
Tensor * input(size_t idx) const
Returns the tensor of a given input of the node.
Status validate() const override
Validate node.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
TensorDescriptor configure_output(size_t idx) const override
Calculates output configuration.