45 const size_t w_size =
src.shape()[0];
46 const size_t h_size =
src.shape()[1];
47 const size_t c_size =
src.shape()[2];
48 const size_t n_size =
src.shape()[3];
50 #pragma omp parallel for collapse(2)
52 for(
size_t n_i = 0; n_i < n_size; ++n_i)
54 for(
size_t c_i = 0; c_i < c_size; ++c_i)
59 for(
size_t h_i = 0; h_i < h_size; ++h_i)
61 for(
size_t w_i = 0; w_i < w_size; ++w_i)
65 sum_sq_h_w += val * val;
69 const float mean_h_w = sum_h_w / (h_size * w_size);
71 const float var_h_w = sum_sq_h_w / (h_size * w_size) - mean_h_w * mean_h_w;
75 for(
size_t h_i = 0; h_i < h_size; ++h_i)
77 for(
size_t w_i = 0; w_i < w_size; ++w_i)
81 dst[index] = (
src[index] - mean_h_w) * gamma / std::sqrt(var_h_w +
epsilon) + beta;