24.02.1
|
Go to the documentation of this file.
24 #ifndef SRC_CORE_NEON_KERNELS_FUSE_BATCH_NORMALIZATION_IMPL_H
25 #define SRC_CORE_NEON_KERNELS_FUSE_BATCH_NORMALIZATION_IMPL_H
51 const bool run_in_place_weights = (fused_weights ==
nullptr) || (fused_weights == dwc_weights);
52 const bool run_in_place_bias = (fused_bias ==
nullptr) || (dwc_bias !=
nullptr && fused_bias == dwc_bias);
58 const int window_step_x = size;
59 const auto window_start_x =
static_cast<int>(window.
x().
start());
60 const auto window_end_x =
static_cast<int>(window.
x().
end());
63 Iterator dwc_w_out(run_in_place_weights ? dwc_weights : fused_weights, win);
65 const auto dwc_bias_in =
68 (run_in_place_bias ? dwc_bias_in
73 const auto input_gamma = (bn_gamma !=
nullptr)
76 const auto input_beta = (bn_beta !=
nullptr)
88 auto gamma = ScalarType(1.0);
89 auto beta = ScalarType(0.0);
90 auto dwc_bias_in_scalar = ScalarType(0);
96 int x = window_start_x;
97 for (; x <= (window_end_x - window_step_x); x += window_step_x)
100 if (input_gamma !=
nullptr)
105 if ((
id[2] == 0) && (
id[1] == 0))
110 if (input_beta !=
nullptr)
115 if (dwc_bias_in !=
nullptr)
126 auto dwc_w_in_ptr =
reinterpret_cast<const ScalarType *
>(dwc_w_in.
ptr());
127 auto dwc_w_out_ptr =
reinterpret_cast<ScalarType *
>(dwc_w_out.
ptr());
139 for (; x < window_end_x; ++x)
141 auto var = input_var[x];
142 if (input_gamma !=
nullptr)
144 gamma = input_gamma[x];
147 if (
id[2] == 0 &&
id[1] == 0)
149 auto mean = input_mean[x];
150 if (input_beta !=
nullptr)
152 beta = input_beta[x];
154 if (dwc_bias_in !=
nullptr)
156 dwc_bias_in_scalar = dwc_bias_in[x];
159 auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(
epsilon));
160 dwc_bias_out[x] = (dwc_bias_tmp_scalar * gamma) + beta;
163 const auto dwc_w_in_ptr =
reinterpret_cast<const ScalarType *
>(dwc_w_in.
ptr());
164 auto dwc_w_out_ptr =
reinterpret_cast<ScalarType *
>(dwc_w_out.
ptr());
166 *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(
epsilon)) * gamma;
169 dwc_w_in, dwc_w_out);
173 #endif //SRC_CORE_NEON_KERNELS_FUSE_BATCH_NORMALIZATION_IMPL_H
uint8x8_t vadd(const uint8x8_t &a, const uint8x8_t &b)
constexpr int start() const
Return the start of the dimension.
uint8x8_t vsub(const uint8x8_t &a, const uint8x8_t &b)
float32x2_t vinvsqrt(const float32x2_t &a)
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Interface for CPU tensor.
uint8x16_t vloadq(const uint8_t *ptr)
Includes all wrapper headers at once.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&...iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Iterator updated by execute_window_loop for each window element.
Describe one of the image's dimensions with a start, end and step.
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
void vstore(uint8_t *ptr, uint8x8_t val)
Describe a multidimensional execution window.
typename neon_bitvector< T, BW >::tag_type neon_bitvector_tag_t
Helper type template to get the tag type of a neon vector.
Copyright (c) 2017-2024 Arm Limited.
uint8_t * ptr_to_element(const Coordinates &id) const
Return a pointer to the element at the passed coordinates.
void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
constexpr int end() const
Return the end of the dimension.
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)