ArmNN
 25.11
Loading...
Searching...
No Matches
TosaSoftmaxOperatorUtils.hpp File Reference
#include <gemmlowp/fixedpoint.h>
Include dependency graph for TosaSoftmaxOperatorUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

void CalculateSoftmaxTableValues (double softmaxBeta, double scale, std::array< std::vector< int16_t >, 4 > &tables)

Function Documentation

◆ CalculateSoftmaxTableValues()

void CalculateSoftmaxTableValues ( double softmaxBeta,
double scale,
std::array< std::vector< int16_t >, 4 > & tables )
inline

Definition at line 60 of file TosaSoftmaxOperatorUtils.hpp.

61{
62 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(softmaxBeta == 1.0f,
63 "CalculateSoftmaxTableValues: Beta values other than 1.0 are not supported");
64
65 int32_t input_beta_multiplier = 0;
66 int input_beta_left_shift = 0;
67
68 const int kScaledDiffIntegerBits = 5;
69 using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
70 using gemmlowp::SaturatingRoundingDoublingHighMul;
71 using gemmlowp::exp_on_negative_values;
72
73 SoftmaxScaling5Bits(softmaxBeta, scale, input_beta_multiplier, input_beta_left_shift);
74 int diff_min = -1 * CalculateInputRadius5Bits(input_beta_left_shift);
75
76 for (int32_t input_diff = -256; input_diff <= 256; input_diff++)
77 {
78 int32_t output = 0;
79 if (input_diff >= diff_min)
80 {
81 int32_t input_diff_rescaled =
82 SaturatingRoundingDoublingHighMul(input_diff * (1 << input_beta_left_shift), input_beta_multiplier);
83 const FixedPointScaledDiff input_diff_fixed_point = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
84 output = exp_on_negative_values(input_diff_fixed_point).raw();
85 }
86
87 // Only copy the 8-bit groups
88 int32_t first = (output >> 24) & 0xFF;
89 int32_t second = (output >> 16) & 0xFF;
90 int32_t third = (output >> 8) & 0xFF;
91 int32_t fourth = (output) & 0xFF;
92 tables[0].push_back(static_cast<int16_t>(first));
93 tables[1].push_back(static_cast<int16_t>(second));
94 tables[2].push_back(static_cast<int16_t>(third));
95 tables[3].push_back(static_cast<int16_t>(fourth));
96 }
97}
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE.

Referenced by ConvertSoftmaxToTosaOperator().