31 inline void perform_bias_addition(uchar *bias_ptr, uint bias_offset_first_element_in_bytes,
TILE(
int, M0, N0, acc), uint x)
33 TILE(
int, 1, N0, bias_tile);
36 T_LOAD(
int, 1, N0, BUFFER,
bias, x, 0, 1, 0, bias_tile);
41 #endif // defined(BIAS)
43 #define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0) // MMUL block size for the output matrix
45 #if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
100 __kernel
void mat_mul_native_quantized_mmul_nt_nt(
211 const uint x0 = get_global_id(0);
213 const uint y0 = get_global_id(1);
214 const uint z = get_global_id(2);
218 const uint section_y = y0;
222 const uint thread_x = thread_id % MMUL_N0;
223 const uint thread_y = (thread_id / MMUL_N0);
226 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
227 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
228 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
229 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
232 const uint lhs_x = K0 * thread_x;
233 const uint lhs_y = dst_y;
236 const uint rhs_x = dst_x;
237 const uint rhs_y = K0 * thread_y;
240 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
241 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
242 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
245 TILE(
int, M0, N0, c);
248 c[i].v =
K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
252 TILE(
int, 1, N0, b_sum);
255 TILE(
int, 1, M0, a_sum);
261 for(
int k = 0; k < lhs_w; k += MMUL_K0)
264 TILE(DATA_TYPE, M0, K0, a);
266 TILE(DATA_TYPE, K0, N0,
b);
269 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
270 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
275 vec_b = (
VEC_DATA_TYPE(DATA_TYPE, K0))(
b[0].s[n0],
b[1].s[n0],
b[2].s[n0],
b[3].s[n0]);
279 c[m0].s[n0] = arm_matrix_multiply(a[m0].v, vec_b, c[m0].s[n0]);
285 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
286 #endif // LHS_OFFSET != 0s
294 a_sum[0].s[m0] = arm_matrix_multiply(a[m0].v, vec_1, a_sum[0].s[m0]);
298 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
299 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
304 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
309 #if RHS_OFFSET != 0 || LHS_OFFSET != 0
312 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
315 c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
321 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
322 #endif // defined(BIAS)
325 TILE(DATA_TYPE, M0, N0, cq);
328 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
332 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
335 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
343 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
346 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
351 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
353 #if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_T)
363 __kernel
void mat_mul_native_quantized_mmul_nt_t(
371 const uint x0 = get_global_id(0);
373 const uint y0 = get_global_id(1);
374 const uint z = get_global_id(2);
378 const uint section_y = y0;
382 const uint thread_x = thread_id % MMUL_N0;
383 const uint thread_y = (thread_id / MMUL_N0);
386 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
387 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
388 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
389 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
392 const uint lhs_x = K0 * thread_x;
393 const uint lhs_y = dst_y;
396 const uint rhs_x = K0 * thread_y;
397 const uint rhs_y = dst_x;
400 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
401 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
402 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
405 TILE(
int, M0, N0, c);
408 c[i].v =
K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
412 TILE(
int, 1, N0, b_sum);
415 TILE(
int, 1, M0, a_sum);
421 for(
int k = 0; k < lhs_w; k += MMUL_K0)
424 TILE(DATA_TYPE, M0, K0, a);
426 TILE(DATA_TYPE, N0, K0,
b);
429 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
430 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
436 c[m0].s[n0] = arm_matrix_multiply(a[m0].v,
b[n0].v, c[m0].s[n0]);
445 a_sum[0].s[m0] = arm_matrix_multiply(a[m0].v, vec_1, a_sum[0].s[m0]);
454 b_sum[0].s[n0] = arm_matrix_multiply(vec_1,
b[n0].v, b_sum[0].s[n0]);
458 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
459 rhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
464 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
469 #if RHS_OFFSET != 0 || LHS_OFFSET != 0
472 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
475 c[i].s[j] -=
A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
481 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
482 #endif // defined(BIAS)
485 TILE(DATA_TYPE, M0, N0, cq);
488 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
492 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
495 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
503 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
506 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
511 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_T)
513 #if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_NT)
523 __kernel
void mat_mul_native_quantized_mmul_t_nt(
531 const uint x0 = get_global_id(0);
533 const uint y0 = get_global_id(1);
534 const uint z = get_global_id(2);
538 const uint section_y = y0;
542 const uint thread_x = thread_id % MMUL_N0;
543 const uint thread_y = (thread_id / MMUL_N0);
546 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
547 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
548 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
549 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
552 const uint lhs_x = dst_y;
553 const uint lhs_y = K0 * thread_x;
556 const uint rhs_x = dst_x;
557 const uint rhs_y = K0 * thread_y;
560 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
561 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
562 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
565 TILE(
int, M0, N0, c);
568 c[i].v =
K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
572 TILE(
int, 1, N0, b_sum);
575 TILE(
int, 1, M0, a_sum);
581 for(
int k = 0; k < lhs_h; k += MMUL_K0)
583 TILE(DATA_TYPE, K0, M0, a);
584 TILE(DATA_TYPE, K0, N0,
b);
587 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
588 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
593 vec_a = (
VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
598 vec_b = (
VEC_DATA_TYPE(DATA_TYPE, K0))(
b[0].s[n0],
b[1].s[n0],
b[2].s[n0],
b[3].s[n0]);
600 c[m0].s[n0] = arm_matrix_multiply(vec_a, vec_b, c[m0].s[n0]);
606 a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
607 #endif // RHS_OFFSET != 0
616 vec_b = (
VEC_DATA_TYPE(DATA_TYPE, K0))(
b[0].s[n0],
b[1].s[n0],
b[2].s[n0],
b[3].s[n0]);
618 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
622 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
623 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
628 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
633 #if RHS_OFFSET != 0 || LHS_OFFSET != 0
636 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
639 c[i].s[j] -=
A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
645 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
646 #endif // defined(BIAS)
649 TILE(DATA_TYPE, M0, N0, cq);
652 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
656 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
659 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
667 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
670 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
675 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_NT)
677 #if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_T)
687 __kernel
void mat_mul_native_quantized_mmul_t_t(
695 const uint x0 = get_global_id(0);
697 const uint y0 = get_global_id(1);
698 const uint z = get_global_id(2);
702 const uint section_y = y0;
706 const uint thread_x = thread_id % MMUL_N0;
707 const uint thread_y = (thread_id / MMUL_N0);
710 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
711 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
712 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
713 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
716 const uint lhs_x = dst_y;
717 const uint lhs_y = K0 * thread_x;
720 const uint rhs_x = K0 * thread_y;
721 const uint rhs_y = dst_x;
724 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
725 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
726 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
729 TILE(
int, M0, N0, c);
732 c[i].v =
K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
736 TILE(
int, 1, N0, b_sum);
739 TILE(
int, 1, M0, a_sum);
745 for(
int k = 0; k < lhs_h; k += MMUL_K0)
747 TILE(DATA_TYPE, K0, M0, a);
748 TILE(DATA_TYPE, N0, K0,
b);
751 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
752 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
757 vec_a = (
VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
761 c[m0].s[n0] = arm_matrix_multiply(vec_a,
b[n0].v, c[m0].s[n0]);
766 a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
767 #endif // RHS_OFFSET != 0
775 b_sum[0].s[n0] = arm_matrix_multiply(vec_1,
b[n0].v, b_sum[0].s[n0]);
779 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
780 rhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
785 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
790 #if RHS_OFFSET != 0 || LHS_OFFSET != 0
793 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
796 c[i].s[j] -=
A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
802 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
803 #endif // defined(BIAS)
806 TILE(DATA_TYPE, M0, N0, cq);
809 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
813 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
816 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
824 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
827 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
832 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_T)