24 #ifndef SRC_CORE_SVE_KERNELS_ELEMENTWISE_QUANTIZED_LIST_H
25 #define SRC_CORE_SVE_KERNELS_ELEMENTWISE_QUANTIZED_LIST_H
36 auto x = svld1(pg, ptr);
38 const auto widened = svcreate4(svmovlb(svmovlb(x)), svmovlt(svmovlb(x)), svmovlb(svmovlt(x)), svmovlt(svmovlt(x)));
42 return svcreate4(svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svget4(widened, 0),
offset)),
scale),
43 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svget4(widened, 1),
offset)),
scale),
44 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svget4(widened, 2),
offset)),
scale),
45 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svget4(widened, 3),
offset)),
scale));
50 auto x = svld1(pg, ptr);
54 const auto widened = svcreate4(svmovlb(svmovlb(x)), svmovlt(svmovlb(x)), svmovlb(svmovlt(x)), svmovlt(svmovlt(x)));
58 return svcreate4(svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svreinterpret_s32(svget4(widened, 0)),
offset)),
scale),
59 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svreinterpret_s32(svget4(widened, 1)),
offset)),
scale),
60 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svreinterpret_s32(svget4(widened, 2)),
offset)),
scale),
61 svmul_z(pg, svcvt_f32_z(pg, svsub_z(pg, svreinterpret_s32(svget4(widened, 3)),
offset)),
scale));
65 store_quantized(uint8_t *ptr, svbool_t pg, svfloat32x4_t data,
const svint32_t &
offset,
const svfloat32_t &inv_scale)
67 const auto quantized =
68 svcreate4(svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 0), inv_scale))),
offset),
69 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 1), inv_scale))),
offset),
70 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 2), inv_scale))),
offset),
71 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 3), inv_scale))),
offset));
73 const auto narrowed_bottom = svqxtunt(svqxtunb(svget4(quantized, 0)), svget4(quantized, 1));
74 const auto narrowed_top = svqxtunt(svqxtunb(svget4(quantized, 2)), svget4(quantized, 3));
75 const auto narrowed = svqxtnt(svqxtnb(narrowed_bottom), narrowed_top);
76 svst1(pg, ptr, narrowed);
80 store_quantized(int8_t *ptr, svbool_t pg, svfloat32x4_t data,
const svint32_t &
offset,
const svfloat32_t &inv_scale)
82 const auto quantized =
83 svcreate4(svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 0), inv_scale))),
offset),
84 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 1), inv_scale))),
offset),
85 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 2), inv_scale))),
offset),
86 svadd_z(pg, svcvt_s32_z(pg, svrinta_z(pg, svmul_z(pg, svget4(data, 3), inv_scale))),
offset));
88 const auto narrowed_bottom = svqxtnt(svqxtnb(svget4(quantized, 0)), svget4(quantized, 1));
89 const auto narrowed_top = svqxtnt(svqxtnb(svget4(quantized, 2)), svget4(quantized, 3));
90 const auto narrowed = svqxtnt(svqxtnb(narrowed_bottom), narrowed_top);
92 svst1(pg, ptr, narrowed);
95 template <
typename ScalarType>
99 const auto all_true_pg = wrapper::svptrue<ScalarType>();
109 const auto window_start_x =
static_cast<int>(window.
x().
start());
110 const auto window_end_x =
static_cast<int>(window.
x().
end());
116 if (is_broadcast_across_x)
118 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
119 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
120 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
121 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
122 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
124 const auto non_broadcast_qinfo =
126 const auto broadcast_qinfo =
129 const auto non_broadcast_voffset = svdup_n(non_broadcast_qinfo.uniform().offset);
130 const auto non_broadcast_vscale = svdup_n(non_broadcast_qinfo.uniform().scale);
135 Iterator broadcast_input(broadcast_tensor, broadcast_win);
136 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
143 auto output_ptr =
reinterpret_cast<ScalarType *
>(output.
ptr());
144 const auto non_broadcast_input_ptr =
reinterpret_cast<const ScalarType *
>(non_broadcast_input.
ptr());
145 const ScalarType broadcast_value = *
reinterpret_cast<const ScalarType *
>(broadcast_input.
ptr());
146 const float broadcast_value_f =
148 const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f),
149 svdup_n(broadcast_value_f), svdup_n(broadcast_value_f));
151 int x = window_start_x;
153 svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
157 load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale);
159 svfloat32x4_t result{};
161 if (!is_broadcast_input_2)
164 svcreate4(elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 0), svget4(in1, 0), op),
165 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 1), svget4(in1, 1), op),
166 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 2), svget4(in1, 2), op),
167 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 3), svget4(in1, 3), op));
172 svcreate4(elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 0), svget4(in2, 0), op),
173 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 1), svget4(in2, 1), op),
174 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 2), svget4(in2, 2), op),
175 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 3), svget4(in2, 3), op));
178 store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale);
180 x += wrapper::svcnt<ScalarType>();
181 pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
182 }
while (svptest_any(all_true_pg, pg));
184 broadcast_input, non_broadcast_input, output);
206 auto output_ptr =
reinterpret_cast<ScalarType *
>(output.
ptr());
207 const auto input1_ptr =
reinterpret_cast<const ScalarType *
>(input1.
ptr());
208 const auto input2_ptr =
reinterpret_cast<const ScalarType *
>(input2.
ptr());
210 int x = window_start_x;
212 svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
215 const auto in1 =
load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale);
216 const auto in2 =
load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale);
219 svcreate4(elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 0), svget4(in2, 0), op),
220 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 1), svget4(in2, 1), op),
221 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 2), svget4(in2, 2), op),
222 elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 3), svget4(in2, 3), op));
224 store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale);
226 x += wrapper::svcnt<ScalarType>();
227 pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
228 }
while (svptest_any(all_true_pg, pg));
230 input1, input2, output);
234 template <
typename InputScalarType,
typename OutputScalarType = u
int8_t>
238 static_assert(
sizeof(InputScalarType) >=
sizeof(OutputScalarType),
239 "input data type's width should be equal to or greater than output data type's width");
242 const auto all_true_pg = wrapper::svptrue<InputScalarType>();
252 const auto window_start_x =
static_cast<int>(window.
x().
start());
253 const auto window_end_x =
static_cast<int>(window.
x().
end());
256 if (is_broadcast_across_x)
258 const bool is_broadcast_input_2 = input2_win.
x().
step() == 0;
259 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
260 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
261 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
262 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
264 const auto non_broadcast_qinfo =
266 const auto broadcast_qinfo =
269 const auto non_broadcast_voffset = svdup_n(non_broadcast_qinfo.uniform().offset);
270 const auto non_broadcast_vscale = svdup_n(non_broadcast_qinfo.uniform().scale);
275 Iterator broadcast_input(broadcast_tensor, broadcast_win);
276 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
283 auto output_ptr =
reinterpret_cast<OutputScalarType *
>(output.
ptr());
284 const auto non_broadcast_input_ptr =
285 reinterpret_cast<const InputScalarType *
>(non_broadcast_input.
ptr());
286 const InputScalarType broadcast_value =
287 *
reinterpret_cast<const InputScalarType *
>(broadcast_input.
ptr());
288 const float broadcast_value_f =
290 const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f),
291 svdup_n(broadcast_value_f), svdup_n(broadcast_value_f));
293 int x = window_start_x;
295 svbool_t pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
299 load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale);
301 svuint8x4_t result{};
303 if (!is_broadcast_input_2)
305 result = svcreate4(elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 0),
307 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 1),
309 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 2),
311 elementwise_comparison_op<svfloat32_t, OutputVectorType>(
312 pg, svget4(in2, 3), svget4(in1, 3), op));
316 result = svcreate4(elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 0),
318 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 1),
320 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 2),
322 elementwise_comparison_op<svfloat32_t, OutputVectorType>(
323 pg, svget4(in1, 3), svget4(in2, 3), op));
326 const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
327 const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
328 const auto zipped = svzip1(zipped_bottom, zipped_top);
329 svst1(pg, output_ptr + x, zipped);
331 x += wrapper::svcnt<InputScalarType>();
332 pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
333 }
while (svptest_any(all_true_pg, pg));
335 broadcast_input, non_broadcast_input, output);
357 auto output_ptr =
reinterpret_cast<OutputScalarType *
>(output.
ptr());
358 const auto input1_ptr =
reinterpret_cast<const InputScalarType *
>(input1.
ptr());
359 const auto input2_ptr =
reinterpret_cast<const InputScalarType *
>(input2.
ptr());
361 int x = window_start_x;
363 svbool_t pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
366 const auto in1 =
load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale);
367 const auto in2 =
load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale);
369 svcreate4(elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 0),
371 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 1),
373 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 2),
375 elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 3),
376 svget4(in2, 3), op));
378 const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
379 const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
380 const auto zipped = svzip1(zipped_bottom, zipped_top);
381 svst1(pg, output_ptr + x, zipped);
383 x += wrapper::svcnt<InputScalarType>();
384 pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
385 }
while (svptest_any(all_true_pg, pg));
387 input1, input2, output);