38 namespace quantization
45 if (multiplier >= 1.f)
58 int32_t *quant_multiplier,
62 const float internal_epsilon = ignore_epsilon ? 0.0f :
epsilon;
70 const double q = std::frexp(multiplier, &shift_exp);
71 *right_shift = -1 * shift_exp;
80 if (ignore_epsilon && *right_shift > 31)
88 *quant_multiplier =
static_cast<int32_t
>(q_fixed);
101 const double q = std::frexp(multiplier, &shift_exp);
102 *left_shift = shift_exp;
112 *quantized_multiplier =
static_cast<int32_t
>(q_fixed);
125 constexpr
unsigned int padding_elems = 32;
126 const unsigned int size = wq_info.
scale().size();
127 const size_t padded_size = (size == 1) ? 1 : size + padding_elems;
130 quant_multipliers.resize(padded_size);
131 quant_shifts.resize(padded_size);
133 const auto &w_scales = wq_info.
scale();
134 const float i_scale = iq_info.
scale().at(0);
135 const float o_scale = oq_info.
scale().at(0);
137 for (
unsigned int i = 0; i < size; ++i)
139 const float multiplier = i_scale * w_scales[i] / o_scale;
140 int32_t quant_multiplier = 0;
141 int32_t quant_shift = 0;
143 quant_multipliers[i] = quant_multiplier;
144 quant_shifts[i] = quant_shift;
156 int min_quant_val = 0;
157 int max_quant_val = 0;
161 min_quant_val = std::numeric_limits<uint8_t>::min();
162 max_quant_val = std::numeric_limits<uint8_t>::max();
166 min_quant_val = std::numeric_limits<int8_t>::min();
167 max_quant_val = std::numeric_limits<int8_t>::max();
170 min_quant_val = std::numeric_limits<uint16_t>::min();
171 max_quant_val = std::numeric_limits<uint16_t>::max();
174 min_quant_val = std::numeric_limits<int16_t>::min();
175 max_quant_val = std::numeric_limits<int16_t>::max();
180 return std::make_pair(min_quant_val, max_quant_val);
191 int32_t type_min = std::get<0>(min_max).get<int32_t>();
192 int32_t type_max = std::get<1>(min_max).get<int32_t>();
200 case ActivationLayerInfo::ActivationFunction::RELU:
203 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
204 type_min = q_unif.offset;
208 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
220 return std::make_tuple(type_min, type_max);
226 int32_t *output_multipliers_ptr,
227 int32_t *output_shifts_ptr)
233 const unsigned int num_filters = wq_info.
scale().size();
235 for (
unsigned int i = 0; i < num_filters; ++i)
237 int32_t output_multiplier = 0;
238 int32_t output_shift = 0;
239 const float multiplier = iq_info.
scale * wq_info.
scale()[i] / oq_info.
scale;
242 output_multipliers_ptr[i] = output_multiplier;
243 output_shifts_ptr[i] = output_shift;
249 bool overflow = a ==
b && a == std::numeric_limits<int32_t>::min();
252 int64_t ab_64 = a_64 * b_64;
253 const bool is_positive_or_zero =
255 int32_t nudge = is_positive_or_zero ? (1 << 30) : (1 - (1 << 30));
256 int32_t ab_x2_high32 =
static_cast<int32_t
>((ab_64 + nudge) / (1ll << 31));
257 return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32;
262 const int32_t mask = (1 << exponent) - 1;
263 const int32_t threshold = (mask >> 1) + (x < 0 ? 1 : 0);
264 return (x >> exponent) + ((x & mask) > threshold ? 1 : 0);
269 const auto left_shift = shift > 0 ? shift : 0;
270 const auto right_shift = shift > 0 ? 0 : -shift;
280 else if (exponent < 0)
286 constexpr
auto min = std::numeric_limits<int32_t>::min();
287 constexpr
auto max = std::numeric_limits<int32_t>::max();
288 const auto width =
sizeof(int32_t) * 8;
290 const int32_t threshold = ((1 << (width - 1 - exponent)) - 1);
291 bool pos_mask = v > threshold;
292 bool neg_mask = v < -threshold;
293 int32_t result = v << exponent;
294 result = pos_mask ? max : result;
295 result = neg_mask ? min : result;
301 int32_t reverse_shift,
302 int32_t &output_inv_sqrt,
303 int32_t &output_shift)
310 output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
317 while (
input >= (1 << 29))
323 const uint32_t max_left_shift_bits = __builtin_clz(
static_cast<uint32_t
>(
input)) - 1;
324 const uint32_t max_left_shift_bits_pairs = max_left_shift_bits / 2;
325 const uint32_t left_shift_bit_pairs = max_left_shift_bits_pairs - 1;
326 output_shift -= left_shift_bit_pairs;
327 input <<= 2 * left_shift_bit_pairs;
330 using FixedPointRawType = int32_t;
331 constexpr uint32_t fixedpoint_position = 3;
332 constexpr uint32_t fixedpoint_int_position =
sizeof(FixedPointRawType) * 8 - 1 - fixedpoint_position;
333 using FixedPoint3 = FixedPointRawType;
334 using FixedPoint0 = FixedPointRawType;
337 const FixedPoint3 fixedpoint_input = (
input >> 1);
339 const FixedPoint3 fixedpoint_half_three = (0x1 << fixedpoint_int_position) + (0x1 << (fixedpoint_int_position - 1));
342 FixedPoint3 x = 0x1 << fixedpoint_int_position;
345 auto fixed_point_mul = [](FixedPointRawType a, FixedPointRawType
b) -> FixedPointRawType
349 auto fixed_point_rescale = [](FixedPointRawType a, uint32_t src_bit, uint32_t dst_bit) -> FixedPointRawType
351 const uint32_t exponent = src_bit - dst_bit;
356 constexpr int32_t num_iteration = 5;
357 for (int32_t i = 0; i < num_iteration; ++i)
359 const auto x3 = fixed_point_rescale(fixed_point_mul(fixed_point_mul(x, x), x), 9, fixedpoint_position);
360 x = fixed_point_rescale(fixed_point_mul(fixedpoint_half_three, x) - fixed_point_mul(fixedpoint_half_input, x3),
361 6, fixedpoint_position);
365 const FixedPoint0 fixedpoint_half_sqrt_2 = 1518500250;
366 x = fixed_point_mul(fixedpoint_half_sqrt_2, x);
368 if (output_shift < 0)
370 output_inv_sqrt <<= -output_shift;
374 output_shift *= reverse_shift;