46 int output_multiplier{};
53 output_multiplier = 0;
57 const uint32_t num_batch =
src.shape()[1];
58 const uint32_t num_input =
src.shape()[0];
60 for(uint32_t batch_idx = 0; batch_idx < num_batch; ++batch_idx)
65 for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
67 const auto index = batch_idx * num_input + input_idx;
68 const auto val =
static_cast<int32_t
>(
src[index]);
73 const auto temp =
static_cast<int64_t
>(0x100000) / num_input;
74 const auto mean = sum * 1024 /
static_cast<int64_t
>(num_input);
75 const auto variance = ((sum_sq * temp) - (mean * mean)) / 0x100000;
77 int32_t stddev_invsqrt_mul{};
78 int32_t stddev_invsqrt_shift{};
81 for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
83 const auto index = batch_idx * num_input + input_idx;
84 const auto val =
static_cast<int32_t
>(
src[index]);
85 const auto shifted = (val << 10) - mean;
87 const int64_t weighted = rescaled * weight[input_idx] +
bias[input_idx];
88 const auto reverse_shifted =
static_cast<int32_t
>((weighted + 512) >> 10);
90 out_val = arm_compute::utility::clamp<decltype(out_val), int16_t>(out_val, std::numeric_limits<int16_t>::min());
91 output[index] =
static_cast<int16_t
>(out_val);