24 #ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
25 #define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
33 template <ArithmeticOperation op,
typename VectorType>
37 using scalar_type =
typename VectorType::scalar_type;
38 using tag_type =
typename VectorType::tag_type;
40 vec_type res =
wrapper::vdup_n(
static_cast<scalar_type
>(0), tag_type{});
58 const vec_type zero =
wrapper::vdup_n(
static_cast<scalar_type
>(0), tag_type{});
73 template <ArithmeticOperation op,
typename ScalarType,
typename VectorType>
75 const ScalarType &broadcast_value,
78 using tag_type =
typename VectorType::tag_type;
81 vec_type broadcast_vector =
wrapper::vdup_n(broadcast_value, tag_type{});
82 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
85 template <
typename InputScalarType,
typename OutputScalarType,
typename InputVectorType>
91 OutputScalarType (*scalar_func)(
const InputScalarType &,
const InputScalarType &),
92 int (*broadcast_func)(
93 int,
int,
int,
const InputScalarType *,
const InputScalarType &, OutputScalarType *,
const bool),
94 int (*neon_func)(
int,
int,
int,
const InputScalarType *,
const InputScalarType *, OutputScalarType *))
104 const int window_step_x = std::min(16 /
static_cast<int>(
sizeof(OutputScalarType)), 8);
105 const auto window_start_x =
static_cast<int>(window.
x().
start());
106 const auto window_end_x =
static_cast<int>(window.
x().
end());
109 if (is_broadcast_across_x)
111 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
112 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
113 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
114 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
115 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
120 Iterator broadcast_input(broadcast_tensor, broadcast_win);
121 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
128 auto output_ptr =
reinterpret_cast<OutputScalarType *
>(output.
ptr());
129 const auto non_broadcast_input_ptr =
130 reinterpret_cast<const InputScalarType *
>(non_broadcast_input.
ptr());
131 const InputScalarType broadcast_value =
132 *
reinterpret_cast<const InputScalarType *
>(broadcast_input.
ptr());
134 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr,
135 broadcast_value, output_ptr, !is_broadcast_input_2);
136 for (; x < window_end_x; ++x)
138 const auto a = *(non_broadcast_input_ptr + x);
139 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a,
140 !is_broadcast_input_2 ? a : broadcast_value);
143 broadcast_input, non_broadcast_input, output);
159 auto output_ptr =
reinterpret_cast<OutputScalarType *
>(output.
ptr());
160 const auto input1_ptr =
reinterpret_cast<const InputScalarType *
>(input1.
ptr());
161 const auto input2_ptr =
reinterpret_cast<const InputScalarType *
>(input2.
ptr());
163 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
164 for (; x < window_end_x; ++x)
166 const auto a = *(input1_ptr + x);
167 const auto b = *(input2_ptr + x);
168 *(output_ptr + x) = (*scalar_func)(a,
b);
171 input1, input2, output);
175 template <ArithmeticOperation op,
typename ScalarType>
178 auto res = ScalarType(0);
183 res = std::max(a,
b);
186 res = std::min(a,
b);
190 res = (a -
b) * (a -
b);
195 res = (a > 0 ? a : a *
b);
201 if (std::is_integral<ScalarType>::value)
203 res = (
b == 0) ? 0 : res;
204 if (
static_cast<int32_t
>(a) %
static_cast<int32_t
>(
b) != 0 && ((a < 0) != (
b < 0)))
213 res = std::pow(a,
b);
224 elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(
const int32x4_t &a,
232 elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(
const float32x4_t &a,
233 const float32x4_t &
b)
240 elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(
const float32x4_t &a,
241 const float32x4_t &
b)
246 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
248 inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(
249 const float16x8_t &a,
const float16x8_t &
b)
256 elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(
257 const float16x8_t &a,
const float16x8_t &
b)
261 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
263 template <ArithmeticOperation op,
typename ScalarType,
typename VectorType>
267 const ScalarType *input1_ptr,
268 const ScalarType *input2_ptr,
269 ScalarType *output_ptr)
271 int x = window_start_x;
272 for (; x <= (window_end_x - window_step_x); x += window_step_x)
276 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a,
b));
281 template <ArithmeticOperation op,
typename ScalarType,
typename VectorType>
285 const ScalarType *non_broadcast_input_ptr,
286 const ScalarType &broadcast_value,
287 ScalarType *output_ptr,
290 int x = window_start_x;
291 for (; x <= (window_end_x - window_step_x); x += window_step_x)
295 elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
300 template <ArithmeticOperation op,
typename VectorType>
303 using scalar_type =
typename VectorType::scalar_type;
305 elementwise_op<scalar_type, scalar_type, VectorType>(
306 in1, in2, out, window, &elementwise_arithm_op_scalar<op, scalar_type>,
307 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
308 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
311 template <ComparisonOperation op,
typename InputScalarType>
339 return res ? ~static_cast<uint8_t>(0) :
static_cast<uint8_t
>(0);
342 template <ComparisonOperation op,
typename InputVectorType,
typename OutputVectorType>
345 OutputVectorType res = {0, 0, 0, 0};
374 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType,
typename OutputVectorType>
375 inline OutputVectorType
379 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a,
380 reorder ? a : broadcast_vector);
383 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
387 const InputScalarType *non_broadcast_input_ptr,
388 const InputScalarType &broadcast_value,
392 int x = window_start_x;
393 for (; x <= (window_end_x - window_step_x); x += window_step_x)
395 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(
396 wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
402 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
406 const InputScalarType *non_broadcast_input_ptr,
407 const InputScalarType &broadcast_value,
411 int x = window_start_x;
412 for (; x <= (window_end_x - window_step_x); x += window_step_x)
414 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(
415 wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
421 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
425 const InputScalarType *non_broadcast_input_ptr,
426 const InputScalarType &broadcast_value,
430 int x = window_start_x;
431 for (; x <= (window_end_x - window_step_x); x += window_step_x)
433 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(
434 wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
435 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(
436 wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
439 if (x <= window_end_x - 4)
441 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(
442 wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
443 for (
int i = 0; i < 4; i++)
452 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
456 const InputScalarType *input1_ptr,
457 const InputScalarType *input2_ptr,
460 int x = window_start_x;
461 for (; x <= (window_end_x - window_step_x); x += window_step_x)
465 const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a,
b);
471 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
475 const InputScalarType *input1_ptr,
476 const InputScalarType *input2_ptr,
479 int x = window_start_x;
480 for (; x <= (window_end_x - window_step_x); x += window_step_x)
484 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a,
b);
490 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
494 const InputScalarType *input1_ptr,
495 const InputScalarType *input2_ptr,
498 int x = window_start_x;
499 for (; x <= (window_end_x - window_step_x); x += window_step_x)
503 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a,
b);
506 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a,
b);
509 if (x <= window_end_x - 4)
513 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a,
b);
514 for (
int i = 0; i < 4; i++)
523 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
526 elementwise_op<InputScalarType, uint8_t, InputVectorType>(
527 in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>,
528 &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
529 &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
532 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
535 elementwise_op<InputScalarType, uint8_t, InputVectorType>(
536 in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>,
537 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
538 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
541 template <ComparisonOperation op,
typename InputScalarType,
typename InputVectorType>
544 elementwise_op<InputScalarType, uint8_t, InputVectorType>(
545 in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>,
546 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
547 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
553 const float32x4x4_t out = {{
555 vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))),
offset)),
558 vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))),
offset)),
561 vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))),
offset)),
563 vmulq_f32(vcvtq_f32_s32(
564 vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))),
offset)),
573 const float32x4x4_t out = {{
574 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))),
offset)),
scale),
575 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))),
offset)),
scale),
576 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))),
offset)),
scale),
577 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))),
offset)),
scale),
584 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
585 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
586 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
591 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
592 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
593 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
597 store_quantized(uint8_t *output_ptr,
const float32x4x4_t &rf,
const float32x4_t &
offset,
const float32x4_t &invscale)
600 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[0], invscale)),
601 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[1], invscale)),
602 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[2], invscale)),
603 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[3], invscale)),
610 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
611 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
612 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
616 const float32x4x4_t &rf,
617 const float32x4_t &
offset,
618 const float32x4_t &invscale)
621 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[0], invscale)),
622 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[1], invscale)),
623 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[2], invscale)),
624 vcvtq_s32_f32(vmlaq_f32(
offset, rf.val[3], invscale)),
629 template <ArithmeticOperation op>
635 template <ArithmeticOperation op>
642 template <ArithmeticOperation op>
646 float32x4x4_t out = {{
647 elementwise_arithm_op<op, neon_vector_float>(a.val[0],
b.val[0]),
648 elementwise_arithm_op<op, neon_vector_float>(a.val[1],
b.val[1]),
649 elementwise_arithm_op<op, neon_vector_float>(a.val[2],
b.val[2]),
650 elementwise_arithm_op<op, neon_vector_float>(a.val[3],
b.val[3]),
655 template <ComparisonOperation op>
659 return elementwise_comp_op_scalar<op>(a,
b);
662 template <ComparisonOperation op>
665 uint32x4x4_t out = {{elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0],
b.val[0]),
666 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1],
b.val[1]),
667 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2],
b.val[2]),
668 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3],
b.val[3])}};
672 template <ArithmeticOperation op>
676 const uint8_t *input1_ptr,
677 const uint8_t *input2_ptr,
683 float32x4_t voffseto,
684 float32x4_t invvscaleo)
686 int x = window_start_x;
687 for (; x <= (window_end_x - window_step_x); x += window_step_x)
690 const float32x4x4_t af =
load_quantized(input1_ptr + x, voffset1, vscale1);
691 const float32x4x4_t bf =
load_quantized(input2_ptr + x, voffset2, vscale2);
692 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
698 template <ArithmeticOperation op>
702 const int8_t *input1_ptr,
703 const int8_t *input2_ptr,
709 float32x4_t voffseto,
710 float32x4_t invvscaleo)
712 int x = window_start_x;
713 for (; x <= (window_end_x - window_step_x); x += window_step_x)
718 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
724 template <ArithmeticOperation op>
728 const uint8_t *non_broadcast_input_ptr,
729 float32x4x4_t broadcast_vector,
731 int32x4_t voffset_non_broadcast,
732 float32x4_t vscale_non_broadcast,
733 float32x4_t voffseto,
734 float32x4_t invvscaleo,
737 int x = window_start_x;
738 for (; x <= (window_end_x - window_step_x); x += window_step_x)
740 const float32x4x4_t af =
741 load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
742 const float32x4x4_t rf =
743 elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
748 template <ArithmeticOperation op>
752 const int8_t *non_broadcast_input_ptr,
753 float32x4x4_t broadcast_vector,
755 int32x4_t voffset_non_broadcast,
756 float32x4_t vscale_non_broadcast,
757 float32x4_t voffseto,
758 float32x4_t invvscaleo,
761 int x = window_start_x;
762 for (; x <= (window_end_x - window_step_x); x += window_step_x)
764 const float32x4x4_t af =
766 const float32x4x4_t rf =
767 elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
773 template <ComparisonOperation op>
777 const uint8_t *input1_ptr,
778 const uint8_t *input2_ptr,
784 float32x4_t voffseto,
785 float32x4_t invvscaleo)
788 int x = window_start_x;
789 for (; x <= (window_end_x - window_step_x); x += window_step_x)
791 const float32x4x4_t af =
load_quantized(input1_ptr + x, voffset1, vscale1);
792 const float32x4x4_t bf =
load_quantized(input2_ptr + x, voffset2, vscale2);
793 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
799 template <ComparisonOperation op>
803 const int8_t *input1_ptr,
804 const int8_t *input2_ptr,
810 float32x4_t voffseto,
811 float32x4_t invvscaleo)
814 int x = window_start_x;
815 for (; x <= (window_end_x - window_step_x); x += window_step_x)
819 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
825 template <ComparisonOperation op>
829 const uint8_t *non_broadcast_input_ptr,
830 float32x4x4_t broadcast_vector,
832 int32x4_t voffset_non_broadcast,
833 float32x4_t vscale_non_broadcast,
834 float32x4_t voffseto,
835 float32x4_t invvscaleo,
839 int x = window_start_x;
840 for (; x <= (window_end_x - window_step_x); x += window_step_x)
842 const float32x4x4_t af =
843 load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
844 const uint32x4x4_t rf =
845 elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
851 template <ComparisonOperation op>
855 const int8_t *non_broadcast_input_ptr,
856 float32x4x4_t broadcast_vector,
858 int32x4_t voffset_non_broadcast,
859 float32x4_t vscale_non_broadcast,
860 float32x4_t voffseto,
861 float32x4_t invvscaleo,
865 int x = window_start_x;
866 for (; x <= (window_end_x - window_step_x); x += window_step_x)
868 const float32x4x4_t af =
870 const uint32x4x4_t rf =
871 elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
882 int (*broadcast_func)(
int,
893 int (*neon_func)(
int,
914 const int window_step_x = 16;
915 const auto window_start_x =
static_cast<int>(window.
x().
start());
916 const auto window_end_x =
static_cast<int>(window.
x().
end());
922 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.
offset + 0.5f);
923 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.
scale);
925 if (is_broadcast_across_x)
928 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
929 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
930 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
931 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
932 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
937 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.
offset);
938 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.
scale);
943 Iterator broadcast_input(broadcast_tensor, broadcast_win);
944 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
951 const auto non_broadcast_input_ptr =
reinterpret_cast<const uint8_t *
>(non_broadcast_input.
ptr());
952 const auto output_ptr =
reinterpret_cast<uint8_t *
>(output.
ptr());
954 const uint8_t broadcast_value = *
reinterpret_cast<const uint8_t *
>(broadcast_input.
ptr());
955 const float32x4x4_t broadcast_vector =
vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
957 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr,
958 broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast,
959 voffseto, invvscaleo, !is_broadcast_input_2);
960 for (; x < window_end_x; ++x)
962 const float afs =
dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
964 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ?
bfs : afs,
965 !is_broadcast_input_2 ? afs :
bfs, output_qinfo);
968 broadcast_input, non_broadcast_input, output);
976 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.
offset);
977 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.
scale);
980 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.
offset);
981 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.
scale);
995 const auto input1_ptr =
reinterpret_cast<const uint8_t *
>(input1.
ptr());
996 const auto input2_ptr =
reinterpret_cast<const uint8_t *
>(input2.
ptr());
997 const auto output_ptr =
reinterpret_cast<uint8_t *
>(output.
ptr());
999 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr,
1000 voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo);
1001 for (; x < window_end_x; ++x)
1005 *(output_ptr + x) = (*scalar_func)(afs,
bfs, output_qinfo);
1008 input1, input2, output);
1018 int (*broadcast_func)(
int,
1029 int (*neon_func)(
int,
1050 const int window_step_x = 16;
1051 const auto window_start_x =
static_cast<int>(window.
x().
start());
1052 const auto window_end_x =
static_cast<int>(window.
x().
end());
1057 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.
offset);
1058 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.
scale);
1060 if (is_broadcast_across_x)
1063 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
1064 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
1065 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
1066 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
1067 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
1072 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.
offset);
1073 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.
scale);
1078 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1079 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1086 const auto non_broadcast_input_ptr =
reinterpret_cast<const int8_t *
>(non_broadcast_input.
ptr());
1087 const auto output_ptr =
reinterpret_cast<uint8_t *
>(output.
ptr());
1089 const int8_t broadcast_value = *
reinterpret_cast<const int8_t *
>(broadcast_input.
ptr());
1090 const float32x4x4_t broadcast_vector =
vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1092 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr,
1093 broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast,
1094 voffseto, invvscaleo, !is_broadcast_input_2);
1095 for (; x < window_end_x; ++x)
1099 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ?
bfs : afs,
1100 !is_broadcast_input_2 ? afs :
bfs, output_qinfo);
1103 broadcast_input, non_broadcast_input, output);
1111 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.
offset);
1112 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.
scale);
1115 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.
offset);
1116 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.
scale);
1130 const auto input1_ptr =
reinterpret_cast<const int8_t *
>(input1.
ptr());
1131 const auto input2_ptr =
reinterpret_cast<const int8_t *
>(input2.
ptr());
1132 const auto output_ptr =
reinterpret_cast<uint8_t *
>(output.
ptr());
1134 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr,
1135 voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo);
1136 for (; x < window_end_x; ++x)
1140 *(output_ptr + x) = (*scalar_func)(afs,
bfs, output_qinfo);
1143 input1, input2, output);
1153 int (*broadcast_func)(
int,
1164 int (*neon_func)(
int,
1185 const int window_step_x = 16;
1186 const auto window_start_x =
static_cast<int>(window.
x().
start());
1187 const auto window_end_x =
static_cast<int>(window.
x().
end());
1192 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.
offset);
1193 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.
scale);
1195 if (is_broadcast_across_x)
1198 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
1199 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
1200 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
1201 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
1202 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
1207 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.
offset);
1208 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.
scale);
1213 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1214 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1221 const auto non_broadcast_input_ptr =
reinterpret_cast<const int8_t *
>(non_broadcast_input.
ptr());
1222 const auto output_ptr =
reinterpret_cast<int8_t *
>(output.
ptr());
1224 const int8_t broadcast_value = *
reinterpret_cast<const int8_t *
>(broadcast_input.
ptr());
1225 const float32x4x4_t broadcast_vector =
vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1227 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr,
1228 broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast,
1229 voffseto, invvscaleo, !is_broadcast_input_2);
1230 for (; x < window_end_x; ++x)
1234 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ?
bfs : afs,
1235 !is_broadcast_input_2 ? afs :
bfs, output_qinfo);
1238 broadcast_input, non_broadcast_input, output);
1246 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.
offset);
1247 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.
scale);
1250 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.
offset);
1251 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.
scale);
1265 const auto input1_ptr =
reinterpret_cast<const int8_t *
>(input1.
ptr());
1266 const auto input2_ptr =
reinterpret_cast<const int8_t *
>(input2.
ptr());
1267 const auto output_ptr =
reinterpret_cast<int8_t *
>(output.
ptr());
1269 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr,
1270 voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo);
1271 for (; x < window_end_x; ++x)
1275 *(output_ptr + x) = (*scalar_func)(afs,
bfs, output_qinfo);
1278 input1, input2, output);
1282 template <ArithmeticOperation op>
1286 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1287 &elementwise_arithm_op_quantized_loop<op>);
1290 template <ArithmeticOperation op>
1294 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1295 &elementwise_arithm_op_quantized_singed_loop<op>);
1298 template <ComparisonOperation op>
1302 &elementwise_comp_op_quantized_broadcast_loop<op>,
1303 &elementwise_comp_op_quantized_loop<op>);
1306 template <ComparisonOperation op>
1310 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1311 &elementwise_comp_op_quantized_signed_loop<op>);