24 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
36 template <
typename InputType,
typename AccType>
37 void vector_float_sum_fp16(AccType &result, AccType &result_square,
const InputType &inputs)
43 template <
typename InputType,
typename AccType>
44 InputType vector_float_norm_fp16(
const InputType &inputs,
45 const AccType &vec_mean,
46 const AccType &vec_multip,
47 const AccType &vec_beta)
53 inline void vector_float_sum_fp16(float32x4_t &result, float32x4_t &result_square,
const float16x8_t &inputs)
55 vector_float_sum_fp16(result, result_square, wrapper::vcvt<float>(
wrapper::vgetlow(inputs)));
56 vector_float_sum_fp16(result, result_square, wrapper::vcvt<float>(
wrapper::vgethigh(inputs)));
59 inline float16x8_t vector_float_norm_fp16(
const float16x8_t &inputs,
60 const float32x4_t &vec_mean,
61 const float32x4_t &vec_multip,
62 const float32x4_t &vec_beta)
66 const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm_fp16(input_low, vec_mean, vec_multip, vec_beta));
67 const auto result_high =
68 wrapper::vcvt<float16_t>(vector_float_norm_fp16(input_high, vec_mean, vec_multip, vec_beta));
74 template <
typename AccType>
75 void instance_normalization_nchw_fp16(
76 const ITensor *
input, ITensor *output,
float gamma,
float beta,
float epsilon,
const Window &window)
79 using ExactTagType =
typename wrapper::traits::neon_bitvector_tag_t<float16_t, wrapper::traits::BitWidth::W128>;
86 constexpr
int window_step_x = 16 /
sizeof(float16_t);
87 const unsigned int elements_plane =
input->info()->dimension(0) * output->info()->dimension(1);
89 Iterator input_it(
input, win);
92 [&](
const Coordinates &
id)
94 Window win_plane = window;
96 win_plane.set(
Window::DimZ, Window::Dimension(
id[2],
id[2] + 1, 1));
97 win_plane.set(3, Window::Dimension(
id[3],
id[3] + 1, 1));
99 Iterator input_plane_it(
input, win_plane);
100 Iterator output_plane_it(output, win_plane);
102 auto sum_h_w =
static_cast<AccType
>(0.f);
103 auto sum_squares_h_w =
static_cast<AccType
>(0.f);
107 [&](
const Coordinates &)
109 const auto input_ptr =
reinterpret_cast<const float16_t *
>(input_plane_it.ptr());
111 auto vec_sum_h_w =
wrapper::vdup_n(
static_cast<AccType
>(0.f), ExactTagType{});
112 auto vec_sum_squares_h_w =
wrapper::vdup_n(
static_cast<AccType
>(0.f), ExactTagType{});
115 int x = window.x().start();
116 for (; x <= (window.x().
end() - window_step_x); x += window_step_x)
119 vector_float_sum_fp16(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
123 auto vec2_sum_squares_h_w =
127 vec2_sum_squares_h_w =
wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
133 for (; x < window.x().
end(); ++x)
135 const auto value =
static_cast<AccType
>(*(input_ptr + x));
137 sum_squares_h_w += value * value;
140 input_plane_it, output_plane_it);
142 const auto mean_h_w = sum_h_w / elements_plane;
143 const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
145 const auto multip_h_w = gamma / std::sqrt(var_h_w +
epsilon);
146 const auto vec_mean_h_w =
wrapper::vdup_n(
static_cast<AccType
>(mean_h_w), ExactTagType{});
147 const auto vec_multip_h_w =
wrapper::vdup_n(
static_cast<AccType
>(multip_h_w), ExactTagType{});
148 const auto vec_beta =
wrapper::vdup_n(
static_cast<AccType
>(beta), ExactTagType{});
152 [&](
const Coordinates &)
154 auto input_ptr =
reinterpret_cast<const float16_t *
>(input_plane_it.ptr());
155 auto output_ptr =
reinterpret_cast<float16_t *
>(output_plane_it.ptr());
158 int x = window.x().start();
159 for (; x <= (window.x().end() - window_step_x); x += window_step_x)
162 const auto normalized_vec =
163 vector_float_norm_fp16(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
168 for (; x < window.x().
end(); ++x)
170 const auto val =
static_cast<AccType
>(*(input_ptr + x));
171 *(output_ptr + x) =
static_cast<float16_t
>((val - mean_h_w) * multip_h_w + beta);
174 input_plane_it, output_plane_it);
185 bool use_mixed_precision,
186 const Window &window)
188 if (use_mixed_precision)
190 return instance_normalization_nchw_fp16<float>(
input, output, gamma, beta,
epsilon, window);
194 return instance_normalization_nchw_fp16<float16_t>(
input, output, gamma, beta,
epsilon, window);