ArmNN
 24.08
TosaSoftmaxOperatorUtils.hpp File Reference
#include <gemmlowp/fixedpoint.h>
Include dependency graph for TosaSoftmaxOperatorUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

void CalculateSoftmaxTableValues (double softmaxBeta, double scale, std::array< std::vector< int16_t >, 4 > &tables)
 

Function Documentation

◆ CalculateSoftmaxTableValues()

void CalculateSoftmaxTableValues ( double  softmaxBeta,
double  scale,
std::array< std::vector< int16_t >, 4 > &  tables 
)
inline

Definition at line 60 of file TosaSoftmaxOperatorUtils.hpp.

61 {
62  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(softmaxBeta == 1.0f,
63  "CalculateSoftmaxTableValues: Beta values other than 1.0 are not supported");
64 
65  int32_t input_beta_multiplier = 0;
66  int input_beta_left_shift = 0;
67 
68  const int kScaledDiffIntegerBits = 5;
69  using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
70  using gemmlowp::SaturatingRoundingDoublingHighMul;
71  using gemmlowp::exp_on_negative_values;
72 
73  SoftmaxScaling5Bits(softmaxBeta, scale, input_beta_multiplier, input_beta_left_shift);
74  int diff_min = -1 * CalculateInputRadius5Bits(input_beta_left_shift);
75 
76  for (int32_t input_diff = -256; input_diff <= 256; input_diff++)
77  {
78  int32_t output = 0;
79  if (input_diff >= diff_min)
80  {
81  int32_t input_diff_rescaled =
82  SaturatingRoundingDoublingHighMul(input_diff * (1 << input_beta_left_shift), input_beta_multiplier);
83  const FixedPointScaledDiff input_diff_fixed_point = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
84  output = exp_on_negative_values(input_diff_fixed_point).raw();
85  }
86 
87  // Only copy the 8-bit groups
88  int32_t first = (output >> 24) & 0xFF;
89  int32_t second = (output >> 16) & 0xFF;
90  int32_t third = (output >> 8) & 0xFF;
91  int32_t fourth = (output) & 0xFF;
92  tables[0].push_back(static_cast<int16_t>(first));
93  tables[1].push_back(static_cast<int16_t>(second));
94  tables[2].push_back(static_cast<int16_t>(third));
95  tables[3].push_back(static_cast<int16_t>(fourth));
96  }
97 }

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE.

Referenced by ConvertSoftmaxToTosaOperator().

ARMNN_THROW_INVALIDARG_MSG_IF_FALSE
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
Definition: Exceptions.hpp:210