11 #include <gemmlowp/fixedpoint.h>
13 static void SoftmaxScaling5Bits(
double softmaxBeta,
15 int32_t &input_beta_multiplier,
16 int &input_beta_left_shift)
18 double max_real_multiplier = (1LL << 31) - 1.0;
19 double input_beta_real_multiplier = std::min<double>(softmaxBeta * scale * (1 << (31 - 5)), max_real_multiplier);
22 "CalculateSoftmaxTableValues: QuantizeMultiplierGreaterThanOne must be greater than 1");
24 if (input_beta_real_multiplier == 0.)
26 input_beta_multiplier = 0;
27 input_beta_left_shift = 0;
30 double q = std::frexp(input_beta_real_multiplier, &input_beta_left_shift);
31 auto q_fixed =
static_cast<int64_t
>(std::round(q * (1LL << 31)));
35 if (q_fixed == (1LL << 31))
38 ++input_beta_left_shift;
42 "CalculateSoftmaxTableValues: All results would be zero");
44 if (input_beta_left_shift < -31)
46 input_beta_left_shift = 0;
49 input_beta_multiplier =
static_cast<int32_t
>(q_fixed);
52 static int CalculateInputRadius5Bits(
int &input_beta_left_shift)
54 const double max_input_rescaled = 1.0 * ((1 << 5) - 1) * (1LL << (31 - 5)) /
55 (
static_cast<double>(1LL << input_beta_left_shift));
57 return static_cast<int>(std::floor(max_input_rescaled));
63 "CalculateSoftmaxTableValues: Beta values other than 1.0 are not supported");
65 int32_t input_beta_multiplier = 0;
66 int input_beta_left_shift = 0;
68 const int kScaledDiffIntegerBits = 5;
69 using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
70 using gemmlowp::SaturatingRoundingDoublingHighMul;
71 using gemmlowp::exp_on_negative_values;
73 SoftmaxScaling5Bits(softmaxBeta, scale, input_beta_multiplier, input_beta_left_shift);
74 int diff_min = -1 * CalculateInputRadius5Bits(input_beta_left_shift);
76 for (int32_t input_diff = -256; input_diff <= 256; input_diff++)
79 if (input_diff >= diff_min)
81 int32_t input_diff_rescaled =
82 SaturatingRoundingDoublingHighMul(input_diff * (1 << input_beta_left_shift), input_beta_multiplier);
83 const FixedPointScaledDiff input_diff_fixed_point = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
84 output = exp_on_negative_values(input_diff_fixed_point).raw();
88 int32_t first = (output >> 24) & 0xFF;
89 int32_t second = (output >> 16) & 0xFF;
90 int32_t third = (output >> 8) & 0xFF;
91 int32_t fourth = (output) & 0xFF;
92 tables[0].push_back(
static_cast<int16_t
>(first));
93 tables[1].push_back(
static_cast<int16_t
>(second));
94 tables[2].push_back(
static_cast<int16_t
>(third));
95 tables[3].push_back(
static_cast<int16_t
>(fourth));