24 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
35 void mean_stddev_normalization<float16_t, 8>(ITensor *
input, ITensor *output,
float epsilon,
const Window &window)
41 const int window_step_x = 8;
42 const auto window_start_x =
static_cast<int>(window.x().start());
43 const auto window_end_x =
static_cast<int>(window.x().end());
45 Iterator input_itr(
input, win);
46 Iterator output_itr(output, win);
50 [&](
const Coordinates &)
52 int x = window_start_x;
53 auto in_ptr =
reinterpret_cast<const float16_t *
>(input_itr.ptr());
54 auto out_ptr =
reinterpret_cast<float16_t *
>(output_itr.ptr());
56 float16x8_t sum_vec = vdupq_n_f16(
static_cast<float16_t
>(0.0f));
57 float32x4_t sum_sq_vec = vdupq_n_f32(0.0f);
59 for (; x <= (window_end_x - window_step_x); x += window_step_x)
61 float16x8_t data = vld1q_f16(in_ptr + x);
63 float32x4_t
dl = vcvt_f32_f16(vget_low_f16(data));
64 float32x4_t dh = vcvt_f32_f16(vget_high_f16(data));
65 sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(
dl,
dl));
66 sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
69 float32x4_t sum_carry_res =
70 vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
71 float sum = vaddvq_f32(sum_carry_res);
72 float sum_sq = vaddvq_f32(sum_sq_vec);
75 for (; x < window_end_x; ++x)
77 const float fdata =
static_cast<float>(*(in_ptr + x));
79 sum_sq += fdata * fdata;
82 float16_t mean =
static_cast<float16_t
>(sum /
input->info()->dimension(0));
83 float var = (sum_sq /
input->info()->dimension(0)) - (mean * mean);
84 float16_t stddev_inv =
static_cast<float16_t
>(1.f / sqrt(var +
epsilon));
86 float16x8_t mean_vec = vdupq_n_f16(mean);
87 float16x8_t stddev_inv_vec = vdupq_n_f16(stddev_inv);
89 for (x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
91 float16x8_t data = vld1q_f16(in_ptr + x);
94 vst1q_f16(out_ptr + x, res);
96 for (; x < window_end_x; ++x)
98 *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv;
101 input_itr, output_itr);
106 return mean_stddev_normalization<float16_t, 8>(
input, output,
epsilon, window);