52 const auto oq =
dst->quantization_info().uniform();
54 const auto scale0 = iq0.scale / oq.scale;
55 const auto scale1 = iq1.scale / oq.scale;
57 if (scale0 < -15.f || scale0 > 15.f || scale1 < -15.f || scale1 > 15.f)
63 const auto offset = float(oq.offset) - scale0 * float(iq0.offset) - scale1 * float(iq1.offset);
65 const auto max_acc = is_addition ? ((std::abs(scale0) + std::abs(scale1)) * 256.f + std::abs(
offset))
66 : ((std::abs(scale0) - std::abs(scale1)) * 256.f + std::abs(
offset));
68 if (max_acc > 1048575.f)
77 template <
typename ScalarType>
81 add_sub_q8_neon_fixedpoint<ScalarType>(src0, src1,
dst, policy, window,
true );
84 template <
typename ScalarType>
94 const auto in0_info = src0->
info();
95 const auto in1_info = src1->
info();
98 const auto &in1_shape = in1_info->tensor_shape();
108 constexpr
int window_step_x = 16;
109 const auto window_start_x = window.
x().
start();
110 const auto window_end_x = window.
x().
end();
111 const auto is_broadcast_across_x = in0_shape.x() != in1_shape.x();
113 const auto iq0_info = in0_info->quantization_info().uniform();
114 const auto iq1_info = in1_info->quantization_info().uniform();
115 const auto oq_info =
dst->info()->quantization_info().uniform();
116 const auto in0_scale = iq0_info.scale / oq_info.scale;
117 const auto in1_scale = is_addition ? (iq1_info.scale / oq_info.scale) : (-(iq1_info.scale / oq_info.scale));
118 const auto offset = float(oq_info.offset) - in0_scale * float(iq0_info.offset) - in1_scale * float(iq1_info.offset);
120 constexpr
float _2pow11 = 2048;
125 constexpr uint8_t shift_amount_remainder = 3;
127 if (is_broadcast_across_x)
131 const auto is_broadcast_input_1 = in1_win.
x().
step() == 0;
132 auto a_win = is_broadcast_input_1 ? in0_win : in1_win;
133 auto b_win = is_broadcast_input_1 ? in1_win : in0_win;
134 const auto a_tensor = is_broadcast_input_1 ? src0 : src1;
135 const auto b_tensor = is_broadcast_input_1 ? src1 : src0;
137 const auto a_scale_5p11 = is_broadcast_input_1 ? in0_scale_5p11 : in1_scale_5p11;
138 const auto b_scale = is_broadcast_input_1 ? in1_scale : in0_scale;
142 const auto a_scale = is_broadcast_input_1 ? in0_scale : in1_scale;
143 #endif // __aarch64__
148 Iterator a_input_it(a_tensor, a_win);
149 Iterator b_input_it(b_tensor, b_win);
156 const auto a_ptr =
reinterpret_cast<const ScalarType *
>(a_input_it.
ptr());
157 const auto b_ptr =
reinterpret_cast<const ScalarType *
>(b_input_it.
ptr());
158 const auto out_ptr =
reinterpret_cast<ScalarType *
>(out_it.
ptr());
160 const auto b_val = *b_ptr;
161 const auto b_scaled = b_scale * b_val;
163 const auto b_scaled_offseted_21p11 = b_scaled_21p11 + offset_21p11;
164 const auto b_vscaled_offseted_21p11 =
168 const auto b_scaled_offseted = b_scaled +
offset;
169 #endif // __aarch64__
171 int x = window_start_x;
173 for (; x <= (window_end_x - window_step_x); x += window_step_x)
184 const auto vout_21p11_00 =
186 const auto vout_21p11_01 =
188 const auto vout_21p11_10 =
190 const auto vout_21p11_11 =
194 const auto vout_8p8_0 =
195 wrapper::vcombine(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_00),
196 wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_01));
197 const auto vout_8p8_1 =
198 wrapper::vcombine(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_10),
199 wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_11));
202 const auto vout_8p0 =
203 wrapper::vcombine(wrapper::vqrshrn<8>(vout_8p8_0), wrapper::vqrshrn<8>(vout_8p8_1));
210 for (; x < window_end_x; ++x)
213 out_ptr[x] = wrapper::vqrshrn<8>(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(
214 int32_t(a_ptr[x]) * a_scale_5p11 + b_scaled_offseted_21p11));
216 out_ptr[x] = utility::clamp<int, ScalarType>(
218 #endif // __aarch64__
221 b_input_it, a_input_it, out_it);
241 const auto in0_ptr =
reinterpret_cast<const ScalarType *
>(in0_it.
ptr());
242 const auto in1_ptr =
reinterpret_cast<const ScalarType *
>(in1_it.
ptr());
243 const auto out_ptr =
reinterpret_cast<ScalarType *
>(out_it.
ptr());
245 int x = window_start_x;
247 for (; x <= (window_end_x - window_step_x); x += window_step_x)
261 const auto vscaled0_offseted_21p11_00 =
263 const auto vscaled0_offseted_21p11_01 =
265 const auto vscaled0_offseted_21p11_10 =
267 const auto vscaled0_offseted_21p11_11 =
270 const auto vout_21p11_00 =
272 const auto vout_21p11_01 =
274 const auto vout_21p11_10 =
276 const auto vout_21p11_11 =
280 const auto vout_8p8_0 =
281 wrapper::vcombine(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_00),
282 wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_01));
283 const auto vout_8p8_1 =
284 wrapper::vcombine(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_10),
285 wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(vout_21p11_11));
288 const auto vout_8p0 =
289 wrapper::vcombine(wrapper::vqrshrn<8>(vout_8p8_0), wrapper::vqrshrn<8>(vout_8p8_1));
296 for (; x < window_end_x; ++x)
299 out_ptr[x] = wrapper::vqrshrn<8>(wrapper::vqrshrn_ex<shift_amount_remainder, ScalarType>(
300 int32_t(in0_ptr[x]) * in0_scale_5p11 + int32_t(in1_ptr[x]) * in1_scale_5p11 + offset_21p11));
302 out_ptr[x] = utility::clamp<int, ScalarType>(
304 #endif // __aarch64__
307 in0_it, in1_it, out_it);
328 constexpr
int window_step_x = 16;
329 const auto window_start_x =
static_cast<int>(window.
x().
start());
330 const auto window_end_x =
static_cast<int>(window.
x().
end());
337 const auto scale1 = iq1_info.
scale / oq_info.
scale;
338 const auto scale2 = is_addition ? (iq2_info.
scale / oq_info.
scale) : (-(iq2_info.
scale / oq_info.
scale));
341 if (is_broadcast_across_x)
343 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
344 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
345 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
346 const ITensor *broadcast_tensor = is_broadcast_input_2 ? src1 : src0;
347 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src1 : src0;
349 const auto af_scale = is_broadcast_input_2 ? scale1 : scale2;
350 const auto bf_scale = is_broadcast_input_2 ? scale2 : scale1;
351 const auto vscale1 = vdupq_n_f32(af_scale);
356 Iterator broadcast_input(broadcast_tensor, broadcast_win);
357 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
364 const auto non_broadcast_input_ptr = non_broadcast_input.
ptr();
365 const auto output_ptr = output.
ptr();
367 const auto broadcast_value = *broadcast_input.
ptr();
368 const auto bf = vdupq_n_f32(
float(broadcast_value) * scale2 +
offset);
369 const auto bfs = float(broadcast_value) * bf_scale +
offset;
372 int x = window_start_x;
373 for (; x <= (window_end_x - window_step_x); x += window_step_x)
375 const uint8x16_t a = vld1q_u8(non_broadcast_input_ptr + x);
377 const auto a_u16_0 = vmovl_u8(vget_low_u8(a));
378 const auto a_u16_1 = vmovl_u8(vget_high_u8(a));
380 const auto af_0 = vmlaq_f32(bf, vcvtq_f32_u32(vmovl_u16(vget_low_u16(a_u16_0))), vscale1);
381 const auto af_1 = vmlaq_f32(bf, vcvtq_f32_u32(vmovl_u16(vget_high_u16(a_u16_0))), vscale1);
382 const auto af_2 = vmlaq_f32(bf, vcvtq_f32_u32(vmovl_u16(vget_low_u16(a_u16_1))), vscale1);
383 const auto af_3 = vmlaq_f32(bf, vcvtq_f32_u32(vmovl_u16(vget_high_u16(a_u16_1))), vscale1);
391 rf_0 = vcvtnq_s32_f32(af_0);
392 rf_1 = vcvtnq_s32_f32(af_1);
393 rf_2 = vcvtnq_s32_f32(af_2);
394 rf_3 = vcvtnq_s32_f32(af_3);
396 rf_0 = vcvtq_s32_f32(af_0);
397 rf_1 = vcvtq_s32_f32(af_1);
398 rf_2 = vcvtq_s32_f32(af_2);
399 rf_3 = vcvtq_s32_f32(af_3);
402 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf_0), vqmovn_s32(rf_1)));
403 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf_2), vqmovn_s32(rf_3)));
404 vst1q_u8(output_ptr + x, vcombine_u8(pa, pb));
408 for (; x < window_end_x; ++x)
410 const auto result = float(non_broadcast_input_ptr[x]) * af_scale +
bfs;
415 #endif // __aarch64__
418 broadcast_input, non_broadcast_input, output);
430 const auto vscale1 = vdupq_n_f32(scale1);
431 const auto vscale2 = vdupq_n_f32(scale2);
432 const auto voffset = vdupq_n_f32(
offset);
438 const auto input1_ptr = input1.
ptr();
439 const auto input2_ptr = input2.
ptr();
440 const auto output_ptr = output.
ptr();
443 int x = window_start_x;
444 for (; x <= (window_end_x - window_step_x); x += window_step_x)
446 const uint8x16_t a = vld1q_u8(input1_ptr + x);
447 const uint8x16_t
b = vld1q_u8(input2_ptr + x);
449 const auto a_u16_0 = vmovl_u8(vget_low_u8(a));
450 const auto a_u16_1 = vmovl_u8(vget_high_u8(a));
451 const auto b_u16_0 = vmovl_u8(vget_low_u8(
b));
452 const auto b_u16_1 = vmovl_u8(vget_high_u8(
b));
454 const auto af_0 = vmlaq_f32(voffset, vcvtq_f32_u32(vmovl_u16(vget_low_u16(a_u16_0))), vscale1);
455 const auto af_1 = vmlaq_f32(voffset, vcvtq_f32_u32(vmovl_u16(vget_high_u16(a_u16_0))), vscale1);
456 const auto af_2 = vmlaq_f32(voffset, vcvtq_f32_u32(vmovl_u16(vget_low_u16(a_u16_1))), vscale1);
457 const auto af_3 = vmlaq_f32(voffset, vcvtq_f32_u32(vmovl_u16(vget_high_u16(a_u16_1))), vscale1);
459 const auto bf_0 = vmlaq_f32(af_0, vcvtq_f32_u32(vmovl_u16(vget_low_u16(b_u16_0))), vscale2);
460 const auto bf_1 = vmlaq_f32(af_1, vcvtq_f32_u32(vmovl_u16(vget_high_u16(b_u16_0))), vscale2);
461 const auto bf_2 = vmlaq_f32(af_2, vcvtq_f32_u32(vmovl_u16(vget_low_u16(b_u16_1))), vscale2);
462 const auto bf_3 = vmlaq_f32(af_3, vcvtq_f32_u32(vmovl_u16(vget_high_u16(b_u16_1))), vscale2);
470 rf_0 = vcvtnq_s32_f32(bf_0);
471 rf_1 = vcvtnq_s32_f32(bf_1);
472 rf_2 = vcvtnq_s32_f32(bf_2);
473 rf_3 = vcvtnq_s32_f32(bf_3);
475 rf_0 = vcvtq_s32_f32(bf_0);
476 rf_1 = vcvtq_s32_f32(bf_1);
477 rf_2 = vcvtq_s32_f32(bf_2);
478 rf_3 = vcvtq_s32_f32(bf_3);
481 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf_0), vqmovn_s32(rf_1)));
482 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf_2), vqmovn_s32(rf_3)));
483 vst1q_u8(output_ptr + x, vcombine_u8(pa, pb));
487 for (; x < window_end_x; ++x)
489 const auto result = float(input1_ptr[x]) * scale1 + float(input2_ptr[x]) * scale2 +
offset;
494 #endif // __aarch64__
497 input1, input2, output);
518 constexpr
int window_step_x = 16;
519 const auto window_start_x =
static_cast<int>(window.
x().
start());
520 const auto window_end_x =
static_cast<int>(window.
x().
end());
527 const auto scale1 = iq1_info.
scale / oq_info.
scale;
528 const auto scale2 = is_addition ? (iq2_info.
scale / oq_info.
scale) : (-(iq2_info.
scale / oq_info.
scale));
531 if (is_broadcast_across_x)
533 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
534 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
535 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
536 const ITensor *broadcast_tensor = is_broadcast_input_2 ? src1 : src0;
537 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src1 : src0;
539 const auto af_scale = is_broadcast_input_2 ? scale1 : scale2;
540 const auto bf_scale = is_broadcast_input_2 ? scale2 : scale1;
541 const auto vscale1 = vdupq_n_f32(af_scale);
546 Iterator broadcast_input(broadcast_tensor, broadcast_win);
547 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
554 const auto non_broadcast_input_ptr =
reinterpret_cast<const int8_t *
>(non_broadcast_input.
ptr());
555 const auto output_ptr =
reinterpret_cast<int8_t *
>(output.
ptr());
557 const auto broadcast_value = *
reinterpret_cast<const int8_t *
>(broadcast_input.
ptr());
558 const auto bf = vdupq_n_f32(
float(broadcast_value) * scale2 +
offset);
559 const auto bfs = float(broadcast_value) * bf_scale +
offset;
562 int x = window_start_x;
563 for (; x <= (window_end_x - window_step_x); x += window_step_x)
565 const int8x16_t a = vld1q_s8(non_broadcast_input_ptr + x);
567 const auto a_s16_0 = vmovl_s8(vget_low_s8(a));
568 const auto a_s16_1 = vmovl_s8(vget_high_s8(a));
570 const auto af_0 = vmlaq_f32(bf, vcvtq_f32_s32(vmovl_s16(vget_low_s16(a_s16_0))), vscale1);
571 const auto af_1 = vmlaq_f32(bf, vcvtq_f32_s32(vmovl_s16(vget_high_s16(a_s16_0))), vscale1);
572 const auto af_2 = vmlaq_f32(bf, vcvtq_f32_s32(vmovl_s16(vget_low_s16(a_s16_1))), vscale1);
573 const auto af_3 = vmlaq_f32(bf, vcvtq_f32_s32(vmovl_s16(vget_high_s16(a_s16_1))), vscale1);
581 rf_0 = vcvtnq_s32_f32(af_0);
582 rf_1 = vcvtnq_s32_f32(af_1);
583 rf_2 = vcvtnq_s32_f32(af_2);
584 rf_3 = vcvtnq_s32_f32(af_3);
586 rf_0 = vcvtq_s32_f32(af_0);
587 rf_1 = vcvtq_s32_f32(af_1);
588 rf_2 = vcvtq_s32_f32(af_2);
589 rf_3 = vcvtq_s32_f32(af_3);
592 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(rf_0), vqmovn_s32(rf_1)));
593 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(rf_2), vqmovn_s32(rf_3)));
594 vst1q_s8(output_ptr + x, vcombine_s8(pa, pb));
598 for (; x < window_end_x; ++x)
600 const auto result = float(non_broadcast_input_ptr[x]) * af_scale +
bfs;
605 #endif // __aarch64__
608 broadcast_input, non_broadcast_input, output);
620 const auto vscale1 = vdupq_n_f32(scale1);
621 const auto vscale2 = vdupq_n_f32(scale2);
622 const auto voffset = vdupq_n_f32(
offset);
628 const auto input1_ptr =
reinterpret_cast<const int8_t *
>(input1.
ptr());
629 const auto input2_ptr =
reinterpret_cast<const int8_t *
>(input2.
ptr());
630 const auto output_ptr =
reinterpret_cast<int8_t *
>(output.
ptr());
633 int x = window_start_x;
634 for (; x <= (window_end_x - window_step_x); x += window_step_x)
636 const int8x16_t a = vld1q_s8(input1_ptr + x);
637 const int8x16_t
b = vld1q_s8(input2_ptr + x);
639 const auto a_s16_0 = vmovl_s8(vget_low_s8(a));
640 const auto a_s16_1 = vmovl_s8(vget_high_s8(a));
641 const auto b_s16_0 = vmovl_s8(vget_low_s8(
b));
642 const auto b_s16_1 = vmovl_s8(vget_high_s8(
b));
644 const auto af_0 = vmlaq_f32(voffset, vcvtq_f32_s32(vmovl_s16(vget_low_s16(a_s16_0))), vscale1);
645 const auto af_1 = vmlaq_f32(voffset, vcvtq_f32_s32(vmovl_s16(vget_high_s16(a_s16_0))), vscale1);
646 const auto af_2 = vmlaq_f32(voffset, vcvtq_f32_s32(vmovl_s16(vget_low_s16(a_s16_1))), vscale1);
647 const auto af_3 = vmlaq_f32(voffset, vcvtq_f32_s32(vmovl_s16(vget_high_s16(a_s16_1))), vscale1);
649 const auto bf_0 = vmlaq_f32(af_0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_s16_0))), vscale2);
650 const auto bf_1 = vmlaq_f32(af_1, vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_s16_0))), vscale2);
651 const auto bf_2 = vmlaq_f32(af_2, vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_s16_1))), vscale2);
652 const auto bf_3 = vmlaq_f32(af_3, vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_s16_1))), vscale2);
660 rf_0 = vcvtnq_s32_f32(bf_0);
661 rf_1 = vcvtnq_s32_f32(bf_1);
662 rf_2 = vcvtnq_s32_f32(bf_2);
663 rf_3 = vcvtnq_s32_f32(bf_3);
665 rf_0 = vcvtq_s32_f32(bf_0);
666 rf_1 = vcvtq_s32_f32(bf_1);
667 rf_2 = vcvtq_s32_f32(bf_2);
668 rf_3 = vcvtq_s32_f32(bf_3);
671 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(rf_0), vqmovn_s32(rf_1)));
672 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(rf_2), vqmovn_s32(rf_3)));
673 vst1q_s8(output_ptr + x, vcombine_s8(pa, pb));
677 for (; x < window_end_x; ++x)
679 const auto result = float(input1_ptr[x]) * scale1 + float(input2_ptr[x]) * scale2 +
offset;
684 #endif // __aarch64__
687 input1, input2, output);