#include <cfloat>
#include <vector>
#include <functional>
#include <cstdint>
#include <cmath>
Go to the source code of this file.
|
std::vector< int16_t > | getTosaConst8bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< float(float)> func) |
|
template<typename FloatT > |
std::vector< int16_t > | getTosaConst16bitTable (float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< FloatT(FloatT)> func) |
|
◆ getTosaConst16bitTable()
std::vector<int16_t> getTosaConst16bitTable |
( |
float |
input_scale, |
|
|
int32_t |
input_zp, |
|
|
float |
output_scale, |
|
|
int32_t |
output_zp, |
|
|
std::function< FloatT(FloatT)> |
func |
|
) |
| |
|
inline |
Definition at line 52 of file TosaTableUtils.hpp.
58 std::vector<int16_t> table;
62 input_scale *
static_cast<FloatT
>(std::numeric_limits<int16_t>::min() - input_zp);
64 input_scale *
static_cast<FloatT
>(std::numeric_limits<int16_t>::max() - input_zp);
66 output_scale *
static_cast<FloatT
>(std::numeric_limits<int16_t>::min() - output_zp);
68 output_scale *
static_cast<FloatT
>(std::numeric_limits<int16_t>::max() - output_zp);
70 FloatT step = (input_max - input_min) / 512;
71 FloatT half_step = step / 2;
72 FloatT output_scaling_inv = 65536 / (output_max - output_min);
74 for (int32_t i = 0; i < 512; i++)
76 FloatT iFloat =
static_cast<FloatT
>(i);
78 std::round(func(input_min + (iFloat * step)) * output_scaling_inv);
79 FloatT midpoint_interp_val = std::round(
80 ((func(input_min + (iFloat + 1) * step) * output_scaling_inv) +
81 std::round(func(input_min + (iFloat * step)) * output_scaling_inv)) /
83 FloatT midpoint_val = std::round(func(input_min + (iFloat * step) + half_step) *
85 FloatT midpoint_err = midpoint_interp_val - midpoint_val;
86 FloatT bias = std::round(midpoint_err / 2);
88 table.push_back(
static_cast<int16_t
>(
89 std::min<FloatT>(std::max<FloatT>(sample_val - bias, -32768), 32767)));
92 FloatT max_val = std::round(func(input_max) * output_scaling_inv);
93 table.push_back(
static_cast<int16_t
>(
94 std::min<FloatT>(std::max<FloatT>(max_val, -32768), 32767)));
◆ getTosaConst8bitTable()
std::vector<int16_t> getTosaConst8bitTable |
( |
float |
input_scale, |
|
|
int32_t |
input_zp, |
|
|
float |
output_scale, |
|
|
int32_t |
output_zp, |
|
|
std::function< float(float)> |
func |
|
) |
| |
|
inline |
Definition at line 19 of file TosaTableUtils.hpp.
26 std::vector<int16_t> table;
28 float inverse_scale = 1.0f / output_scale;
29 for (int32_t i = -128; i < 128; i++)
31 float dequantized = input_scale *
static_cast<float>(i - input_zp);
32 float transformed = func(dequantized);
34 float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale);
35 if (transformed >= max)
37 table.push_back(INT8_MAX);
41 int32_t rescaled =
static_cast<int32_t
>(std::round(transformed * inverse_scale));
42 int32_t quantized =
static_cast<int32_t
>(rescaled + output_zp);
44 static_cast<int8_t
>(std::min(std::max(quantized, -128), 127)));
Referenced by ConvertExpOperator(), ConvertGeluToTosaOperator(), and ConvertLogOperator().