ArmNN
 25.11
Loading...
Searching...
No Matches
TosaTableUtils.hpp File Reference
#include <cfloat>
#include <vector>
#include <functional>
#include <cstdint>
#include <cmath>
Include dependency graph for TosaTableUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

std::vector< int16_t > getTosaConst8bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< float(float)> func)
template<typename FloatT>
std::vector< int16_t > getTosaConst16bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< FloatT(FloatT)> func)

Function Documentation

◆ getTosaConst16bitTable()

template<typename FloatT>
std::vector< int16_t > getTosaConst16bitTable ( float input_scale,
int32_t input_zp,
float output_scale,
int32_t output_zp,
std::function< FloatT(FloatT)> func )
inline

Definition at line 52 of file TosaTableUtils.hpp.

57{
58 std::vector<int16_t> table;
59 table.reserve(513);
60
61 FloatT input_min =
62 input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - input_zp);
63 FloatT input_max =
64 input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - input_zp);
65 FloatT output_min =
66 output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - output_zp);
67 FloatT output_max =
68 output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - output_zp);
69
70 FloatT step = (input_max - input_min) / 512;
71 FloatT half_step = step / 2;
72 FloatT output_scaling_inv = 65536 / (output_max - output_min);
73
74 for (int32_t i = 0; i < 512; i++)
75 {
76 FloatT iFloat = static_cast<FloatT>(i);
77 FloatT sample_val =
78 std::round(func(input_min + (iFloat * step)) * output_scaling_inv);
79 FloatT midpoint_interp_val = std::round(
80 ((func(input_min + (iFloat + 1) * step) * output_scaling_inv) +
81 std::round(func(input_min + (iFloat * step)) * output_scaling_inv)) /
82 2);
83 FloatT midpoint_val = std::round(func(input_min + (iFloat * step) + half_step) *
84 output_scaling_inv);
85 FloatT midpoint_err = midpoint_interp_val - midpoint_val;
86 FloatT bias = std::round(midpoint_err / 2);
87
88 table.push_back(static_cast<int16_t>(
89 std::min<FloatT>(std::max<FloatT>(sample_val - bias, -32768), 32767)));
90 }
91
92 FloatT max_val = std::round(func(input_max) * output_scaling_inv);
93 table.push_back(static_cast<int16_t>(
94 std::min<FloatT>(std::max<FloatT>(max_val, -32768), 32767)));
95 return table;
96}

◆ getTosaConst8bitTable()

std::vector< int16_t > getTosaConst8bitTable ( float input_scale,
int32_t input_zp,
float output_scale,
int32_t output_zp,
std::function< float(float)> func )
inline

Definition at line 19 of file TosaTableUtils.hpp.

24{
25 // TosaTableAttribute requires int16 vector input. However, TOSA TABLE legalizations are performed using int8.
26 std::vector<int16_t> table;
27 table.reserve(256);
28 float inverse_scale = 1.0f / output_scale;
29 for (int32_t i = -128; i < 128; i++)
30 {
31 float dequantized = input_scale * static_cast<float>(i - input_zp);
32 float transformed = func(dequantized);
33
34 float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale);
35 if (transformed >= max)
36 {
37 table.push_back(INT8_MAX);
38 continue;
39 }
40
41 int32_t rescaled = static_cast<int32_t>(std::round(transformed * inverse_scale));
42 int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
43 table.push_back(
44 static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
45 }
46 return table;
47}

Referenced by ConvertExpOperator(), ConvertGeluToTosaOperator(), ConvertLogOperator(), ConvertRsqrtOperator(), ConvertSigmoidToTosaOperator(), and ConvertTanHToTosaOperator().