42 TensorShape get_output_shape(TensorShape
shape,
unsigned int axis)
60 const int upper_dims =
src.shape().total_size_upper(axis + 1);
61 const int lower_dims =
src.shape().total_size_lower(axis + 1);
62 const int lower_dims_sum = sum.
shape().total_size_lower(axis + 1);
64 for(
int du = 0; du < upper_dims; ++du)
66 const T *src_row_ptr =
src.data() + du * lower_dims;
67 T *dst_row_ptr =
dst.data() + du * lower_dims;
72 const int elems =
src.shape()[0];
73 const T normalization_value = sqrt(std::max(sum[du],
static_cast<T
>(
epsilon)));
74 std::transform(src_row_ptr, src_row_ptr + elems, dst_row_ptr, [normalization_value](T val)
76 return val / normalization_value;
83 for(
int ld = 0; ld < lower_dims; ++ld)
85 const T normalization_value = sqrt(std::max(sum[ld % lower_dims_sum + du * lower_dims_sum],
static_cast<T
>(
epsilon)));
86 dst_row_ptr[ld] = src_row_ptr[ld] / normalization_value;