50 inline int32x4x4_t load_results_input(
const Iterator &mm_result_it, int32_t x)
52 return {{vld1q_s32(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x + 0),
53 vld1q_s32(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x + 4),
54 vld1q_s32(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x + 8),
55 vld1q_s32(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x + 12)}};
58 inline int32x4x4_t load(
const int32_t *ptr, int32_t x)
60 return {{vld1q_s32(ptr + x + 0), vld1q_s32(ptr + x + 4), vld1q_s32(ptr + x + 8), vld1q_s32(ptr + x + 12)}};
63 inline int32x4x4_t add_s32(int32x4x4_t a, int32x4_t
b)
65 return {{vaddq_s32(a.val[0],
b), vaddq_s32(a.val[1],
b), vaddq_s32(a.val[2],
b), vaddq_s32(a.val[3],
b)}};
68 inline int32x4x4_t add_s32(int32x4x4_t a, int32x4x4_t
b)
70 return {{vaddq_s32(a.val[0],
b.val[0]), vaddq_s32(a.val[1],
b.val[1]), vaddq_s32(a.val[2],
b.val[2]),
71 vaddq_s32(a.val[3],
b.val[3])}};
74 inline int32x4x4_t mul_s32(int32x4x4_t &a, int32_t mul_scalar)
76 return {{vmulq_n_s32(a.val[0], mul_scalar), vmulq_n_s32(a.val[1], mul_scalar), vmulq_n_s32(a.val[2], mul_scalar),
77 vmulq_n_s32(a.val[3], mul_scalar)}};
80 inline int32x4x4_t mul_s32(int32x4x4_t &a,
const int32_t *multilpier)
82 return {{vmulq_s32(a.val[0], vld1q_s32(multilpier)), vmulq_s32(a.val[1], vld1q_s32(multilpier + 4)),
83 vmulq_s32(a.val[2], vld1q_s32(multilpier + 8)), vmulq_s32(a.val[3], vld1q_s32(multilpier + 12))}};
86 inline int32x4x4_t get_a_offset(
const int32_t *vector_sum_col_ptr, int32_t a_offset, int32_t x)
88 int32x4x4_t a_offset_term_s32 = load(vector_sum_col_ptr, x);
90 a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
91 a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
92 a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
93 a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
94 return a_offset_term_s32;
97 inline int32x4_t get_b_offset(
const int32_t *vector_sum_row_ptr, int32_t b_offset)
99 int32x4_t b_offset_term_s32 = vld1q_dup_s32(vector_sum_row_ptr);
100 b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, b_offset);
101 return b_offset_term_s32;
104 inline int32x4x4_t get_k_offset(int32_t k_offset)
106 return {{vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}};
109 inline uint8x16_t finalize_quantization_floating_point(
110 int32x4x4_t &in_s32, int32x4_t result_shift_s32, uint8x16_t min_u8, uint8x16_t max_u8,
bool is_bounded_relu)
112 const static int32x4_t zero_s32 = vdupq_n_s32(0);
115 in_s32.val[0] = vshlq_s32(in_s32.val[0], result_shift_s32);
116 in_s32.val[1] = vshlq_s32(in_s32.val[1], result_shift_s32);
117 in_s32.val[2] = vshlq_s32(in_s32.val[2], result_shift_s32);
118 in_s32.val[3] = vshlq_s32(in_s32.val[3], result_shift_s32);
121 in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
122 in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
123 in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
124 in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
127 const int16x8x2_t in_s16 = {{vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
128 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))}};
131 uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_s16.val[0]), vqmovun_s16(in_s16.val[1]));
135 out_u8 = vmaxq_u8(out_u8, min_u8);
136 out_u8 = vminq_u8(out_u8, max_u8);
142 inline int8x16_t finalize_quantization_floating_point(
143 int32x4x4_t &in_s32, int32x4_t result_shift_s32, int8x16_t min_s8, int8x16_t max_s8,
bool is_bounded_relu)
145 const static int32x4_t zero_s32 = vdupq_n_s32(0);
148 in_s32.val[0] = vshlq_s32(in_s32.val[0], result_shift_s32);
149 in_s32.val[1] = vshlq_s32(in_s32.val[1], result_shift_s32);
150 in_s32.val[2] = vshlq_s32(in_s32.val[2], result_shift_s32);
151 in_s32.val[3] = vshlq_s32(in_s32.val[3], result_shift_s32);
154 in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
155 in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
156 in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
157 in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
160 const int16x8x2_t in_s16 = {{vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
161 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))}};
164 int8x16_t out_s8 = vcombine_s8(vqmovn_s16(in_s16.val[0]), vqmovn_s16(in_s16.val[1]));
168 out_s8 = vmaxq_s8(out_s8, min_s8);
169 out_s8 = vminq_s8(out_s8, max_s8);
175 inline int8x16_t finalize_quantization_floating_point(
176 int32x4x4_t &in_s32, int32x4x4_t result_shift_s32, int8x16_t min_s8, int8x16_t max_s8,
bool is_bounded_relu)
178 const static int32x4_t zero_s32 = vdupq_n_s32(0);
181 in_s32.val[0] = vshlq_s32(in_s32.val[0], vnegq_s32(result_shift_s32.val[0]));
182 in_s32.val[1] = vshlq_s32(in_s32.val[1], vnegq_s32(result_shift_s32.val[1]));
183 in_s32.val[2] = vshlq_s32(in_s32.val[2], vnegq_s32(result_shift_s32.val[2]));
184 in_s32.val[3] = vshlq_s32(in_s32.val[3], vnegq_s32(result_shift_s32.val[3]));
187 in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
188 in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
189 in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
190 in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
193 const int16x8x2_t in_s16 = {{vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
194 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))}};
197 int8x16_t out_s8 = vcombine_s8(vqmovn_s16(in_s16.val[0]), vqmovn_s16(in_s16.val[1]));
201 out_s8 = vmaxq_s8(out_s8, min_s8);
202 out_s8 = vminq_s8(out_s8, max_s8);
208 template <
typename T>
212 using vtype =
typename wrapper::traits::neon_bitvector_t<T, wrapper::traits::BitWidth::W128>;
215 inline Window get_win_vector_sum(
const Window &window)
217 Window win_vector_sum(window);
218 win_vector_sum.set(
Window::DimY, Window::Dimension(0, 0, 0));
219 win_vector_sum.set(
Window::DimZ, Window::Dimension(0, 0, 0));
220 return win_vector_sum;
223 inline Iterator get_vector_sum_col_it(
const Window &window,
const ITensor *vector_sum_col)
225 Iterator vector_sum_col_it(vector_sum_col, get_win_vector_sum(window));
226 return vector_sum_col_it;
229 inline Iterator get_vector_sum_row_it(
const Window &window,
const ITensor *vector_sum_row)
231 Window win_vector_sum_row = get_win_vector_sum(window);
232 win_vector_sum_row.set(
Window::DimX, Window::Dimension(0, 0, 0));
233 Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
234 return vector_sum_row_it;
237 inline Iterator get_bias_it(
const Window &window,
const ITensor *
bias)
239 Window win_bias(window);
242 Iterator bias_it(
bias, win_bias);
246 template <
typename VT>
247 inline void run_offset_contribution_output_stage_window(
const int32_t *vector_sum_col_ptr,
248 const int32_t *vector_sum_row_ptr,
249 const int32_t *bias_ptr,
250 Iterator mm_result_it,
252 const int32x4_t result_offset_s32,
253 const int32x4_t result_shift_s32,
254 typename VT::vtype min_vec,
255 typename VT::vtype max_vec,
270 bool is_bounded_relu,
273 int32x4x4_t offset_term_s32 = {0, 0, 0, 0};
277 offset_term_s32 = add_s32(offset_term_s32, result_offset_s32);
279 if (has_a_offset && has_b_offset)
281 offset_term_s32 = add_s32(offset_term_s32, get_k_offset(k_offset));
285 offset_term_s32 = add_s32(offset_term_s32, get_b_offset(vector_sum_row_ptr, b_offset));
288 int x = window_start_x;
289 for (; x <= (window_end_x - window_step_x); x += window_step_x)
291 int32x4x4_t in_s32 = load_results_input(mm_result_it, x);
295 in_s32 = add_s32(in_s32, get_a_offset(vector_sum_col_ptr, a_offset, x));
299 in_s32 = add_s32(in_s32, load(bias_ptr, x));
301 if (!is_fixed_point || has_b_offset)
303 in_s32 = add_s32(in_s32, offset_term_s32);
307 in_s32 = mul_s32(in_s32, multiplier);
313 reinterpret_cast<typename VT::stype *
>(out_it.ptr() + x),
314 finalize_quantization(in_s32, multiplier, shift, result_offset_s32, min_vec, max_vec, is_bounded_relu));
319 reinterpret_cast<typename VT::stype *
>(out_it.ptr() + x),
320 finalize_quantization_floating_point(in_s32, result_shift_s32, min_vec, max_vec, is_bounded_relu));
324 for (; x < window_end_x; ++x)
327 *(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x) +
wrapper::vgetlane(offset_term_s32.val[0], 0);
331 in_value += (*(vector_sum_col_ptr + x) * a_offset);
335 in_value += *(bias_ptr + x);
341 *
reinterpret_cast<typename VT::stype *
>(out_it.ptr() + x) =
343 static_cast<typename VT::stype
>(max_bound), is_bounded_relu);
348 in_value = (in_value * multiplier) >> shift;
353 in_value =
static_cast<typename VT::stype
>(
354 std::max<int32_t>(min_bound, std::min<int32_t>(max_bound, in_value)));
356 *
reinterpret_cast<typename VT::stype *
>(out_it.ptr() + x) =
357 static_cast<typename VT::stype
>(std::max<int32_t>(
359 std::min<int32_t>(
static_cast<int32_t
>(std::numeric_limits<typename VT::stype>::max()), in_value)));
364 inline void run_offset_contribution_output_stage_window_symm(
const int32_t *vector_sum_col_ptr,
365 const int32_t *bias_ptr,
366 Iterator mm_result_it,
368 const int32_t *result_multipliers,
369 const int32_t *result_shifts,
370 const int32x4_t result_offset,
382 bool is_bounded_relu,
385 int32x4x4_t offset_term_s32 = {0, 0, 0, 0};
389 offset_term_s32 = add_s32(offset_term_s32, result_offset);
392 int x = window_start_x;
393 for (; x <= (window_end_x - window_step_x); x += window_step_x)
395 int32x4x4_t in_s32 = load_results_input(mm_result_it, x);
399 in_s32 = add_s32(in_s32, get_a_offset(vector_sum_col_ptr, a_offset, x));
403 in_s32 = add_s32(in_s32, load(bias_ptr, x));
407 in_s32 = add_s32(in_s32, offset_term_s32);
408 in_s32 = mul_s32(in_s32, result_multipliers + x);
413 vst1q_s8(
reinterpret_cast<int8_t *
>(out_it.ptr() + x),
415 result_offset, min_s8, max_s8, is_bounded_relu));
420 reinterpret_cast<int8_t *
>(out_it.ptr() + x),
421 finalize_quantization_floating_point(in_s32, load(result_shifts, x), min_s8, max_s8, is_bounded_relu));
425 for (; x < window_end_x; ++x)
428 *(
reinterpret_cast<const int32_t *
>(mm_result_it.ptr()) + x) +
wrapper::vgetlane(offset_term_s32.val[0], 0);
432 in_value += (*(vector_sum_col_ptr + x) * a_offset);
436 in_value += *(bias_ptr + x);
442 *(out_it.ptr() + x) =
444 static_cast<int8_t
>(min_bound),
static_cast<int8_t
>(max_bound), is_bounded_relu);
449 in_value = (in_value * result_multipliers[x]) >> (-result_shifts[x]);
454 in_value =
static_cast<int8_t
>(std::max<int32_t>(min_bound, std::min<int32_t>(max_bound, in_value)));
456 *(out_it.ptr() + x) =
static_cast<int8_t
>(std::max<int32_t>(-128, std::min<int32_t>(127, in_value)));
461 template <
typename T>
462 void run_offset_contribution_output_stage(
const Window &window,
463 const ITensor *mm_result,
464 const ITensor *vector_sum_col,
465 const ITensor *vector_sum_row,
471 bool is_vector_sum_col_batched,
474 bool is_bounded_relu,
486 using ExactTagType =
typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
487 using Typer = VectorTyper<T>;
489 const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
490 const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
492 const int32_t multiplier =
output_stage.gemmlowp_multiplier;
495 const int32_t min_bound =
output_stage.gemmlowp_min_bound;
496 const int32_t max_bound =
output_stage.gemmlowp_max_bound;
498 const int32x4_t result_offset_s32 = vdupq_n_s32(
offset);
499 const int32x4_t result_shift_s32 = vdupq_n_s32(is_fixed_point ? shift : -shift);
500 const auto min_vec =
wrapper::vdup_n(
static_cast<T
>(min_bound), ExactTagType{});
501 const auto max_vec =
wrapper::vdup_n(
static_cast<T
>(max_bound), ExactTagType{});
503 const int window_step_x = 16;
504 const auto window_start_x =
static_cast<int>(window.x().start());
505 const auto window_end_x =
static_cast<int>(window.x().end());
510 Window collapsed_window = win.collapse_if_possible(win,
Window::DimZ);
512 Iterator mm_result_it(mm_result, win);
513 Iterator out_it(output, win);
515 if ((a_offset != 0) && (b_offset != 0))
520 Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
521 Iterator vector_sum_row_it = get_vector_sum_row_it(collapsed_window, vector_sum_row);
523 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
526 const int vector_sum_col_stride_batch =
527 is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
531 Iterator bias_it = get_bias_it(collapsed_window,
bias);
534 [&](
const Coordinates &
id)
536 const int batch_id =
id.z() / depth_input;
537 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
538 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
539 const auto vector_sum_row_ptr =
540 reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
541 id.y() + (
id.z() % depth_input) * height_input;
542 run_offset_contribution_output_stage_window<Typer>(
543 vector_sum_col_ptr, vector_sum_row_ptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()),
544 mm_result_it, out_it, result_offset_s32, result_shift_s32, min_vec, max_vec, a_offset, b_offset,
545 k_offset, multiplier, shift,
offset, min_bound, max_bound, window_step_x, window_start_x,
546 window_end_x,
true,
true,
true, is_bounded_relu, is_fixed_point);
548 vector_sum_col_it, vector_sum_row_it, bias_it, mm_result_it, out_it);
554 [&](
const Coordinates &
id)
556 const int batch_id =
id.z() / depth_input;
557 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
558 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
559 const auto vector_sum_row_ptr =
560 reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
561 id.y() + (
id.z() % depth_input) * height_input;
562 run_offset_contribution_output_stage_window<Typer>(
563 vector_sum_col_ptr, vector_sum_row_ptr,
nullptr, mm_result_it, out_it, result_offset_s32,
564 result_shift_s32, min_vec, max_vec, a_offset, b_offset, k_offset, multiplier, shift,
offset,
565 min_bound, max_bound, window_step_x, window_start_x, window_end_x,
true,
true,
false,
566 is_bounded_relu, is_fixed_point);
568 vector_sum_col_it, vector_sum_row_it, mm_result_it, out_it);
571 else if ((a_offset == 0) && (b_offset != 0))
575 Iterator vector_sum_row_it = get_vector_sum_row_it(collapsed_window, vector_sum_row);
577 const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
581 Iterator bias_it = get_bias_it(collapsed_window,
bias);
584 [&](
const Coordinates &
id)
586 const int batch_id =
id.z() / depth_input;
587 const auto vector_sum_row_ptr =
588 reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
589 id.y() + (
id.z() % depth_input) * height_input;
590 run_offset_contribution_output_stage_window<Typer>(
591 nullptr, vector_sum_row_ptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()), mm_result_it,
592 out_it, result_offset_s32, result_shift_s32, min_vec, max_vec, a_offset, b_offset, k_offset,
593 multiplier, shift,
offset, min_bound, max_bound, window_step_x, window_start_x, window_end_x,
594 false,
true,
true, is_bounded_relu, is_fixed_point);
596 vector_sum_row_it, bias_it, mm_result_it, out_it);
602 [&](
const Coordinates &
id)
604 const int batch_id =
id.z() / depth_input;
605 const auto vector_sum_row_ptr =
606 reinterpret_cast<const int32_t *
>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
607 id.y() + (
id.z() % depth_input) * height_input;
608 run_offset_contribution_output_stage_window<Typer>(
609 nullptr, vector_sum_row_ptr,
nullptr, mm_result_it, out_it, result_offset_s32, result_shift_s32,
610 min_vec, max_vec, a_offset, b_offset, k_offset, multiplier, shift,
offset, min_bound, max_bound,
611 window_step_x, window_start_x, window_end_x,
false,
true,
false, is_bounded_relu,
614 vector_sum_row_it, mm_result_it, out_it);
617 else if ((a_offset != 0) && (b_offset == 0))
621 Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
624 const int vector_sum_col_stride_batch =
625 is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
629 Iterator bias_it = get_bias_it(collapsed_window,
bias);
632 [&](
const Coordinates &
id)
634 const int batch_id =
id.z() / depth_input;
635 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
636 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
637 run_offset_contribution_output_stage_window<Typer>(
638 vector_sum_col_ptr,
nullptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()), mm_result_it,
639 out_it, result_offset_s32, result_shift_s32, min_vec, max_vec, a_offset, b_offset, k_offset,
640 multiplier, shift,
offset, min_bound, max_bound, window_step_x, window_start_x, window_end_x,
641 true,
false,
true, is_bounded_relu, is_fixed_point);
643 vector_sum_col_it, bias_it, mm_result_it, out_it);
649 [&](
const Coordinates &
id)
651 const int batch_id =
id.z() / depth_input;
652 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
653 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
654 run_offset_contribution_output_stage_window<Typer>(
655 vector_sum_col_ptr,
nullptr,
nullptr, mm_result_it, out_it, result_offset_s32, result_shift_s32,
656 min_vec, max_vec, a_offset, b_offset, k_offset, multiplier, shift,
offset, min_bound, max_bound,
657 window_step_x, window_start_x, window_end_x,
true,
false,
false, is_bounded_relu,
660 vector_sum_col_it, mm_result_it, out_it);
667 Iterator bias_it = get_bias_it(collapsed_window,
bias);
670 [&](
const Coordinates &)
672 run_offset_contribution_output_stage_window<Typer>(
673 nullptr,
nullptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()), mm_result_it, out_it,
674 result_offset_s32, result_shift_s32, min_vec, max_vec, a_offset, b_offset, k_offset, multiplier,
675 shift,
offset, min_bound, max_bound, window_step_x, window_start_x, window_end_x,
false,
false,
676 true, is_bounded_relu, is_fixed_point);
678 bias_it, mm_result_it, out_it);
684 [&](
const Coordinates &)
686 run_offset_contribution_output_stage_window<Typer>(
687 nullptr,
nullptr,
nullptr, mm_result_it, out_it, result_offset_s32, result_shift_s32, min_vec,
688 max_vec, a_offset, b_offset, k_offset, multiplier, shift,
offset, min_bound, max_bound,
689 window_step_x, window_start_x, window_end_x,
false,
false,
false, is_bounded_relu,
692 mm_result_it, out_it);
698 void run_offset_contribution_output_stage_symm(
const Window &window,
699 const ITensor *mm_result,
700 const ITensor *vector_sum_col,
701 const ITensor *vector_sum_row,
707 bool is_vector_sum_col_batched,
710 bool is_bounded_relu,
715 const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
718 const int32_t min_bound =
output_stage.gemmlowp_min_bound;
719 const int32_t max_bound =
output_stage.gemmlowp_max_bound;
721 const int32_t *result_multipliers =
output_stage.gemmlowp_multipliers.data();
722 const int32_t *result_shifts =
output_stage.gemmlowp_shifts.data();
723 const int32x4_t result_offset_s32 = vdupq_n_s32(
offset);
724 const int8x16_t min_s8 = vdupq_n_s8(
static_cast<int8_t
>(min_bound));
725 const int8x16_t max_s8 = vdupq_n_s8(
static_cast<int8_t
>(max_bound));
727 const int window_step_x = 16;
728 const auto window_start_x =
static_cast<int>(window.x().start());
729 const auto window_end_x =
static_cast<int>(window.x().end());
734 Window collapsed_window = win.collapse_if_possible(win,
Window::DimZ);
736 Iterator mm_result_it(mm_result, win);
737 Iterator out_it(output, win);
743 Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
746 const int vector_sum_col_stride_batch =
747 is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
751 Iterator bias_it = get_bias_it(collapsed_window,
bias);
754 [&](
const Coordinates &
id)
756 const int batch_id =
id.z() / depth_input;
757 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
758 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
759 run_offset_contribution_output_stage_window_symm(
760 vector_sum_col_ptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()), mm_result_it, out_it,
761 result_multipliers, result_shifts, result_offset_s32, min_s8, max_s8, a_offset,
offset,
762 min_bound, max_bound, window_step_x, window_start_x, window_end_x,
true,
true, is_bounded_relu,
765 vector_sum_col_it, bias_it, mm_result_it, out_it);
771 [&](
const Coordinates &
id)
773 const int batch_id =
id.z() / depth_input;
774 const auto vector_sum_col_ptr =
reinterpret_cast<const int32_t *
>(
775 vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
776 run_offset_contribution_output_stage_window_symm(
777 vector_sum_col_ptr,
nullptr, mm_result_it, out_it, result_multipliers, result_shifts,
778 result_offset_s32, min_s8, max_s8, a_offset,
offset, min_bound, max_bound, window_step_x,
779 window_start_x, window_end_x,
true,
false, is_bounded_relu, is_fixed_point);
781 vector_sum_col_it, mm_result_it, out_it);
788 Iterator bias_it = get_bias_it(collapsed_window,
bias);
791 [&](
const Coordinates &)
793 run_offset_contribution_output_stage_window_symm(
794 nullptr,
reinterpret_cast<const int32_t *
>(bias_it.ptr()), mm_result_it, out_it,
795 result_multipliers, result_shifts, result_offset_s32, min_s8, max_s8, a_offset,
offset,
796 min_bound, max_bound, window_step_x, window_start_x, window_end_x,
false,
true, is_bounded_relu,
799 bias_it, mm_result_it, out_it);
805 [&](
const Coordinates &)
807 run_offset_contribution_output_stage_window_symm(
808 nullptr,
nullptr, mm_result_it, out_it, result_multipliers, result_shifts, result_offset_s32,
809 min_s8, max_s8, a_offset,
offset, min_bound, max_bound, window_step_x, window_start_x,
810 window_end_x,
false,
false, is_bounded_relu, is_fixed_point);
812 mm_result_it, out_it);
819 const ITensorInfo *vector_sum_col,
820 const ITensorInfo *vector_sum_row,
821 const ITensorInfo *
bias,
822 const ITensorInfo *output,
858 const bool reinterpret_as_3d =
859 mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
863 (mm_result->dimension(1) * mm_result->dimension(2)));
869 const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
871 TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
876 "mm_result tensor must have the same number of batches of output tensor");
880 TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
881 vector_sum_col_shape.collapse_from(1);
884 vector_sum_col_shape[1] != vector_sum_row_shape[1],
885 "vector_sum_col tensor must have the same number of batches of "
886 "vector_sum_row_shape or the number of batches must be set to 1");
894 if (output->total_size() != 0)
920 _a_offset = a_offset;
921 _b_offset = b_offset;
922 _k_offset = a_offset * b_offset * k;
943 ICpuKernel::configure(win);
977 std::tie(type_min, type_max) =
get_min_max(
dst->info()->data_type());
978 int32_t type_min_int = type_min.get<int32_t>();
979 int32_t type_max_int = type_max.get<int32_t>();
981 const bool reinterpret_as_3d = vector_sum_row !=
nullptr && mm_result->info()->num_dimensions() > 1 &&
982 mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
984 const bool is_bounded_relu =
998 run_offset_contribution_output_stage_symm(
window, mm_result, vector_sum_col, vector_sum_row,
bias,
dst,
999 _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched,
1000 _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
1006 run_offset_contribution_output_stage<int8_t>(
1007 window, mm_result, vector_sum_col, vector_sum_row,
bias,
dst, _a_offset, _b_offset, _k_offset,
1008 _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
1012 run_offset_contribution_output_stage<uint8_t>(
1013 window, mm_result, vector_sum_col, vector_sum_row,
bias,
dst, _a_offset, _b_offset, _k_offset,
1014 _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
1021 return "CpuGemmLowpOffsetContributionOutputStageKernel";