34 #if defined(ARM_COMPUTE_ENABLE_SVE)
51 const auto window_start_x =
static_cast<int>(window.x().start());
52 const auto window_end_x =
static_cast<int>(window.x().end());
54 Window win_collapsed = window.collapse_if_possible(window,
Window::DimZ);
55 win_collapsed.set(
Window::DimX, Window::Dimension(0, 1, 1));
58 Iterator output(
dst, win_collapsed);
60 const auto input_mean =
reinterpret_cast<const float *
>(mean->ptr_to_element(Coordinates(0, 0)));
61 const auto input_var =
reinterpret_cast<const float *
>(var->ptr_to_element(Coordinates(0, 0)));
62 const auto input_gamma =
63 (gamma !=
nullptr) ?
reinterpret_cast<const float *
>(gamma->ptr_to_element(Coordinates(0, 0))) :
nullptr;
64 const auto input_beta =
65 (beta !=
nullptr) ?
reinterpret_cast<const float *
>(beta->ptr_to_element(Coordinates(0, 0))) :
nullptr;
67 const auto epsilon_vec = svdup_n_f32(
epsilon);
68 const auto const_1 = svdup_n_f32(1.f);
69 const auto const_0 = svdup_n_f32(0.f);
70 const auto va = svdup_n_f32(
act_info.a());
71 const auto vb = svdup_n_f32(
act_info.b());
74 [&](
const Coordinates &)
76 const auto input_ptr =
reinterpret_cast<const float *
>(
input.ptr());
77 const auto output_ptr =
reinterpret_cast<float *
>(output.ptr());
80 int x = window_start_x;
81 svbool_t pg = svwhilelt_b32(x, window_end_x);
85 const auto mean_vec = svld1_f32(pg, input_mean + x);
86 const auto var_vec = svld1_f32(pg, input_var + x);
87 const auto gamma_vec = (input_gamma !=
nullptr) ? svld1_f32(pg, input_gamma + x) : const_1;
88 const auto beta_vec = (input_beta !=
nullptr) ? svld1_f32(pg, input_beta + x) : const_0;
91 const auto tmp = svadd_f32_z(pg, var_vec, epsilon_vec);
92 auto denominator = svrsqrte_f32(tmp);
94 svmul_f32_z(pg, svrsqrts_f32(svmul_f32_z(pg, tmp, denominator), denominator), denominator);
96 svmul_f32_z(pg, svrsqrts_f32(svmul_f32_z(pg, tmp, denominator), denominator), denominator);
99 const auto numerator = svsub_f32_z(pg, svld1_f32(pg, input_ptr + x), mean_vec);
100 const auto x_bar = svmul_f32_z(pg, numerator, denominator);
101 auto res = svmla_f32_z(pg, beta_vec, x_bar, gamma_vec);
106 if (
act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU)
108 res = svmax_f32_z(pg, const_0, res);
110 else if (
act_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
112 res = svmin_f32_z(pg, va, svmax_f32_z(pg, const_0, res));
114 else if (
act_info.activation() == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
116 res = svmin_f32_z(pg, va, svmax_f32_z(pg, vb, res));
121 svst1_f32(pg, output_ptr + x, res);
124 pg = svwhilelt_b32(x, window_end_x);
125 }
while (svptest_any(svptrue_b32(), pg));