42 TensorShape get_output_shape(TensorShape
shape,
unsigned int axis)
60 const int upper_dims = src.shape().total_size_upper(axis + 1);
61 const int lower_dims = src.shape().total_size_lower(axis + 1);
62 const int lower_dims_sum = sum.
shape().total_size_lower(axis + 1);
64 for(
int du = 0; du < upper_dims; ++du)
66 const T *src_row_ptr = src.data() + du * lower_dims;
67 T *dst_row_ptr =
dst.data() + du * lower_dims;
72 const int elems = src.shape()[0];
73 const T normalization_value = sqrt(std::max(sum[du], static_cast<T>(epsilon)));
74 std::transform(src_row_ptr, src_row_ptr + elems, dst_row_ptr, [normalization_value](T val)
76 return val / normalization_value;
83 for(
int ld = 0; ld < lower_dims; ++ld)
85 const T normalization_value = sqrt(std::max(sum[ld % lower_dims_sum + du * lower_dims_sum], static_cast<T>(epsilon)));
86 dst_row_ptr[ld] = src_row_ptr[ld] / normalization_value;
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
DATA_TYPE sum(__global const DATA_TYPE *input)
Calculate sum of a vector.
DataType data_type() const override
Data type of the tensor.
TensorShape shape() const override
Shape of the tensor.
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
SimpleTensor< T > l2_normalize(const SimpleTensor< T > &src, unsigned int axis, float epsilon)
Simple tensor object that stores elements in a consecutive chunk of memory.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.