24 #ifndef ACL_SRC_CPU_KERNELS_FUSE_BATCH_NORMALIZATION_GENERIC_IMPL_H
25 #define ACL_SRC_CPU_KERNELS_FUSE_BATCH_NORMALIZATION_GENERIC_IMPL_H
35 template <
typename T,
bool fused_activation,
typename F>
49 const int window_step_x = 16 /
sizeof(T);
50 const auto window_start_x =
static_cast<int>(window.
x().
start());
51 const auto window_end_x =
static_cast<int>(window.
x().
end());
53 Window win_to_use = window;
67 const auto input_gamma =
69 const auto input_beta =
72 T mean =
static_cast<T
>(0);
73 T var =
static_cast<T
>(0);
74 T gamma =
static_cast<T
>(1);
75 T beta =
static_cast<T
>(0);
76 T denominator =
static_cast<T
>(0);
88 const auto input_ptr =
reinterpret_cast<const T *
>(
input.ptr());
89 const auto output_ptr =
reinterpret_cast<T *
>(output.
ptr());
93 mean = input_mean[
id.z()];
94 var = input_var[
id.z()];
97 if (input_gamma !=
nullptr)
99 gamma = input_gamma[
id.z()];
102 if (input_beta !=
nullptr)
104 beta = input_beta[
id.z()];
115 int x = window_start_x;
116 for (; x <= (window_end_x - window_step_x); x += window_step_x)
120 const auto x_bar =
wrapper::vmul(numerator, denominator_vec);
124 if (fused_activation)
126 activation_functor(res);
134 for (; x < window_end_x; ++x)
136 const T numerator = input_ptr[x] - mean;
137 const T x_bar = numerator * denominator;
138 T res = beta + x_bar * gamma;
141 if (fused_activation)
143 activation_functor(res);
147 *(output_ptr + x) = res;
153 template <
typename T>
165 using ScalarType = T;
169 const bool run_in_place_weights = (fused_weights ==
nullptr) || (fused_weights == conv_weights);
170 const bool run_in_place_bias = (fused_bias ==
nullptr) || (conv_bias !=
nullptr && fused_bias == conv_bias);
176 const int window_step_x = size;
177 const auto window_start_x =
static_cast<int>(window.
x().
start());
178 const auto window_end_x =
static_cast<int>(window.
x().
end());
180 Iterator conv_w_in(conv_weights, win);
181 Iterator conv_w_out(run_in_place_weights ? conv_weights : fused_weights, win);
183 const auto conv_bias_in =
186 (run_in_place_bias ? conv_bias_in
191 const auto input_gamma = (bn_gamma !=
nullptr)
194 const auto input_beta = (bn_beta !=
nullptr)
205 auto mean = ScalarType(0.0);
206 auto var = ScalarType(0.0);
207 auto gamma = ScalarType(1.0);
208 auto beta = ScalarType(0.0);
209 auto conv_bias_in_scalar = ScalarType(0.0);
214 var = input_var[
id[3]];
215 if (input_gamma !=
nullptr)
217 gamma = input_gamma[
id[3]];
220 if ((
id[0] == 0) && (
id[1] == 0) && (
id[2] == 0))
222 if (input_beta !=
nullptr)
224 beta = input_beta[
id[3]];
229 mean = input_mean[
id[3]];
232 if (conv_bias_in !=
nullptr)
234 conv_bias_in_scalar = conv_bias_in[
id[3]];
236 auto conv_bias_tmp_scalar = (conv_bias_in_scalar - mean) / std::sqrt(var + ScalarType(
epsilon));
237 conv_bias_out[
id[3]] = (conv_bias_tmp_scalar * gamma) + beta;
240 int x = window_start_x;
241 auto conv_w_in_ptr =
reinterpret_cast<const ScalarType *
>(conv_w_in.
ptr());
242 auto conv_w_out_ptr =
reinterpret_cast<ScalarType *
>(conv_w_out.
ptr());
247 for (; x <= (window_end_x - window_step_x); x += window_step_x)
258 for (; x < window_end_x; ++x)
260 *(conv_w_out_ptr + x) = *(conv_w_in_ptr + x) / std::sqrt(var + ScalarType(
epsilon)) * gamma;
263 conv_w_in, conv_w_out);
265 template <
typename T>
277 using ScalarType = T;
281 const bool run_in_place_weights = (fused_weights ==
nullptr) || (fused_weights == dwc_weights);
282 const bool run_in_place_bias = (fused_bias ==
nullptr) || (dwc_bias !=
nullptr && fused_bias == dwc_bias);
288 const int window_step_x = size;
289 const auto window_start_x =
static_cast<int>(window.
x().
start());
290 const auto window_end_x =
static_cast<int>(window.
x().
end());
292 Iterator dwc_w_in(dwc_weights, win);
293 Iterator dwc_w_out(run_in_place_weights ? dwc_weights : fused_weights, win);
295 const auto dwc_bias_in =
298 (run_in_place_bias ? dwc_bias_in
303 const auto input_gamma = (bn_gamma !=
nullptr)
306 const auto input_beta = (bn_beta !=
nullptr)
317 auto mean = ScalarType(0.0);
318 auto var = ScalarType(0.0);
319 auto gamma = ScalarType(1.0);
320 auto beta = ScalarType(0.0);
321 auto dwc_bias_in_scalar = ScalarType(0.0);
326 var = input_var[
id[2]];
327 if (input_gamma !=
nullptr)
329 gamma = input_gamma[
id[2]];
334 mean = input_mean[
id[2]];
338 if (input_beta !=
nullptr)
340 beta = input_beta[
id[2]];
344 if (dwc_bias_in !=
nullptr)
346 dwc_bias_in_scalar = dwc_bias_in[
id[2]];
349 auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(
epsilon));
350 dwc_bias_out[
id[2]] = (dwc_bias_tmp_scalar * gamma) + beta;
353 int x = window_start_x;
354 auto dwc_w_in_ptr =
reinterpret_cast<const ScalarType *
>(dwc_w_in.
ptr());
355 auto dwc_w_out_ptr =
reinterpret_cast<ScalarType *
>(dwc_w_out.
ptr());
360 for (; x <= (window_end_x - window_step_x); x += window_step_x)
371 for (; x < window_end_x; ++x)
373 *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(
epsilon)) * gamma;
376 dwc_w_in, dwc_w_out);
381 #endif // ACL_SRC_CPU_KERNELS_FUSE_BATCH_NORMALIZATION_GENERIC_IMPL_H