27 #if defined(__ARM_FEATURE_SVE) 30 #define M_PI (3.14159265358979323846) 35 inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, svfloat32_t x,
const std::array<svfloat32_t, 8> &coeffs)
37 const auto A = svmla_f32_z(pg, coeffs[0], coeffs[4], x);
38 const auto B = svmla_f32_z(pg, coeffs[2], coeffs[6], x);
39 const auto C = svmla_f32_z(pg, coeffs[1], coeffs[5], x);
40 const auto D = svmla_f32_z(pg, coeffs[3], coeffs[7], x);
41 const auto x2 = svmul_f32_z(pg, x, x);
42 const auto x4 = svmul_f32_z(pg, x2, x2);
43 const auto res = svmla_f32_z(pg, svmla_f32_z(pg, A, B, x2), svmla_f32_z(pg, C, D, x2), x4);
47 inline svfloat16_t svtaylor_poly_f16_z(svbool_t pg, svfloat16_t x,
const std::array<svfloat16_t, 8> &coeffs)
49 const auto A = svmla_f16_z(pg, coeffs[0], coeffs[4], x);
50 const auto B = svmla_f16_z(pg, coeffs[2], coeffs[6], x);
51 const auto C = svmla_f16_z(pg, coeffs[1], coeffs[5], x);
52 const auto D = svmla_f16_z(pg, coeffs[3], coeffs[7], x);
53 const auto x2 = svmul_f16_z(pg, x, x);
54 const auto x4 = svmul_f16_z(pg, x2, x2);
55 const auto res = svmla_f16_z(pg, svmla_f16_z(pg, A, B, x2), svmla_f16_z(pg, C, D, x2), x4);
59 inline svfloat16_t svinv_f16_z(svbool_t pg, svfloat16_t x)
61 auto recip = svrecpe_f16(x);
62 recip = svmul_f16_z(pg, svrecps_f16(x, recip), recip);
63 recip = svmul_f16_z(pg, svrecps_f16(x, recip), recip);
67 inline svfloat32_t svinv_f32_z(svbool_t pg, svfloat32_t x)
69 auto recip = svrecpe_f32(x);
70 recip = svmul_f32_z(pg, svrecps_f32(x, recip), recip);
71 recip = svmul_f32_z(pg, svrecps_f32(x, recip), recip);
75 inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x)
77 const auto CONST_LN2 = svdup_n_f32(0.6931471805f);
78 const auto CONST_INV_LN2 = svdup_n_f32(1.4426950408f);
79 const auto CONST_INF = svdup_n_f32(std::numeric_limits<float>::infinity());
80 const auto CONST_MAX_INPUT = svdup_n_f32(88.7f);
81 const auto CONST_0 = svdup_n_f32(0.f);
82 const auto CONST_NEGATIVE_126 = svdup_n_s32(-126);
85 const std::array<svfloat32_t, 8>
exp_tab =
89 svdup_n_f32(0.0416598916054f),
90 svdup_n_f32(0.500000596046f),
91 svdup_n_f32(0.0014122662833f),
92 svdup_n_f32(1.00000011921f),
93 svdup_n_f32(0.00833693705499f),
94 svdup_n_f32(0.166665703058f),
95 svdup_n_f32(0.000195780929062f),
100 auto m = svcvt_s32_f32_z(pg, svmul_f32_z(pg, x, CONST_INV_LN2));
101 auto val = svmls_f32_z(pg, x, svcvt_f32_s32_z(pg, m), CONST_LN2);
104 auto poly = svtaylor_poly_f32_z(pg, val, exp_tab);
107 poly = svreinterpret_f32_s32(svqadd_s32(svreinterpret_s32_f32(poly), svlsl_n_s32_z(pg, m, 23)));
110 svbool_t ltpg = svcmplt_s32(pg, m, CONST_NEGATIVE_126);
111 poly = svsel_f32(ltpg, CONST_0, poly);
114 svbool_t gtpg = svcmpgt_f32(pg, x, CONST_MAX_INPUT);
115 poly = svsel_f32(gtpg, CONST_INF, poly);
120 inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x)
122 auto bottom = svcvt_f32_z(pg, x);
123 #if defined(__ARM_FEATURE_SVE2) 124 auto top = svcvtlt_f32_x(pg, x);
127 auto pg_top = svptrue_b16();
128 auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x))));
131 bottom = svexp_f32_z(pg, bottom);
132 top = svexp_f32_z(pg_top, top);
134 #if defined(__ARM_FEATURE_SVE2) 135 return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
137 return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
141 inline svfloat32_t svtanh_f32_z(svbool_t pg, svfloat32_t val)
143 const svfloat32_t CONST_1 = svdup_n_f32(1.f);
144 const svfloat32_t CONST_2 = svdup_n_f32(2.f);
145 const svfloat32_t CONST_MIN_TANH = svdup_n_f32(-10.f);
146 const svfloat32_t CONST_MAX_TANH = svdup_n_f32(10.f);
148 svfloat32_t x = svmin_f32_z(pg, svmax_f32_z(pg, val, CONST_MIN_TANH), CONST_MAX_TANH);
149 svfloat32_t exp2x = svexp_f32_z(pg, svmul_f32_z(pg, CONST_2, x));
150 svfloat32_t num = svsub_f32_z(pg, exp2x, CONST_1);
151 svfloat32_t den = svadd_f32_z(pg, exp2x, CONST_1);
152 svfloat32_t tanh = svdiv_f32_z(pg, num, den);
156 inline svfloat16_t svtanh_f16_z(svbool_t pg, svfloat16_t val)
158 const svfloat16_t CONST_1 = svdup_n_f16(1.f);
159 const svfloat16_t CONST_2 = svdup_n_f16(2.f);
160 const svfloat16_t CONST_MIN_TANH = svdup_n_f16(-10.f);
161 const svfloat16_t CONST_MAX_TANH = svdup_n_f16(10.f);
163 const svfloat16_t x = svmin_f16_z(pg, svmax_f16_z(pg, val, CONST_MIN_TANH), CONST_MAX_TANH);
164 const svfloat16_t exp2x = svexp_f16_z(pg, svmul_f16_z(pg, CONST_2, x));
165 const svfloat16_t num = svsub_f16_z(pg, exp2x, CONST_1);
166 const svfloat16_t den = svadd_f16_z(pg, exp2x, CONST_1);
167 const svfloat16_t tanh = svdiv_f16_z(pg, num, den);
171 inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x)
174 const std::array<svfloat32_t, 8>
log_tab =
177 svdup_n_f32(-2.29561495781f),
178 svdup_n_f32(-2.47071170807f),
179 svdup_n_f32(-5.68692588806f),
180 svdup_n_f32(-0.165253549814f),
181 svdup_n_f32(5.17591238022f),
182 svdup_n_f32(0.844007015228f),
183 svdup_n_f32(4.58445882797f),
184 svdup_n_f32(0.0141278216615f),
188 const auto CONST_127 = svdup_n_s32(127);
189 const auto CONST_LN2 = svdup_n_f32(0.6931471805f);
192 auto m = svsub_s32_z(pg, svasr_n_s32_z(pg, svreinterpret_s32_f32(x), 23), CONST_127);
193 auto val = svreinterpret_f32_s32(svsub_s32_z(pg, svreinterpret_s32_f32(x), svlsl_n_s32_z(pg, m, 23)));
196 auto poly = svtaylor_poly_f32_z(pg, val, log_tab);
199 poly = svmla_f32_z(pg, poly, svcvt_f32_s32_z(pg, m), CONST_LN2);
204 inline svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x)
206 auto bottom = svcvt_f32_z(pg, x);
207 #if defined(__ARM_FEATURE_SVE2) 208 auto top = svcvtlt_f32_x(pg, x);
211 auto pg_top = svptrue_b16();
212 auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x))));
215 bottom = svlog_f32_z(pg, bottom);
216 top = svlog_f32_z(pg_top, top);
218 #if defined(__ARM_FEATURE_SVE2) 219 return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
221 return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
225 inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val)
227 using ScalarType = float;
228 using IntType = uint32_t;
235 const auto pi_v = wrapper::svdup_n(ScalarType(
M_PI));
236 const auto pio2_v = wrapper::svdup_n(ScalarType(
M_PI / 2));
237 const auto ipi_v = wrapper::svdup_n(ScalarType(1 /
M_PI));
240 const auto c_v = svabs_z(pg, wrapper::svcvt_z<int32_t>(pg, svmul_z(pg, val, ipi_v)));
241 const auto sign_v = svcmple(pg, val, wrapper::svdup_n(ScalarType(0)));
242 const auto odd_v = svcmpne(pg, svand_z(pg, wrapper::svreinterpret<IntType>(c_v), wrapper::svdup_n(IntType(1))), wrapper::svdup_n(IntType(0)));
244 auto neg_v = sveor_z(pg, odd_v, sign_v);
247 auto ma = svsub_z(pg, svabs_z(pg, val), svmul_z(pg, pi_v, wrapper::svcvt_z<ScalarType>(pg, c_v)));
248 const auto reb_v = svcmpge(pg, ma, pio2_v);
251 ma = svsel(reb_v, svsub_z(pg, pi_v, ma), ma);
254 const auto ma2 = svmul_z(pg, ma, ma);
257 auto elem = svmul_z(pg, svmul_z(pg, ma, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff2)));
258 auto res = svsub_z(pg, ma, elem);
261 elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff3)));
262 res = svadd_z(pg, res, elem);
265 elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff4)));
266 res = svsub_z(pg, res, elem);
269 elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff5)));
270 res = svadd_z(pg, res, elem);
273 res = svneg_m(res, neg_v, res);
277 inline svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val)
279 auto bottom = svcvt_f32_z(pg, val);
280 #if defined(__ARM_FEATURE_SVE2) 281 auto top = svcvtlt_f32_x(pg, val);
284 auto pg_top = svptrue_b16();
285 auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(val))));
288 bottom = svsin_f32_z(pg, bottom);
289 top = svsin_f32_z(pg_top, top);
291 #if defined(__ARM_FEATURE_SVE2) 292 return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
294 return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
298 inline svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t
b)
300 return svexp_f32_z(pg, svmul_z(pg, b, svlog_f32_z(pg, a)));
303 inline svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b)
305 auto a_bottom = svcvt_f32_z(pg, a);
306 auto b_bottom = svcvt_f32_z(pg, b);
308 #if defined(__ARM_FEATURE_SVE2) 310 auto a_top = svcvtlt_f32_x(pg, a);
311 auto b_top = svcvtlt_f32_x(pg, b);
313 auto pg_top = svptrue_b16();
314 auto a_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(a))));
315 auto b_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(b))));
318 auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom);
319 auto res_top = svpow_f32_z(pg_top, a_top, b_top);
321 #if defined(__ARM_FEATURE_SVE2) 322 return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top);
324 return svtrn1(svcvt_f16_z(pg, res_bottom), svcvt_f16_z(pg_top, res_top));
constexpr float te_sin_coeff5
Copyright (c) 2017-2021 Arm Limited.
const std::array< float32x4_t, 8 > exp_tab
Exponent polynomial coefficients.
constexpr float te_sin_coeff3
constexpr float te_sin_coeff4
constexpr float te_sin_coeff2
Sin polynomial coefficients.
const std::array< float32x4_t, 8 > log_tab
Logarithm polynomial coefficients.