Compute Library
 23.08
utils.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 
28 
29 #include <cstddef>
30 #include <limits>
31 #include <tuple>
32 
33 // Macro for unreachable code (e.g. impossible default cases on switch)
34 #define UNREACHABLE(why) __builtin_unreachable()
35 
36 // Paranoid option for the above with assert
37 // #define UNREACHABLE(why) assert(0 && why)
38 
39 namespace arm_gemm {
40 
41 template<typename T>
42 std::string get_type_name() {
43 #ifdef __GNUC__
44  std::string s = __PRETTY_FUNCTION__;
45 
46  auto start = s.find("cls_");
47 
48  if (start==std::string::npos) {
49  return "(unknown)";
50  }
51 
52  for(size_t x = start+4; x<s.size(); x++) {
53  if (s[x] == ';' || s[x] == ']') {
54  return s.substr(start+4, x-(start+4));
55  }
56  }
57 
58  return "(unknown)";
59 #else
60  return "(unsupported)";
61 #endif
62 }
63 
64 template<typename T>
65 inline T iceildiv(const T a, const T b) {
66  return (a + b - 1) / b;
67 }
68 
69 template <typename T>
70 inline T roundup(const T a, const T b) {
71  T rem = a % b;
72 
73  if (rem) {
74  return a + b - rem;
75  } else {
76  return a;
77  }
78 }
79 
80 enum class VLType {
81  None,
82  SVE,
83  SME,
84  SME2
85 };
86 
87 template<typename T>
89  struct {
90  T *base;
91  size_t stride;
92  } direct = {};
93  struct {
94  T * const *ptr;
95  size_t offset;
96  } indirect = {};
98 
99  // Direct
100  IndirectOutputArg(T *base, size_t stride) : is_indirect(false) {
101  direct.base = base;
102  direct.stride = stride;
103  }
104 
105  // Indirect
106  IndirectOutputArg(T * const * ptr, size_t offset) : is_indirect(true) {
107  indirect.ptr = ptr;
108  indirect.offset = offset;
109  }
110 
112  direct.base = nullptr;
113  direct.stride = 0;
114  }
115 };
116 
117 // Check that the provided Requantize32 doesn't have a left shift.
118 inline bool quant_no_left_shift(const Requantize32 &qp) {
119  if (qp.per_channel_requant) {
120  return (qp.per_channel_left_shifts == nullptr);
121  } else {
122  return (qp.per_layer_left_shift == 0);
123  }
124 }
125 
126 // Check that the provided Requantize32 is compatible with the "symmetric" hybrid kernels. These don't include row
127 // sums, so the 'b_offset' has to be zero.
128 inline bool quant_hybrid_symmetric(const Requantize32 &qp) {
129  return quant_no_left_shift(qp) && qp.b_offset == 0;
130 }
131 
132 // Check that the provided Requantize32 is compatible with the "asymmetric" hybrid kernels. These don't support per
133 // channel quantization. Technically b_offset==0 cases would work, but it is a waste to sum and then multiply by 0...
134 inline bool quant_hybrid_asymmetric(const Requantize32 &qp) {
135  return quant_no_left_shift(qp) /* && qp.b_offset != 0 */ && qp.per_channel_requant==false;
136 }
137 
138 template<typename T>
140  struct {
141  const T *base;
142  size_t stride;
143  } direct = {};
144  struct {
145  const T * const * const * ptr;
146  unsigned int start_row;
147  unsigned int start_col;
148  } indirect = {};
150 
151  // Direct
152  IndirectInputArg(const T *base, size_t stride) : is_indirect(false) {
153  direct.base = base;
154  direct.stride = stride;
155  }
156 
157  // Indirect
158  IndirectInputArg(const T * const * const *ptr, unsigned int start_row, unsigned int start_col) : is_indirect(true) {
159  indirect.ptr = ptr;
160  indirect.start_row = start_row;
161  indirect.start_col = start_col;
162  }
163 
165  direct.base = nullptr;
166  direct.stride = 0;
167  }
168 };
169 
170 namespace utils {
171 
172 // get_vector_length(): Returns SVE vector length for type "T".
173 //
174 // It is required that this can be compiled by a compiler in non-SVE mode, but it must be prevented from running (at
175 // runtime) if SVE is not enabled. Typically this is used by switchyard/driver code which is built in normal mode
176 // which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime.
177 template <typename T>
178 inline unsigned long get_vector_length() {
179 #if defined(__aarch64__)
180  uint64_t vl;
181 
182  __asm __volatile (
183  ".inst 0x0420e3e0\n" // CNTB X0, ALL, MUL #1
184  "mov %0, X0\n"
185  : "=r" (vl)
186  :
187  : "x0"
188  );
189 
190  return vl / sizeof(T);
191 #else // !defined(__aarch64__)
192  return 16 / sizeof(T);
193 #endif // defined(__aarch64__)
194 }
195 
196 #ifdef ARM_COMPUTE_ENABLE_SME
197 namespace sme {
198 
199 // function from misc-sve.cpp
200 extern unsigned int raw_vector_length();
201 
202 template <typename T>
203 inline unsigned long get_vector_length() {
204  return raw_vector_length() / sizeof(T);
205 }
206 
207 } // namespace sme
208 #endif // ARM_COMPUTE_ENABLE_SME
209 
210 // get_vector_length(VLType): Returns vector length for type "T".
211 //
212 // This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
213 
214 template <typename T>
215 inline unsigned long get_vector_length(VLType vl_type) {
216  switch (vl_type) {
217 #ifdef ARM_COMPUTE_ENABLE_SME
218  case VLType::SME:
219  return sme::get_vector_length<T>();
220 #endif // ARM_COMPUTE_ENABLE_SME
221  case VLType::SVE:
222  return get_vector_length<T>();
223  default:
224  return 16 / sizeof(T);
225  }
226 }
227 
228 // get_default_activation_values(): Returns the default values for activation min and max for integer activation.
229 template <typename T>
230 inline std::tuple<T, T> get_default_activation_values()
231 {
232  const T min = static_cast<T>(std::numeric_limits<T>::min());
233  const T max = static_cast<T>(std::numeric_limits<T>::max());
234 
235  return std::make_tuple(min, max);
236 }
237 
238 // get_default_activation_values(): Returns the default values for activation min and max for float activation.
239 template <>
240 inline std::tuple<float, float> get_default_activation_values()
241 {
242  const float min = static_cast<float>(-std::numeric_limits<float>::infinity());
243  const float max = static_cast<float>(std::numeric_limits<float>::infinity());
244 
245  return std::make_tuple(min, max);
246 }
247 
248 #if defined(__ARM_FP16_ARGS)
249 // get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation.
250 template <>
251 inline std::tuple<__fp16, __fp16> get_default_activation_values()
252 {
253  const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity());
254  const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity());
255 
256  return std::make_tuple(min, max);
257 }
258 #endif // defined(__ARM_FP16_ARGS)
259 } // utils namespace
260 } // arm_gemm namespace
261 
262 using namespace arm_gemm::utils;
arm_gemm::IndirectOutputArg::IndirectOutputArg
IndirectOutputArg()
Definition: utils.hpp:111
arm_gemm::IndirectInputArg::stride
size_t stride
Definition: utils.hpp:142
arm_gemm::IndirectOutputArg::ptr
T *const * ptr
Definition: utils.hpp:94
arm_gemm::IndirectOutputArg::stride
size_t stride
Definition: utils.hpp:91
arm_gemm::VLType::SME2
@ SME2
arm_gemm::IndirectInputArg::IndirectInputArg
IndirectInputArg(const T *base, size_t stride)
Definition: utils.hpp:152
arm_gemm::IndirectOutputArg::is_indirect
bool is_indirect
Definition: utils.hpp:97
arm_gemm::utils::get_default_activation_values
std::tuple< T, T > get_default_activation_values()
Definition: utils.hpp:230
arm_gemm::IndirectInputArg::IndirectInputArg
IndirectInputArg()
Definition: utils.hpp:164
arm_gemm::quant_no_left_shift
bool quant_no_left_shift(const Requantize32 &qp)
Definition: utils.hpp:118
arm_gemm::roundup
T roundup(const T a, const T b)
Definition: utils.hpp:70
arm_gemm::quant_hybrid_symmetric
bool quant_hybrid_symmetric(const Requantize32 &qp)
Definition: utils.hpp:128
arm_gemm::VLType::SME
@ SME
arm_gemm::IndirectOutputArg::base
T * base
Definition: utils.hpp:90
arm_gemm::IndirectInputArg
Definition: utils.hpp:139
arm_gemm::get_type_name
std::string get_type_name()
Definition: utils.hpp:42
arm_gemm::IndirectInputArg::direct
struct arm_gemm::IndirectInputArg::@3 direct
arm_gemm::Requantize32::per_channel_requant
bool per_channel_requant
Definition: arm_gemm.hpp:177
arm_gemm::utils::get_vector_length
unsigned long get_vector_length()
Definition: utils.hpp:178
arm_gemm::Requantize32::b_offset
int32_t b_offset
Definition: arm_gemm.hpp:175
arm_gemm::IndirectInputArg::start_col
unsigned int start_col
Definition: utils.hpp:147
arm_gemm::IndirectOutputArg::IndirectOutputArg
IndirectOutputArg(T *base, size_t stride)
Definition: utils.hpp:100
arm_gemm::IndirectInputArg::IndirectInputArg
IndirectInputArg(const T *const *const *ptr, unsigned int start_row, unsigned int start_col)
Definition: utils.hpp:158
arm_gemm::Requantize32::per_layer_left_shift
int32_t per_layer_left_shift
Definition: arm_gemm.hpp:178
arm_gemm
Definition: barrier.hpp:30
arm_gemm::Requantize32::per_channel_left_shifts
const int32_t * per_channel_left_shifts
Definition: arm_gemm.hpp:181
arm_gemm::IndirectInputArg::base
const T * base
Definition: utils.hpp:141
arm_gemm::IndirectOutputArg::direct
struct arm_gemm::IndirectOutputArg::@1 direct
arm_gemm.hpp
arm_gemm::quant_hybrid_asymmetric
bool quant_hybrid_asymmetric(const Requantize32 &qp)
Definition: utils.hpp:134
arm_gemm::VLType::SVE
@ SVE
arm_gemm::iceildiv
T iceildiv(const T a, const T b)
Definition: utils.hpp:65
arm_gemm::IndirectOutputArg
Definition: utils.hpp:88
arm_gemm::VLType::None
@ None
arm_gemm::VLType
VLType
Definition: utils.hpp:80
arm_compute::test::validation::b
SimpleTensor< float > b
Definition: DFT.cpp:157
arm_gemm::IndirectOutputArg::IndirectOutputArg
IndirectOutputArg(T *const *ptr, size_t offset)
Definition: utils.hpp:106
arm_gemm::IndirectOutputArg::indirect
struct arm_gemm::IndirectOutputArg::@2 indirect
arm_gemm::IndirectInputArg::ptr
const T *const *const * ptr
Definition: utils.hpp:145
arm_gemm::IndirectOutputArg::offset
size_t offset
Definition: utils.hpp:95
arm_gemm::IndirectInputArg::start_row
unsigned int start_row
Definition: utils.hpp:146
arm_gemm::Requantize32
Definition: arm_gemm.hpp:169
arm_gemm::IndirectInputArg::is_indirect
bool is_indirect
Definition: utils.hpp:149
arm_gemm::utils
Definition: misc-sve.cpp:27
arm_gemm::IndirectInputArg::indirect
struct arm_gemm::IndirectInputArg::@4 indirect