ArmNN
 24.08
TosaSoftmaxOperatorUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 //
6 // Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
7 // SPDX-License-Identifier: Apache-2.0
8 //
9 
10 #pragma once
11 #include <gemmlowp/fixedpoint.h>
12 
13 static void SoftmaxScaling5Bits(double softmaxBeta,
14  double scale,
15  int32_t &input_beta_multiplier,
16  int &input_beta_left_shift)
17 {
18  double max_real_multiplier = (1LL << 31) - 1.0;
19  double input_beta_real_multiplier = std::min<double>(softmaxBeta * scale * (1 << (31 - 5)), max_real_multiplier);
20 
21  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(input_beta_real_multiplier > 1.,
22  "CalculateSoftmaxTableValues: QuantizeMultiplierGreaterThanOne must be greater than 1");
23 
24  if (input_beta_real_multiplier == 0.)
25  {
26  input_beta_multiplier = 0;
27  input_beta_left_shift = 0;
28  return;
29  }
30  double q = std::frexp(input_beta_real_multiplier, &input_beta_left_shift);
31  auto q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
32 
33  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(q_fixed <= (1LL << 31), "CalculateSoftmaxTableValues: Rounding not valid");
34 
35  if (q_fixed == (1LL << 31))
36  {
37  q_fixed /= 2;
38  ++input_beta_left_shift;
39  }
40 
41  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(q_fixed <= std::numeric_limits<int32_t>::max(),
42  "CalculateSoftmaxTableValues: All results would be zero");
43 
44  if (input_beta_left_shift < -31)
45  {
46  input_beta_left_shift = 0;
47  q_fixed = 0;
48  }
49  input_beta_multiplier = static_cast<int32_t>(q_fixed);
50 }
51 
52 static int CalculateInputRadius5Bits(int &input_beta_left_shift)
53 {
54  const double max_input_rescaled = 1.0 * ((1 << 5) - 1) * (1LL << (31 - 5)) /
55  (static_cast<double>(1LL << input_beta_left_shift));
56 
57  return static_cast<int>(std::floor(max_input_rescaled));
58 }
59 
60 inline void CalculateSoftmaxTableValues(double softmaxBeta, double scale, std::array<std::vector<int16_t>, 4>& tables)
61 {
62  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(softmaxBeta == 1.0f,
63  "CalculateSoftmaxTableValues: Beta values other than 1.0 are not supported");
64 
65  int32_t input_beta_multiplier = 0;
66  int input_beta_left_shift = 0;
67 
68  const int kScaledDiffIntegerBits = 5;
69  using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
70  using gemmlowp::SaturatingRoundingDoublingHighMul;
71  using gemmlowp::exp_on_negative_values;
72 
73  SoftmaxScaling5Bits(softmaxBeta, scale, input_beta_multiplier, input_beta_left_shift);
74  int diff_min = -1 * CalculateInputRadius5Bits(input_beta_left_shift);
75 
76  for (int32_t input_diff = -256; input_diff <= 256; input_diff++)
77  {
78  int32_t output = 0;
79  if (input_diff >= diff_min)
80  {
81  int32_t input_diff_rescaled =
82  SaturatingRoundingDoublingHighMul(input_diff * (1 << input_beta_left_shift), input_beta_multiplier);
83  const FixedPointScaledDiff input_diff_fixed_point = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
84  output = exp_on_negative_values(input_diff_fixed_point).raw();
85  }
86 
87  // Only copy the 8-bit groups
88  int32_t first = (output >> 24) & 0xFF;
89  int32_t second = (output >> 16) & 0xFF;
90  int32_t third = (output >> 8) & 0xFF;
91  int32_t fourth = (output) & 0xFF;
92  tables[0].push_back(static_cast<int16_t>(first));
93  tables[1].push_back(static_cast<int16_t>(second));
94  tables[2].push_back(static_cast<int16_t>(third));
95  tables[3].push_back(static_cast<int16_t>(fourth));
96  }
97 }
CalculateSoftmaxTableValues
void CalculateSoftmaxTableValues(double softmaxBeta, double scale, std::array< std::vector< int16_t >, 4 > &tables)
Definition: TosaSoftmaxOperatorUtils.hpp:60
ARMNN_THROW_INVALIDARG_MSG_IF_FALSE
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
Definition: Exceptions.hpp:210