ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
TosaRescaleOperatorUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Exceptions.hpp>
7 
8 #pragma once
9 
10 inline void CreateRawRescaleTosaOperator(const std::string& inputName,
11  const std::string& outputName,
12  const std::vector<int32_t>& multipliers,
13  const std::vector<int32_t>& shifts,
14  int32_t input_zp,
15  int32_t output_zp,
16  bool input_unsigned,
17  bool output_unsigned,
18  bool double_round,
19  bool scale32,
20  bool per_channel,
21  TosaSerializationOperator** op)
22 {
23  if (!op)
24  {
25  throw armnn::Exception("CreateRawRescaleTosaOperator: nullptr op.");
26  }
27 
28  if (multipliers.empty())
29  {
30  throw armnn::Exception("CreateRawRescaleTosaOperator: multipliers is empty.");
31  }
32 
33  if (multipliers.size() != shifts.size())
34  {
35  throw armnn::Exception("CreateRawRescaleTosaOperator: multipliers and shift not same size.");
36  }
37 
38  if (multipliers.size() == 1 && per_channel)
39  {
40  throw armnn::Exception("CreateRawRescaleTosaOperator: \
41  multipliers must be greater than 1 if per_channel is true.");
42  }
43 
44  if (multipliers.size() == 1 && per_channel)
45  {
46  throw armnn::Exception("CreateRawRescaleTosaOperator: \
47  multipliers size must be greater than 1 if per_channel is true.");
48  }
49 
50  if (multipliers.size() > 1 && !per_channel)
51  {
52  throw armnn::Exception("CreateRawRescaleTosaOperator: \
53  multipliers size must be 1 if per_channel is false.");
54  }
55 
56  TosaRescaleAttribute attribute(input_zp,
57  output_zp,
58  multipliers,
59  shifts,
60  scale32,
61  double_round,
62  per_channel,
63  input_unsigned,
64  output_unsigned);
65 
66  // op
67  *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
68  if (!(*op))
69  {
70  throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
71  }
72 }
73 
74 /// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
75 /// From a scale value, generates multiplier and shift values where
76 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
77 /// multiplier = mantissa*2^shift for 32-bit scaling.
78 inline void ComputeMultiplierAndShiftTosaScale32(double scale,
79  int32_t &multiplier,
80  int32_t &shift)
81 {
82  const double mantissa = std::frexp(scale, &shift);
83  auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
84 
85  // Can't be greater than 1.0.
86  if (!(shiftedM <= (int64_t(1) << 31)))
87  {
88  throw armnn::Exception("Shifted mantissa exceeds 32 signed bits");
89  }
90 
91  if (shiftedM == (int64_t(1) << 31))
92  {
93  shiftedM /= 2;
94  shift++;
95  }
96 
97  // TOSA expects right shift to be positive, and embed (1 << 31) into right
98  // shift bits.
99  shift = (-shift) + 31;
100 
101  if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
102  {
103  throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
104  }
105 
106  multiplier = static_cast<int32_t>(shiftedM);
107 
108  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
109  // The limit of 62 on shift allows the shift to be decomposed as
110  // two right shifts of 31.
111  if (shift > 62)
112  {
113  // Shifting the multiplier by more than 32-bits is unnecessary.
114  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
115  shift = 62;
116  }
117 }
118 
119 /// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
120 /// From a scale value, generates multiplier and shift values where
121 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
122 /// multiplier = mantissa*2^shift for 16-bit scaling.
123 inline void ComputeMultiplierAndShiftTosaScale16(double scale,
124  int32_t &multiplier,
125  int32_t &shift)
126 {
127  const double mantissa = std::frexp(scale, &shift);
128  auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
129 
130  // Can't be greater than 1.0.
131  if (!(shiftedM <= (int64_t(1) << 15)))
132  {
133  throw armnn::Exception("Shifted mantissa exceeds 16 signed bits");
134  }
135 
136  if (shiftedM == (int64_t(1) << 15))
137  {
138  shiftedM /= 2;
139  shift++;
140  }
141 
142  // TOSA expects right shift to be positive and embed (1 << 15) into right
143  // shift bits.
144  shift = (-shift) + 15;
145 
146  if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
147  {
148  throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
149  }
150 
151  multiplier = static_cast<int32_t>(shiftedM);
152 
153  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
154  // The limit of 62 on shift allows the shift to be decomposed as
155  // two right shifts of 31.
156  if (shift > 62)
157  {
158  // Shifting the multiplier by more than 31-bits is unnecessary.
159  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
160  shift = 62;
161  }
162 }
163 
164 inline void CreateRescaleTosaOperator(const std::string& inputName,
165  const std::string& outputName,
166  double scale,
167  int32_t input_zp,
168  int32_t output_zp,
169  bool input_unsigned,
170  bool output_unsigned,
171  bool double_round,
172  bool scale32,
173  TosaSerializationOperator** op)
174 {
175  int32_t multiplier;
176  int32_t shift;
177 
178  if (scale32)
179  {
180  ComputeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
181  }
182  else
183  {
184  ComputeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
185  }
186 
187  const std::vector<int32_t> multipliers{multiplier};
188  const std::vector<int32_t> shifts{shift};
189  CreateRawRescaleTosaOperator(inputName, outputName, multipliers, shifts,
190  input_zp, output_zp, input_unsigned, output_unsigned,
191  double_round, scale32, false, op);
192 }
193 
194 inline void CreateRescaleTosaOperatorForWeights(const std::string& inputName,
195  const std::string& outputName,
196  int32_t input_zp,
197  int32_t output_zp,
198  bool input_unsigned,
199  bool output_unsigned,
200  bool double_round,
201  bool scale32,
202  double input_scale,
203  double output_scale,
204  const std::vector<float>& weight_scales,
205  TosaSerializationOperator** op)
206 {
207  std::vector<int32_t> op_tensor_multipliers;
208  std::vector<int32_t> op_tensor_shifts;
209  op_tensor_multipliers.reserve(weight_scales.size());
210  op_tensor_shifts.reserve(weight_scales.size());
211 
212  for (const float& weight_scale : weight_scales)
213  {
214  double op_tensor_scale = (input_scale * weight_scale) / output_scale;
215  int32_t multiplier;
216  int32_t shift;
217 
218  if (scale32)
219  {
220  ComputeMultiplierAndShiftTosaScale32(op_tensor_scale, multiplier, shift);
221  }
222  else
223  {
224  ComputeMultiplierAndShiftTosaScale16(op_tensor_scale, multiplier, shift);
225  }
226 
227  op_tensor_multipliers.push_back(multiplier);
228  op_tensor_shifts.push_back(shift);
229  }
230 
231  bool per_channel = weight_scales.size() == 1 ? false : true;
232  CreateRawRescaleTosaOperator(inputName, outputName, op_tensor_multipliers, op_tensor_shifts,
233  input_zp, output_zp, input_unsigned, output_unsigned, double_round,
234  scale32, per_channel, op);
235 }
void ComputeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, int32_t &shift)
The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project From a sca...
void CreateRawRescaleTosaOperator(const std::string &inputName, const std::string &outputName, const std::vector< int32_t > &multipliers, const std::vector< int32_t > &shifts, int32_t input_zp, int32_t output_zp, bool input_unsigned, bool output_unsigned, bool double_round, bool scale32, bool per_channel, TosaSerializationOperator **op)
void CreateRescaleTosaOperator(const std::string &inputName, const std::string &outputName, double scale, int32_t input_zp, int32_t output_zp, bool input_unsigned, bool output_unsigned, bool double_round, bool scale32, TosaSerializationOperator **op)
void ComputeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, int32_t &shift)
The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project From a sca...
void CreateRescaleTosaOperatorForWeights(const std::string &inputName, const std::string &outputName, int32_t input_zp, int32_t output_zp, bool input_unsigned, bool output_unsigned, bool double_round, bool scale32, double input_scale, double output_scale, const std::vector< float > &weight_scales, TosaSerializationOperator **op)
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:47