38 const bool run_in_place_weights = (fused_weights ==
nullptr) || (fused_weights == conv_weights);
39 const bool run_in_place_bias = (fused_bias ==
nullptr) || (conv_bias !=
nullptr && fused_bias == conv_bias);
45 const int window_step_x = size;
46 const auto window_start_x =
static_cast<int>(window.
x().
start());
47 const auto window_end_x =
static_cast<int>(window.
x().
end());
49 Iterator conv_w_in(conv_weights, win);
50 Iterator conv_w_out(run_in_place_weights ? conv_weights : fused_weights, win);
52 const auto conv_bias_in = (conv_bias !=
nullptr ?
reinterpret_cast<ScalarType *
>(conv_bias->
ptr_to_element(
Coordinates(0, 0))) :
nullptr);
53 auto conv_bias_out = (run_in_place_bias ? conv_bias_in :
reinterpret_cast<ScalarType *
>(fused_bias->
ptr_to_element(
Coordinates(0, 0))));
57 const auto input_gamma = (bn_gamma !=
nullptr) ? reinterpret_cast<const ScalarType *>(bn_gamma->
ptr_to_element(
Coordinates(0, 0))) :
nullptr;
58 const auto input_beta = (bn_beta !=
nullptr) ? reinterpret_cast<const ScalarType *>(bn_beta->
ptr_to_element(
Coordinates(0, 0))) :
nullptr;
65 const auto epsilon_vec =
wrapper::vdup_n(ScalarType(epsilon), ExactTagType{});
67 auto mean = ScalarType(0.0);
68 auto var = ScalarType(0.0);
69 auto gamma = ScalarType(1.0);
70 auto beta = ScalarType(0.0);
71 auto conv_bias_in_scalar = ScalarType(0.0);
74 var = input_var[
id[3]];
75 if(input_gamma !=
nullptr)
77 gamma = input_gamma[
id[3]];
80 if((
id[0] == 0) && (
id[1] == 0) && (
id[2] == 0))
82 if(input_beta !=
nullptr)
84 beta = input_beta[
id[3]];
89 mean = input_mean[
id[3]];
92 if(conv_bias_in !=
nullptr)
94 conv_bias_in_scalar = conv_bias_in[
id[3]];
96 auto conv_bias_tmp_scalar = (conv_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
97 conv_bias_out[
id[3]] = (conv_bias_tmp_scalar * gamma) + beta;
100 int x = window_start_x;
101 auto conv_w_in_ptr =
reinterpret_cast<const ScalarType *
>(conv_w_in.
ptr());
102 auto conv_w_out_ptr =
reinterpret_cast<ScalarType *
>(conv_w_out.
ptr());
107 for(; x <= (window_end_x - window_step_x); x += window_step_x)
118 for(; x < window_end_x; ++x)
120 *(conv_w_out_ptr + x) = *(conv_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
123 conv_w_in, conv_w_out);
129 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) 130 template void fused_batch_normalization_conv<float16_t>(
const ITensor *conv_weights,
const ITensor *conv_bias,
ITensor *fused_weights,
ITensor *fused_bias,
uint8_t * ptr_to_element(const Coordinates &id) const
Return a pointer to the element at the passed coordinates.
float32x2_t vinvsqrt(const float32x2_t &a)
uint8x16_t vloadq(const uint8_t *ptr)
uint8x8_t vadd(const uint8x8_t &a, const uint8x8_t &b)
Describe one of the image's dimensions with a start, end and step.
Interface for CPU tensor.
Copyright (c) 2017-2022 Arm Limited.
void fused_batch_normalization_conv(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
typename neon_bitvector< T, BW >::tag_type neon_bitvector_tag_t
Helper type template to get the tag type of a neon vector.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
template void fused_batch_normalization_conv< float32_t >(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
void vstore(uint8_t *ptr, uint8x8_t val)
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
constexpr int end() const
Return the end of the dimension.
Iterator updated by execute_window_loop for each window element.
constexpr int start() const
Return the start of the dimension.
Describe a multidimensional execution window.
constexpr const Dimension & x() const
Alias to access the first dimension of the window.