58 template<
bool do_shift_correction,
bool per_channel,
bool do_left_shift>
59 void requantize_block_32_int(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
60 const int32_t *
input,
unsigned int in_stride, int8_t *output,
unsigned int out_stride,
61 const int32_t *row_bias,
const int32_t *col_bias,
const unsigned int start_col) {
62 const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
63 const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
64 const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
65 const int32x4_t v_minval = vdupq_n_s32(qp.minval);
66 const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
67 const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
72 for (
unsigned int row=0; row<height; row+=2) {
76 unsigned int blocks=(width / 16);
77 unsigned int regs=(width % 16) / 4;
78 unsigned int odds=(width % 4);
80 const int32_t *colptr = col_bias;
81 const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
82 const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
83 const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
85 const int32_t *in_ptr =
input + (row * in_stride);
86 int8_t *out_ptr = output + (row * out_stride);
87 int32_t row_sum = row_bias[row];
89 const int32_t *in_ptr1;
93 if (row == height-1) {
98 in_ptr1 = in_ptr + in_stride;
99 out_ptr1 = out_ptr + out_stride;
100 row_sum1 = row_bias[row+1];
103 const int32x4_t v_row_sum = vdupq_n_s32(row_sum);
104 const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
123 v_mul0 = vld1q_s32(perch_mul_ptr);
124 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
125 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
126 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
129 v_shf0 = vld1q_s32(perch_shift_ptr);
130 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
131 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
132 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
133 perch_shift_ptr += 16;
136 v_shf0l = vld1q_s32(perch_shiftl_ptr);
137 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
138 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
139 v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
140 perch_shiftl_ptr += 16;
143 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
144 v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
145 v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
149 int32x4_t v_col0 = vld1q_s32(colptr);
150 int32x4_t v_col1 = vld1q_s32(colptr + 4);
151 int32x4_t v_col2 = vld1q_s32(colptr + 8);
152 int32x4_t v_col3 = vld1q_s32(colptr + 12);
156 int32x4_t v_in00 = vld1q_s32(in_ptr);
157 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
158 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
159 int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
163 int32x4_t v_in10 = vld1q_s32(in_ptr1);
164 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
165 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
166 int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
170 v_in00 = vaddq_s32(v_in00, v_row_sum);
171 v_in01 = vaddq_s32(v_in01, v_row_sum);
172 v_in02 = vaddq_s32(v_in02, v_row_sum);
173 v_in03 = vaddq_s32(v_in03, v_row_sum);
175 v_in10 = vaddq_s32(v_in10, v_row_sum1);
176 v_in11 = vaddq_s32(v_in11, v_row_sum1);
177 v_in12 = vaddq_s32(v_in12, v_row_sum1);
178 v_in13 = vaddq_s32(v_in13, v_row_sum1);
180 v_in00 = vaddq_s32(v_in00, v_col0);
181 v_in01 = vaddq_s32(v_in01, v_col1);
182 v_in02 = vaddq_s32(v_in02, v_col2);
183 v_in03 = vaddq_s32(v_in03, v_col3);
185 v_in10 = vaddq_s32(v_in10, v_col0);
186 v_in11 = vaddq_s32(v_in11, v_col1);
187 v_in12 = vaddq_s32(v_in12, v_col2);
188 v_in13 = vaddq_s32(v_in13, v_col3);
194 v_in00 = vrshlq_s32(v_in00, v_shf0l);
195 v_in01 = vrshlq_s32(v_in01, v_shf1l);
196 v_in02 = vrshlq_s32(v_in02, v_shf2l);
197 v_in03 = vrshlq_s32(v_in03, v_shf3l);
199 v_in10 = vrshlq_s32(v_in10, v_shf0l);
200 v_in11 = vrshlq_s32(v_in11, v_shf1l);
201 v_in12 = vrshlq_s32(v_in12, v_shf2l);
202 v_in13 = vrshlq_s32(v_in13, v_shf3l);
206 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
207 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
208 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
209 v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
211 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
212 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
213 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
214 v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
217 if (do_shift_correction) {
218 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
219 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
220 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
221 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
223 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
224 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
225 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
226 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
228 v_temp00 = vshrq_n_s32(v_temp00, 31);
229 v_temp01 = vshrq_n_s32(v_temp01, 31);
230 v_temp02 = vshrq_n_s32(v_temp02, 31);
231 v_temp03 = vshrq_n_s32(v_temp03, 31);
233 v_temp10 = vshrq_n_s32(v_temp10, 31);
234 v_temp11 = vshrq_n_s32(v_temp11, 31);
235 v_temp12 = vshrq_n_s32(v_temp12, 31);
236 v_temp13 = vshrq_n_s32(v_temp13, 31);
238 v_in00 = vqaddq_s32(v_in00, v_temp00);
239 v_in01 = vqaddq_s32(v_in01, v_temp01);
240 v_in02 = vqaddq_s32(v_in02, v_temp02);
241 v_in03 = vqaddq_s32(v_in03, v_temp03);
243 v_in10 = vqaddq_s32(v_in10, v_temp10);
244 v_in11 = vqaddq_s32(v_in11, v_temp11);
245 v_in12 = vqaddq_s32(v_in12, v_temp12);
246 v_in13 = vqaddq_s32(v_in13, v_temp13);
249 v_in00 = vrshlq_s32(v_in00, v_shf0);
250 v_in01 = vrshlq_s32(v_in01, v_shf1);
251 v_in02 = vrshlq_s32(v_in02, v_shf2);
252 v_in03 = vrshlq_s32(v_in03, v_shf3);
254 v_in10 = vrshlq_s32(v_in10, v_shf0);
255 v_in11 = vrshlq_s32(v_in11, v_shf1);
256 v_in12 = vrshlq_s32(v_in12, v_shf2);
257 v_in13 = vrshlq_s32(v_in13, v_shf3);
259 v_in00 = vaddq_s32(v_in00, v_c_offset);
260 v_in01 = vaddq_s32(v_in01, v_c_offset);
261 v_in02 = vaddq_s32(v_in02, v_c_offset);
262 v_in03 = vaddq_s32(v_in03, v_c_offset);
264 v_in10 = vaddq_s32(v_in10, v_c_offset);
265 v_in11 = vaddq_s32(v_in11, v_c_offset);
266 v_in12 = vaddq_s32(v_in12, v_c_offset);
267 v_in13 = vaddq_s32(v_in13, v_c_offset);
269 v_in00 = vmaxq_s32(v_in00, v_minval);
270 v_in01 = vmaxq_s32(v_in01, v_minval);
271 v_in02 = vmaxq_s32(v_in02, v_minval);
272 v_in03 = vmaxq_s32(v_in03, v_minval);
274 v_in10 = vmaxq_s32(v_in10, v_minval);
275 v_in11 = vmaxq_s32(v_in11, v_minval);
276 v_in12 = vmaxq_s32(v_in12, v_minval);
277 v_in13 = vmaxq_s32(v_in13, v_minval);
279 v_in00 = vminq_s32(v_in00, v_maxval);
280 v_in01 = vminq_s32(v_in01, v_maxval);
281 v_in02 = vminq_s32(v_in02, v_maxval);
282 v_in03 = vminq_s32(v_in03, v_maxval);
284 v_in10 = vminq_s32(v_in10, v_maxval);
285 v_in11 = vminq_s32(v_in11, v_maxval);
286 v_in12 = vminq_s32(v_in12, v_maxval);
287 v_in13 = vminq_s32(v_in13, v_maxval);
289 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
290 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
292 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
293 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
295 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
296 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
298 vst1q_s8(out_ptr, v_uz0);
300 vst1q_s8(out_ptr1, v_uz1);
322 v_mul0 = vld1q_s32(perch_mul_ptr);
323 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
324 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
327 v_shf0 = vld1q_s32(perch_shift_ptr);
328 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
329 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
330 perch_shift_ptr += 12;
333 v_shf0l = vld1q_s32(perch_shiftl_ptr);
334 v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
335 v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
336 perch_shiftl_ptr += 12;
339 v_mul0=v_mul1=v_mul2=v_mul;
340 v_shf0=v_shf1=v_shf2=v_right_shift;
341 v_shf0l=v_shf1l=v_shf2l=v_left_shift;
345 int32x4_t v_col0 = vld1q_s32(colptr);
346 int32x4_t v_col1 = vld1q_s32(colptr + 4);
347 int32x4_t v_col2 = vld1q_s32(colptr + 8);
351 int32x4_t v_in00 = vld1q_s32(in_ptr);
352 int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
353 int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
357 int32x4_t v_in10 = vld1q_s32(in_ptr1);
358 int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
359 int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
363 v_in00 = vaddq_s32(v_in00, v_row_sum);
364 v_in01 = vaddq_s32(v_in01, v_row_sum);
365 v_in02 = vaddq_s32(v_in02, v_row_sum);
367 v_in10 = vaddq_s32(v_in10, v_row_sum1);
368 v_in11 = vaddq_s32(v_in11, v_row_sum1);
369 v_in12 = vaddq_s32(v_in12, v_row_sum1);
371 v_in00 = vaddq_s32(v_in00, v_col0);
372 v_in01 = vaddq_s32(v_in01, v_col1);
373 v_in02 = vaddq_s32(v_in02, v_col2);
375 v_in10 = vaddq_s32(v_in10, v_col0);
376 v_in11 = vaddq_s32(v_in11, v_col1);
377 v_in12 = vaddq_s32(v_in12, v_col2);
383 v_in00 = vrshlq_s32(v_in00, v_shf0l);
384 v_in01 = vrshlq_s32(v_in01, v_shf1l);
385 v_in02 = vrshlq_s32(v_in02, v_shf2l);
387 v_in10 = vrshlq_s32(v_in10, v_shf0l);
388 v_in11 = vrshlq_s32(v_in11, v_shf1l);
389 v_in12 = vrshlq_s32(v_in12, v_shf2l);
393 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
394 v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
395 v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
397 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
398 v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
399 v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
402 if (do_shift_correction) {
403 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
404 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
405 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
407 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
408 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
409 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
411 v_temp00 = vshrq_n_s32(v_temp00, 31);
412 v_temp01 = vshrq_n_s32(v_temp01, 31);
413 v_temp02 = vshrq_n_s32(v_temp02, 31);
415 v_temp10 = vshrq_n_s32(v_temp10, 31);
416 v_temp11 = vshrq_n_s32(v_temp11, 31);
417 v_temp12 = vshrq_n_s32(v_temp12, 31);
419 v_in00 = vqaddq_s32(v_in00, v_temp00);
420 v_in01 = vqaddq_s32(v_in01, v_temp01);
421 v_in02 = vqaddq_s32(v_in02, v_temp02);
423 v_in10 = vqaddq_s32(v_in10, v_temp10);
424 v_in11 = vqaddq_s32(v_in11, v_temp11);
425 v_in12 = vqaddq_s32(v_in12, v_temp12);
428 v_in00 = vrshlq_s32(v_in00, v_shf0);
429 v_in01 = vrshlq_s32(v_in01, v_shf1);
430 v_in02 = vrshlq_s32(v_in02, v_shf2);
432 v_in10 = vrshlq_s32(v_in10, v_shf0);
433 v_in11 = vrshlq_s32(v_in11, v_shf1);
434 v_in12 = vrshlq_s32(v_in12, v_shf2);
436 v_in00 = vaddq_s32(v_in00, v_c_offset);
437 v_in01 = vaddq_s32(v_in01, v_c_offset);
438 v_in02 = vaddq_s32(v_in02, v_c_offset);
440 v_in10 = vaddq_s32(v_in10, v_c_offset);
441 v_in11 = vaddq_s32(v_in11, v_c_offset);
442 v_in12 = vaddq_s32(v_in12, v_c_offset);
444 v_in00 = vmaxq_s32(v_in00, v_minval);
445 v_in01 = vmaxq_s32(v_in01, v_minval);
446 v_in02 = vmaxq_s32(v_in02, v_minval);
448 v_in10 = vmaxq_s32(v_in10, v_minval);
449 v_in11 = vmaxq_s32(v_in11, v_minval);
450 v_in12 = vmaxq_s32(v_in12, v_minval);
452 v_in00 = vminq_s32(v_in00, v_maxval);
453 v_in01 = vminq_s32(v_in01, v_maxval);
454 v_in02 = vminq_s32(v_in02, v_maxval);
456 v_in10 = vminq_s32(v_in10, v_maxval);
457 v_in11 = vminq_s32(v_in11, v_maxval);
458 v_in12 = vminq_s32(v_in12, v_maxval);
460 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
461 int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02));
463 int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
464 int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12));
466 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
467 int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
469 vst1q_lane_s64(
reinterpret_cast<int64_t *
>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0);
470 vst1q_lane_s32(
reinterpret_cast<int32_t *
>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2);
472 vst1q_lane_s64(
reinterpret_cast<int64_t *
>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0);
473 vst1q_lane_s32(
reinterpret_cast<int32_t *
>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2);
483 v_mul0 = vld1q_s32(perch_mul_ptr);
486 v_shf0 = vld1q_s32(perch_shift_ptr);
487 perch_shift_ptr += 4;
490 v_shf0l = vld1q_s32(perch_shiftl_ptr);
491 perch_shiftl_ptr += 4;
495 v_shf0=v_right_shift;
496 v_shf0l=v_left_shift;
499 int32x4_t v_col0 = vld1q_s32(colptr);
503 int32x4_t v_in00 = vld1q_s32(in_ptr);
507 int32x4_t v_in10 = vld1q_s32(in_ptr1);
511 v_in00 = vaddq_s32(v_in00, v_row_sum);
513 v_in10 = vaddq_s32(v_in10, v_row_sum1);
516 v_in00 = vaddq_s32(v_in00, v_col0);
518 v_in10 = vaddq_s32(v_in10, v_col0);
522 v_in00 = vrshlq_s32(v_in00, v_shf0l);
524 v_in10 = vrshlq_s32(v_in10, v_shf0l);
528 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
530 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
533 if (do_shift_correction) {
534 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
536 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
538 v_temp00 = vshrq_n_s32(v_temp00, 31);
540 v_temp10 = vshrq_n_s32(v_temp10, 31);
542 v_in00 = vqaddq_s32(v_in00, v_temp00);
544 v_in10 = vqaddq_s32(v_in10, v_temp10);
547 v_in00 = vrshlq_s32(v_in00, v_shf0);
549 v_in10 = vrshlq_s32(v_in10, v_shf0);
551 v_in00 = vaddq_s32(v_in00, v_c_offset);
553 v_in10 = vaddq_s32(v_in10, v_c_offset);
555 v_in00 = vmaxq_s32(v_in00, v_minval);
557 v_in10 = vmaxq_s32(v_in10, v_minval);
559 v_in00 = vminq_s32(v_in00, v_maxval);
561 v_in10 = vminq_s32(v_in10, v_maxval);
563 int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
565 int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
567 vst1q_lane_s32(
reinterpret_cast<int32_t *
>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
569 vst1q_lane_s32(
reinterpret_cast<int32_t *
>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
574 int32x4_t v_col0 = vdupq_n_s32(0);
575 int32x4_t v_in00 = vdupq_n_s32(0);
576 int32x4_t v_in10 = vdupq_n_s32(0);
577 int32x4_t v_mul0 = vdupq_n_s32(0);
578 int32x4_t v_shf0 = vdupq_n_s32(0);
579 int32x4_t v_shf0l = vdupq_n_s32(0);
583 v_shf0 = v_right_shift;
584 v_shf0l = v_left_shift;
588 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
589 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
590 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
592 v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
593 v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
595 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
598 if (odds == 1) {
break; }
600 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
601 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
602 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
604 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
605 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
607 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
610 if (odds == 2) {
break; }
612 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
613 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
614 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
616 v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
617 v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
619 v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
625 v_in00 = vaddq_s32(v_in00, v_row_sum);
627 v_in10 = vaddq_s32(v_in10, v_row_sum1);
630 v_in00 = vaddq_s32(v_in00, v_col0);
632 v_in10 = vaddq_s32(v_in10, v_col0);
636 v_in00 = vrshlq_s32(v_in00, v_shf0l);
638 v_in10 = vrshlq_s32(v_in10, v_shf0l);
642 v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
644 v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
647 if (do_shift_correction) {
648 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
650 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
652 v_temp00 = vshrq_n_s32(v_temp00, 31);
654 v_temp10 = vshrq_n_s32(v_temp10, 31);
656 v_in00 = vqaddq_s32(v_in00, v_temp00);
658 v_in10 = vqaddq_s32(v_in10, v_temp10);
661 v_in00 = vrshlq_s32(v_in00, v_shf0);
663 v_in10 = vrshlq_s32(v_in10, v_shf0);
665 v_in00 = vaddq_s32(v_in00, v_c_offset);
667 v_in10 = vaddq_s32(v_in10, v_c_offset);
669 v_in00 = vmaxq_s32(v_in00, v_minval);
671 v_in10 = vmaxq_s32(v_in10, v_minval);
673 v_in00 = vminq_s32(v_in00, v_maxval);
675 v_in10 = vminq_s32(v_in10, v_maxval);
678 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
679 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
681 if (odds==1) {
break; }
683 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
684 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
686 if (odds==2) {
break; }
688 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
689 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
697 template<
typename Tin,
typename Tout>
699 const Tin *
input,
unsigned int in_stride, Tout *output,
unsigned int out_stride,
700 const int32_t *row_bias,
const int32_t *col_bias,
unsigned int start_col) {
701 if (qp.per_channel_requant) {
702 if (qp.minval >= qp.c_offset) {
703 if (qp.per_channel_left_shifts) {
704 requantize_block_32_int<false, true, true>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
705 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
707 requantize_block_32_int<false, true, false>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
708 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
711 if (qp.per_channel_left_shifts) {
712 requantize_block_32_int<true, true, true>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
713 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
715 requantize_block_32_int<true, true, false>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
716 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
720 if (qp.minval >= qp.c_offset) {
721 if (qp.per_layer_left_shift > 0) {
722 requantize_block_32_int<false, false, true>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
723 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
725 requantize_block_32_int<false, false, false>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
726 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
729 if (qp.per_layer_left_shift > 0) {
730 requantize_block_32_int<true, false, true>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
731 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
733 requantize_block_32_int<true, false, false>(qp, width, height,
reinterpret_cast<const int32_t *
>(
input), in_stride,
734 reinterpret_cast<int8_t *
>(output), out_stride, row_bias, col_bias, start_col);
740 template void requantize_block_32(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
741 const int32_t *
input,
unsigned int in_stride, int8_t *output,
unsigned int out_stride,
742 const int32_t *row_bias,
const int32_t *col_bias,
unsigned int start_col);
744 template void requantize_block_32(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
745 const uint32_t *
input,
unsigned int in_stride, uint8_t *output,
unsigned int out_stride,
746 const int32_t *row_bias,
const int32_t *col_bias,
unsigned int start_col);
784 struct row_sum_helpers {
785 const Requantize32 &qp;
789 inline int16x8_t accumulate_16(
const T *ptr, int16x8_t sum);
793 inline int16x8_t accumulate_masked_16(
const T *ptr, int16x8_t sum, uint64x2_t mask);
797 inline int16x8_t accumulate_masked_8(
const T *ptr, int16x8_t sum, uint64x2_t mask);
803 template<
unsigned int rows,
typename T>
804 void compute_some_rows(
unsigned int blocks,
const T *
input,
unsigned int in_stride, int32_t *row_bias,
unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
805 int16x8_t sums[
rows];
806 int32x4_t finalsums[
rows];
808 for (
unsigned int i=0; i<
rows; i++) {
809 sums[i] = vdupq_n_s16(0);
810 finalsums[i] = vdupq_n_s32(0);
813 for (
unsigned int i=0; i<blocks; i++) {
814 for (
unsigned int r=0; r<
rows; r++) {
832 if (i > 0 && ((i & 31) == 0)) {
833 finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
834 sums[r] = vdupq_n_s16(0);
836 sums[r] = accumulate_16(
input + (r * in_stride) + (i * 16), sums[r]);
842 for (
unsigned int r=0; r<
rows; r++) {
843 if (mask_mode == 1) {
844 sums[r] = accumulate_masked_8(
input + (r * in_stride) + (blocks * 16), sums[r], mask);
846 sums[r] = accumulate_masked_16(
input + (r * in_stride) + (blocks * 16), sums[r], mask);
851 for (
unsigned int i=0; i<
rows; i++) {
852 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
867 t0 = vmulq_s32(finalsums[0], offset_mul);
868 *row_bias = vaddvq_s32(t0);
875 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
876 t0 = vpaddq_s32(t0, t0);
877 t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
878 vst1_s32(row_bias, t2);
883 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
884 t1 = vpaddq_s32(finalsums[2], finalsums[2]);
886 t0 = vpaddq_s32(t0, t1);
887 t0 = vmulq_s32(t0, offset_mul);
889 vst1_s32(row_bias, vget_low_s32(t0));
890 row_bias[2] = vgetq_lane_s32(t0, 2);
896 t0 = vpaddq_s32(finalsums[0], finalsums[1]);
897 t1 = vpaddq_s32(finalsums[2], finalsums[3]);
899 t0 = vpaddq_s32(t0, t1);
900 t0 = vmulq_s32(t0, offset_mul);
902 vst1q_s32(row_bias, t0);
910 row_sum_helpers(
const Requantize32 &qp) : qp(qp) { }
914 int16x8_t row_sum_helpers::accumulate_16(
const uint8_t *ptr, int16x8_t sum) {
915 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
919 int16x8_t row_sum_helpers::accumulate_16(
const int8_t *ptr, int16x8_t sum) {
920 return vpadalq_s8(sum, vld1q_s8(ptr));
924 int16x8_t row_sum_helpers::accumulate_masked_16(
const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
925 int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
926 return vpadalq_s8(sum, v);
930 int16x8_t row_sum_helpers::accumulate_masked_16(
const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
931 uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
932 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
936 int16x8_t row_sum_helpers::accumulate_masked_8(
const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
937 int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
938 v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
939 return vpadalq_s8(sum, v);
943 int16x8_t row_sum_helpers::accumulate_masked_8(
const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
944 uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
945 v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
946 return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
951 void compute_row_sums(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
952 const T *
input,
unsigned int in_stride, int32_t *row_bias) {
954 if (qp.b_offset == 0) {
955 memset(row_bias, 0, height *
sizeof(int32_t));
959 row_sum_helpers thehelpers(qp);
961 const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
965 unsigned int blocks = (width / 16);
966 const unsigned int odds = width % 16;
970 unsigned int mask_mode = 0;
972 if (odds > 0 && odds <= 8) {
974 uint64_t maskval = (~0ULL) >> (8 * (8-odds));
976 mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
979 }
else if (odds > 8) {
981 uint64_t maskval = (~0ULL) >> (8 * (16-odds));
983 mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
988 for (
unsigned int row=0; row<height; row+=4) {
992 thehelpers.compute_some_rows<4>(blocks,
input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
995 thehelpers.compute_some_rows<3>(blocks,
input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
998 thehelpers.compute_some_rows<2>(blocks,
input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
1001 thehelpers.compute_some_rows<1>(blocks,
input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
1008 template void compute_row_sums(
const Requantize32 &,
unsigned int,
unsigned int,
const int8_t *,
unsigned int, int32_t *);
1009 template void compute_row_sums(
const Requantize32 &,
unsigned int,
unsigned int,
const uint8_t *,
unsigned int, int32_t *);
1011 template<
unsigned int active_rows,
typename T>
1012 inline void add_block(
const T *
input,
unsigned int in_stride, int32_t *output);
1014 template<
unsigned int active_rows>
1015 inline void add_block(
const uint8_t *
input,
unsigned int in_stride, int32_t *output) {
1016 uint8x16_t inputs[4];
1018 for (
unsigned int i=0; i<4; i++) {
1019 if (i < active_rows) {
1020 inputs[i] = vld1q_u8(
input + i * in_stride);
1022 inputs[i] = vdupq_n_u8(0);
1026 int16x8_t sums_16b[4];
1029 sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
1030 sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
1032 sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
1033 sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
1035 int32x4_t sums_32b[4];
1037 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1038 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1039 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1040 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1042 for (
unsigned int i=0; i<4; i++) {
1043 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1047 template<
unsigned int active_rows>
1048 inline void add_block(
const int8_t *
input,
unsigned int in_stride, int32_t *output) {
1049 int8x16_t inputs[4];
1051 for (
unsigned int i=0; i<4; i++) {
1052 if (i < active_rows) {
1053 inputs[i] = vld1q_s8(
input + i * in_stride);
1055 inputs[i] = vdupq_n_s8(0);
1059 int16x8_t sums_16b[4];
1062 sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
1063 sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
1065 sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
1066 sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
1068 int32x4_t sums_32b[4];
1070 sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1071 sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1072 sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1073 sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1075 for (
unsigned int i=0; i<4; i++) {
1076 vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1083 template<
typename T>
1084 void compute_col_sums(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
const T *
input,
unsigned int in_stride, int32_t *col_bias,
unsigned int depth,
unsigned int multi,
unsigned int first_col) {
1086 if (qp.a_offset != 0) {
1087 memset(
reinterpret_cast<void *
>(col_bias), 0, width *
sizeof(int32_t));
1089 for (
unsigned int row=0; row<height; row+=4) {
1090 unsigned int numrows=std::min(height-row, 4u);
1092 for (
unsigned int col=0; col<width; col+=16) {
1093 unsigned int numcols=std::min(width-col, 16u);
1098 add_block<1>(
input + row * in_stride + col, in_stride, col_bias + col);
1102 add_block<2>(
input + row * in_stride + col, in_stride, col_bias + col);
1106 add_block<3>(
input + row * in_stride + col, in_stride, col_bias + col);
1110 add_block<4>(
input + row * in_stride + col, in_stride, col_bias + col);
1117 for (; col<width; col++) {
1119 for (
unsigned int r=0; r<numrows; r++) {
1120 sum +=
input[(row + r)*in_stride + col];
1122 col_bias[col] += sum;
1129 for (
unsigned int col=0; col<width; col++) {
1130 int32_t result = col_bias[col];
1132 result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
1134 if (qp.bias !=
nullptr) {
1135 result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
1138 col_bias[col] = result;
1142 template void compute_col_sums(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
const int8_t *
input,
unsigned int in_stride, int32_t *col_bias,
unsigned int depth,
unsigned int multi,
unsigned int first_col);
1143 template void compute_col_sums(
const Requantize32 &qp,
unsigned int width,
unsigned int height,
const uint8_t *
input,
unsigned int in_stride, int32_t *col_bias,
unsigned int depth,
unsigned int multi,
unsigned int first_col);
1147 #endif // __aarch64__