ArmNN
 25.11
Loading...
Searching...
No Matches
TosaSoftmaxOperatorUtils.hpp
Go to the documentation of this file.
1//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5//
6// Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
7// SPDX-License-Identifier: Apache-2.0
8//
9
10#pragma once
11#include <gemmlowp/fixedpoint.h>
12
13static void SoftmaxScaling5Bits(double softmaxBeta,
14 double scale,
15 int32_t &input_beta_multiplier,
16 int &input_beta_left_shift)
17{
18 double max_real_multiplier = (1LL << 31) - 1.0;
19 double input_beta_real_multiplier = std::min<double>(softmaxBeta * scale * (1 << (31 - 5)), max_real_multiplier);
20
21 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(input_beta_real_multiplier > 1.,
22 "CalculateSoftmaxTableValues: QuantizeMultiplierGreaterThanOne must be greater than 1");
23
24 if (input_beta_real_multiplier == 0.)
25 {
26 input_beta_multiplier = 0;
27 input_beta_left_shift = 0;
28 return;
29 }
30 double q = std::frexp(input_beta_real_multiplier, &input_beta_left_shift);
31 auto q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
32
33 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(q_fixed <= (1LL << 31), "CalculateSoftmaxTableValues: Rounding not valid");
34
35 if (q_fixed == (1LL << 31))
36 {
37 q_fixed /= 2;
38 ++input_beta_left_shift;
39 }
40
41 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(q_fixed <= std::numeric_limits<int32_t>::max(),
42 "CalculateSoftmaxTableValues: All results would be zero");
43
44 if (input_beta_left_shift < -31)
45 {
46 input_beta_left_shift = 0;
47 q_fixed = 0;
48 }
49 input_beta_multiplier = static_cast<int32_t>(q_fixed);
50}
51
52static int CalculateInputRadius5Bits(int &input_beta_left_shift)
53{
54 const double max_input_rescaled = 1.0 * ((1 << 5) - 1) * (1LL << (31 - 5)) /
55 (static_cast<double>(1LL << input_beta_left_shift));
56
57 return static_cast<int>(std::floor(max_input_rescaled));
58}
59
60inline void CalculateSoftmaxTableValues(double softmaxBeta, double scale, std::array<std::vector<int16_t>, 4>& tables)
61{
62 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(softmaxBeta == 1.0f,
63 "CalculateSoftmaxTableValues: Beta values other than 1.0 are not supported");
64
65 int32_t input_beta_multiplier = 0;
66 int input_beta_left_shift = 0;
67
68 const int kScaledDiffIntegerBits = 5;
69 using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
70 using gemmlowp::SaturatingRoundingDoublingHighMul;
71 using gemmlowp::exp_on_negative_values;
72
73 SoftmaxScaling5Bits(softmaxBeta, scale, input_beta_multiplier, input_beta_left_shift);
74 int diff_min = -1 * CalculateInputRadius5Bits(input_beta_left_shift);
75
76 for (int32_t input_diff = -256; input_diff <= 256; input_diff++)
77 {
78 int32_t output = 0;
79 if (input_diff >= diff_min)
80 {
81 int32_t input_diff_rescaled =
82 SaturatingRoundingDoublingHighMul(input_diff * (1 << input_beta_left_shift), input_beta_multiplier);
83 const FixedPointScaledDiff input_diff_fixed_point = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
84 output = exp_on_negative_values(input_diff_fixed_point).raw();
85 }
86
87 // Only copy the 8-bit groups
88 int32_t first = (output >> 24) & 0xFF;
89 int32_t second = (output >> 16) & 0xFF;
90 int32_t third = (output >> 8) & 0xFF;
91 int32_t fourth = (output) & 0xFF;
92 tables[0].push_back(static_cast<int16_t>(first));
93 tables[1].push_back(static_cast<int16_t>(second));
94 tables[2].push_back(static_cast<int16_t>(third));
95 tables[3].push_back(static_cast<int16_t>(fourth));
96 }
97}
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
void CalculateSoftmaxTableValues(double softmaxBeta, double scale, std::array< std::vector< int16_t >, 4 > &tables)