ArmNN
 24.08
TosaTableUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 //
6 // Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
7 // SPDX-License-Identifier: Apache-2.0
8 //
9 
10 #include <cfloat>
11 #include <vector>
12 #include <functional>
13 #include <cstdint>
14 #include <cmath>
15 
16 
17 // Abstract of getTosaConst8bitTable() function from:
18 // tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
19 inline std::vector<int16_t> getTosaConst8bitTable(float input_scale,
20  int32_t input_zp,
21  float output_scale,
22  int32_t output_zp,
23  std::function<float(float)> func)
24 {
25  // TosaTableAttribute requires int16 vector input. However, TOSA TABLE legalizations are performed using int8.
26  std::vector<int16_t> table;
27  table.reserve(256);
28  float inverse_scale = 1.0f / output_scale;
29  for (int32_t i = -128; i < 128; i++)
30  {
31  float dequantized = input_scale * static_cast<float>(i - input_zp);
32  float transformed = func(dequantized);
33 
34  float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale);
35  if (transformed >= max)
36  {
37  table.push_back(INT8_MAX);
38  continue;
39  }
40 
41  int32_t rescaled = static_cast<int32_t>(std::round(transformed * inverse_scale));
42  int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
43  table.push_back(
44  static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
45  }
46  return table;
47 }
48 
49 // Abstract of getTosaConst16bitTable() function from:
50 // tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc
51 template <typename FloatT>
52 inline std::vector<int16_t> getTosaConst16bitTable(float input_scale,
53  int32_t input_zp,
54  float output_scale,
55  int32_t output_zp,
56  std::function<FloatT(FloatT)> func)
57 {
58  std::vector<int16_t> table;
59  table.reserve(513);
60 
61  FloatT input_min =
62  input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - input_zp);
63  FloatT input_max =
64  input_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - input_zp);
65  FloatT output_min =
66  output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::min() - output_zp);
67  FloatT output_max =
68  output_scale * static_cast<FloatT>(std::numeric_limits<int16_t>::max() - output_zp);
69 
70  FloatT step = (input_max - input_min) / 512;
71  FloatT half_step = step / 2;
72  FloatT output_scaling_inv = 65536 / (output_max - output_min);
73 
74  for (int32_t i = 0; i < 512; i++)
75  {
76  FloatT iFloat = static_cast<FloatT>(i);
77  FloatT sample_val =
78  std::round(func(input_min + (iFloat * step)) * output_scaling_inv);
79  FloatT midpoint_interp_val = std::round(
80  ((func(input_min + (iFloat + 1) * step) * output_scaling_inv) +
81  std::round(func(input_min + (iFloat * step)) * output_scaling_inv)) /
82  2);
83  FloatT midpoint_val = std::round(func(input_min + (iFloat * step) + half_step) *
84  output_scaling_inv);
85  FloatT midpoint_err = midpoint_interp_val - midpoint_val;
86  FloatT bias = std::round(midpoint_err / 2);
87 
88  table.push_back(static_cast<int16_t>(
89  std::min<FloatT>(std::max<FloatT>(sample_val - bias, -32768), 32767)));
90  }
91 
92  FloatT max_val = std::round(func(input_max) * output_scaling_inv);
93  table.push_back(static_cast<int16_t>(
94  std::min<FloatT>(std::max<FloatT>(max_val, -32768), 32767)));
95  return table;
96 }
getTosaConst8bitTable
std::vector< int16_t > getTosaConst8bitTable(float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< float(float)> func)
Definition: TosaTableUtils.hpp:19
getTosaConst16bitTable
std::vector< int16_t > getTosaConst16bitTable(float input_scale, int32_t input_zp, float output_scale, int32_t output_zp, std::function< FloatT(FloatT)> func)
Definition: TosaTableUtils.hpp:52