ArmNN
 24.08
TosaTableUtils.hpp File Reference
#include <cfloat>
#include <vector>
#include <functional>
#include <cstdint>
#include <cmath>
Include dependency graph for TosaTableUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

std::vector< int16_t > getTosaConst8bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< float(float)> func)
 
template<typename FloatT >
std::vector< int16_t > getTosaConst16bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< FloatT(FloatT)> func)
 

Function Documentation

◆ getTosaConst16bitTable()

std::vector<int16_t> getTosaConst16bitTable ( float  input_scale,
int32_t  input_zp,
float  output_scale,
int32_t  output_zp,
std::function< FloatT(FloatT)>  func 
)
inline

Definition at line 52 of file TosaTableUtils.hpp.

57 {
58  std::vector<int16_t> table;
59  table.reserve(513);
60 
61  FloatT input_min =
62  input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - input_zp);
63  FloatT input_max =
64  input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - input_zp);
65  FloatT output_min =
66  output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - output_zp);
67  FloatT output_max =
68  output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - output_zp);
69 
70  FloatT step = (input_max - input_min) / 512;
71  FloatT half_step = step / 2;
72  FloatT output_scaling_inv = 65536 / (output_max - output_min);
73 
74  for (int32_t i = 0; i < 512; i++)
75  {
76  FloatT iFloat = static_cast<FloatT>(i);
77  FloatT sample_val =
78  std::round(func(input_min + (iFloat * step)) * output_scaling_inv);
79  FloatT midpoint_interp_val = std::round(
80  ((func(input_min + (iFloat + 1) * step) * output_scaling_inv) +
81  std::round(func(input_min + (iFloat * step)) * output_scaling_inv)) /
82  2);
83  FloatT midpoint_val = std::round(func(input_min + (iFloat * step) + half_step) *
84  output_scaling_inv);
85  FloatT midpoint_err = midpoint_interp_val - midpoint_val;
86  FloatT bias = std::round(midpoint_err / 2);
87 
88  table.push_back(static_cast<int16_t>(
89  std::min<FloatT>(std::max<FloatT>(sample_val - bias, -32768), 32767)));
90  }
91 
92  FloatT max_val = std::round(func(input_max) * output_scaling_inv);
93  table.push_back(static_cast<int16_t>(
94  std::min<FloatT>(std::max<FloatT>(max_val, -32768), 32767)));
95  return table;
96 }

◆ getTosaConst8bitTable()

std::vector<int16_t> getTosaConst8bitTable ( float  input_scale,
int32_t  input_zp,
float  output_scale,
int32_t  output_zp,
std::function< float(float)>  func 
)
inline

Definition at line 19 of file TosaTableUtils.hpp.

24 {
25  // TosaTableAttribute requires int16 vector input. However, TOSA TABLE legalizations are performed using int8.
26  std::vector<int16_t> table;
27  table.reserve(256);
28  float inverse_scale = 1.0f / output_scale;
29  for (int32_t i = -128; i < 128; i++)
30  {
31  float dequantized = input_scale * static_cast<float>(i - input_zp);
32  float transformed = func(dequantized);
33 
34  float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale);
35  if (transformed >= max)
36  {
37  table.push_back(INT8_MAX);
38  continue;
39  }
40 
41  int32_t rescaled = static_cast<int32_t>(std::round(transformed * inverse_scale));
42  int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
43  table.push_back(
44  static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
45  }
46  return table;
47 }

Referenced by ConvertExpOperator(), ConvertGeluToTosaOperator(), and ConvertLogOperator().