24 #ifndef ARM_COMPUTE_TENSORSHAPE_H 25 #define ARM_COMPUTE_TENSORSHAPE_H 46 template <
typename... Ts>
51 if(_num_dimensions > 0)
53 std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
57 apply_dimension_correction();
79 TensorShape &
set(
size_t dimension,
size_t value,
bool apply_dim_correction =
true,
bool increase_dim_unit =
true)
90 std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
97 if(apply_dim_correction)
99 apply_dimension_correction();
116 std::copy(_id.begin() + n + 1, _id.end(), _id.begin() + n);
122 std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
125 apply_dimension_correction();
138 std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
149 _num_dimensions +=
step;
152 apply_dimension_correction();
174 return std::accumulate(_id.begin(), _id.end(), 1, std::multiplies<size_t>());
185 return std::accumulate(_id.begin() + dimension, _id.end(), 1, std::multiplies<size_t>());
197 return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>());
210 template <
typename... Shapes>
215 auto broadcast = [&bc_shape](
const TensorShape & other)
221 else if(other.num_dimensions() != 0)
225 const size_t dim_min = std::min(bc_shape[d], other[d]);
226 const size_t dim_max = std::max(bc_shape[d], other[d]);
228 if((dim_min != 1) && (dim_min != dim_max))
234 bc_shape.
set(d, dim_max);
246 void apply_dimension_correction()
248 for(
int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
void set(size_t dimension, T value, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
void shift_right(size_t step)
Shifts right the tensor shape increasing its dimensions.
void remove_dimension(size_t n)
Accessor to remove the dimension n from the tensor shape.
TensorShape collapsed_from(size_t start) const
Return a copy with collapsed dimensions starting from a given point.
TensorShape & operator=(const TensorShape &)=default
Allow instances of this class to be copied.
static TensorShape broadcast_shape(const Shapes &... shapes)
If shapes are broadcast compatible, return the broadcasted shape.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
size_t total_size_upper(size_t dimension) const
Collapses given dimension and above.
SimpleTensor< T > copy(const SimpleTensor< T > &src, const TensorShape &output_shape)
SimpleTensor< T2 > accumulate(const SimpleTensor< T1 > &src, DataType output_data_type)
size_t total_size_lower(size_t dimension) const
Compute size of dimensions lower than the given one.
void collapse(const size_t n, const size_t first=0)
Collapse dimensions.
Copyright (c) 2017-2021 Arm Limited.
TensorShape(Ts... dims)
Constructor to initialize the tensor shape.
library fill(src, distribution, 0)
size_t total_size() const
Collapses all dimensions to a single linear total size.
Dimensions with dimensionality.
std::array< size_t, num_max_dimensions >::iterator begin()
Returns a read/write iterator that points to the first element in the dimension array.
void for_each(F &&)
Base case of for_each.
std::array< size_t, num_max_dimensions >::iterator end()
Returns a read/write iterator that points one past the last element in the dimension array...
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
~TensorShape()=default
Default destructor.
static constexpr size_t num_max_dimensions
Number of dimensions the tensor has.
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.