30 inline void perform_bias_addition(uchar *bias_ptr, uint bias_offset_first_element_in_bytes,
TILE(DATA_TYPE, M0, N0, acc), uint x)
32 TILE(DATA_TYPE, 1, N0, bias_tile);
35 T_LOAD(DATA_TYPE, 1, N0, BUFFER,
bias, x, 0, 1, 0, bias_tile);
40 #endif // defined(BIAS)
42 #if defined(MAT_MUL_NATIVE_MMUL_NT_NT)
90 __kernel
void mat_mul_native_mmul_nt_nt(
101 #define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0) // MMUL block size for the output matrix
117 const uint x0 = get_global_id(0);
119 const uint y0 = get_global_id(1);
120 const uint z = get_global_id(2);
124 const uint section_y = y0;
292 const uint thread_x = thread_id % MMUL_N0;
293 const uint thread_y = (thread_id / MMUL_N0);
304 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
305 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
306 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
307 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
310 const uint lhs_x = thread_x;
311 const uint lhs_y = dst_y;
314 const uint rhs_x = dst_x;
315 const uint rhs_y = thread_y;
318 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
319 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
320 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
324 TILE(
float, M0, N0, c_f32);
331 for(
int k = 0; k <
K; k += MMUL_K0)
334 TILE(DATA_TYPE, M0, 1, a);
336 TILE(DATA_TYPE, 1, N0,
b);
339 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
340 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
346 c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0],
b[0].s[n0], c_f32[m0].s[n0]);
350 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
351 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
355 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
360 #if defined(HALF_PRECISION)
361 TILE(DATA_TYPE, M0, N0, c);
368 c[m0].s[n0] = c_f32[m0].s[n0];
376 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
377 #endif // defined(BIAS)
379 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
383 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
386 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
394 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
397 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
402 #undef MMUL_BLOCK_SIZE
404 #endif // defined(MAT_MUL_NATIVE_MMUL_NT_NT)
406 #if defined(MAT_MUL_NATIVE_MMUL_T_NT)
455 __kernel
void mat_mul_native_mmul_t_nt(
466 #define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0)
469 const uint x0 = get_global_id(0);
471 const uint y0 = get_global_id(1);
472 const uint z = get_global_id(2);
476 const uint section_y = y0;
480 uint thread_x = thread_id % MMUL_N0;
481 uint thread_y = (thread_id / MMUL_N0);
484 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
485 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
486 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
487 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
491 uint lhs_y = thread_x;
495 uint rhs_y = thread_y;
498 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
499 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
500 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
504 TILE(
float, M0, N0, c_f32);
511 for(
int k = 0; k <
K; k += MMUL_K0)
513 TILE(DATA_TYPE, 1, M0, a);
514 TILE(DATA_TYPE, 1, N0,
b);
517 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
518 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
524 c_f32[m0].s[n0] = arm_matrix_multiply(a[0].s[m0],
b[0].s[n0], c_f32[m0].s[n0]);
528 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
529 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
533 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
538 #if defined(HALF_PRECISION)
539 TILE(DATA_TYPE, M0, N0, c);
546 c[m0].s[n0] = c_f32[m0].s[n0];
554 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
555 #endif // defined(BIAS)
557 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
561 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
564 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
572 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
575 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
580 #undef MMUL_BLOCK_SIZE
582 #endif // defined(MAT_MUL_NATIVE_MMUL_T_NT)
584 #if defined(MAT_MUL_NATIVE_MMUL_NT_T)
632 __kernel
void mat_mul_native_mmul_nt_t(
643 #define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0)
646 const uint x0 = get_global_id(0);
648 const uint y0 = get_global_id(1);
649 const uint z = get_global_id(2);
653 const uint section_y = y0;
657 const uint thread_x = thread_id % MMUL_N0;
658 const uint thread_y = (thread_id / MMUL_N0);
664 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
665 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
666 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
667 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
670 const uint lhs_x = thread_x;
671 const uint lhs_y = dst_y;
674 const uint rhs_x = thread_y;
675 const uint rhs_y = dst_x;
678 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
679 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
680 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
684 TILE(
float, M0, N0, c_f32);
691 for(
int k = 0; k <
K; k += MMUL_K0)
694 TILE(DATA_TYPE, M0, 1, a);
696 TILE(DATA_TYPE, N0, 1,
b);
699 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
700 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
706 c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0],
b[n0].s[0], c_f32[m0].s[n0]);
710 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
711 rhs_offset_first_element_in_bytes += MMUL_N0 *
sizeof(DATA_TYPE);
715 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
720 #if defined(HALF_PRECISION)
721 TILE(DATA_TYPE, M0, N0, c);
728 c[m0].s[n0] = c_f32[m0].s[n0];
736 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
737 #endif // defined(BIAS)
739 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
743 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
746 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
754 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
757 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
762 #undef MMUL_BLOCK_SIZE
764 #endif // defined(MAT_MUL_NATIVE_MMUL_NT_T)
766 #if defined(MAT_MUL_NATIVE_MMUL_T_T)
814 __kernel
void mat_mul_native_mmul_t_t(
825 #define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0)
828 const uint x0 = get_global_id(0);
830 const uint y0 = get_global_id(1);
831 const uint z = get_global_id(2);
835 const uint section_y = y0;
839 const uint thread_x = thread_id % MMUL_N0;
840 const uint thread_y = (thread_id / MMUL_N0);
846 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
847 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
848 const uint dst_x = min(dst_x_unclamped, (uint)(
N - N0));
849 const uint dst_y = min(dst_y_unclamped, (uint)(
M - M0));
852 const uint lhs_x = dst_y;
853 const uint lhs_y = thread_x;
856 const uint rhs_x = thread_y;
857 const uint rhs_y = dst_x;
860 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
861 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
862 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
866 TILE(
float, M0, N0, c_f32);
873 for(
int k = 0; k <
K; k += MMUL_K0)
876 TILE(DATA_TYPE, 1, M0, a);
878 TILE(DATA_TYPE, N0, 1,
b);
881 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
882 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y,
b);
888 c_f32[m0].s[n0] = arm_matrix_multiply(a[0].s[m0],
b[n0].s[0], c_f32[m0].s[n0]);
892 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
893 rhs_offset_first_element_in_bytes += MMUL_N0 *
sizeof(DATA_TYPE);
897 if(dst_x_unclamped >=
N || dst_y_unclamped >=
M)
902 #if defined(HALF_PRECISION)
903 TILE(DATA_TYPE, M0, N0, c);
910 c[m0].s[n0] = c_f32[m0].s[n0];
918 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
919 #endif // defined(BIAS)
921 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
925 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
928 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
936 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
939 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
944 #undef MMUL_BLOCK_SIZE
946 #endif // defined(MAT_MUL_NATIVE_MMUL_T_T)