36 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
41 using BatchNomalizationPtr = void (*)(ITensor *
src,
49 const Window &window);
52 void batch_normalization(ITensor *
src,
63 using ExactTagType =
typename wrapper::traits::neon_bitvector_tag_t<float16_t, wrapper::traits::BitWidth::W128>;
65 const int window_step_x = 8;
66 const auto window_start_x =
static_cast<int>(window.x().start());
67 const auto window_end_x =
static_cast<int>(window.x().end());
69 Window win_collapsed = window.collapse_if_possible(window,
Window::DimZ);
70 win_collapsed.set(
Window::DimX, Window::Dimension(0, 1, 1));
73 Iterator output(
dst, win_collapsed);
75 const auto input_mean =
reinterpret_cast<const float16_t *
>(mean->ptr_to_element(Coordinates(0, 0)));
76 const auto input_var =
reinterpret_cast<const float16_t *
>(var->ptr_to_element(Coordinates(0, 0)));
77 const auto input_gamma =
78 (gamma !=
nullptr) ?
reinterpret_cast<const float16_t *
>(gamma->ptr_to_element(Coordinates(0, 0))) :
nullptr;
79 const auto input_beta =
80 (beta !=
nullptr) ?
reinterpret_cast<const float16_t *
>(beta->ptr_to_element(Coordinates(0, 0))) :
nullptr;
87 [&](
const Coordinates &)
89 const auto input_ptr =
reinterpret_cast<const float16_t *
>(
input.ptr());
90 const auto output_ptr =
reinterpret_cast<float16_t *
>(output.ptr());
93 int x = window_start_x;
94 for (; x <= (window_end_x - window_step_x); x += window_step_x)
99 const auto gamma_vec = (input_gamma !=
nullptr)
102 const auto beta_vec = (input_beta !=
nullptr)
117 activation_functor(res);
125 for (; x < window_end_x; ++x)
128 const float16_t gamma = (input_gamma !=
nullptr) ? input_gamma[x] : 1.f;
129 const float16_t beta = (input_beta !=
nullptr) ? input_beta[x] : 0.f;
131 const float16_t denominator = sqrt(input_var[x] +
epsilon);
132 const float16_t numerator = input_ptr[x] - input_mean[x];
133 const float16_t x_bar = numerator / denominator;
134 float16_t res = beta + x_bar * gamma;
139 activation_functor(res);
143 *
reinterpret_cast<float16_t *
>(output_ptr + x) = res;
150 static std::map<ActivationLayerInfo::ActivationFunction, BatchNomalizationPtr> fused_map = {
151 {ActivationLayerInfo::ActivationFunction::RELU, &batch_normalization<detail::relu<float16_t, 8>>},
152 {ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &batch_normalization<detail::brelu<float16_t, 8>>},
153 {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &batch_normalization<detail::lubrelu<float16_t, 8>>}};
162 const ITensor *gamma,
165 const Window &window)
173 batch_normalization<detail::dummy<float16_t, 8>>(
src,
dst, mean, var, beta, gamma,
epsilon,
act_info, window);