27 #if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE)
29 #define CONCAT(a, b) a##b
31 #define ARM_DOT1(a, b, c) \
35 #define ARM_DOT2(a, b, c) \
37 c = fma(a.s0, b.s0, c); \
38 c = fma(a.s1, b.s1, c); \
40 #define ARM_DOT3(a, b, c) \
43 c = fma((a.s2), (b.s2), c); \
45 #define ARM_DOT4(a, b, c) \
48 c = fma((a.s3), (b.s3), c); \
50 #define ARM_DOT8(a, b, c) \
52 ARM_DOT4((a.lo), (b.lo), c); \
53 ARM_DOT4((a.hi), (b.hi), c); \
55 #define ARM_DOT16(a, b, c) \
57 ARM_DOT8((a.lo), (b.lo), c); \
58 ARM_DOT8((a.hi), (b.hi), c); \
62 #define ARM_DOT_K0XN0(k0, a, b, c) \
65 ((a), (b##0), (c.s0)); \
67 ((a), (b##1), (c.s1)); \
69 #elif N0 == 3 // N0 == 3
70 #define ARM_DOT_K0XN0(k0, a, b, c) \
73 ((a), (b##0), (c.s0)); \
75 ((a), (b##1), (c.s1)); \
77 ((a), (b##2), (c.s2)); \
79 #elif N0 == 4 // N0 == 4
80 #define ARM_DOT_K0XN0(k0, a, b, c) \
83 ((a), (b##0), (c.s0)); \
85 ((a), (b##1), (c.s1)); \
87 ((a), (b##2), (c.s2)); \
89 ((a), (b##3), (c.s3)); \
91 #elif N0 == 8 // N0 == 8
92 #define ARM_DOT_K0XN0(k0, a, b, c) \
95 ((a), (b##0), (c.s0)); \
97 ((a), (b##1), (c.s1)); \
99 ((a), (b##2), (c.s2)); \
100 CONCAT(ARM_DOT, k0) \
101 ((a), (b##3), (c.s3)); \
102 CONCAT(ARM_DOT, k0) \
103 ((a), (b##4), (c.s4)); \
104 CONCAT(ARM_DOT, k0) \
105 ((a), (b##5), (c.s5)); \
106 CONCAT(ARM_DOT, k0) \
107 ((a), (b##6), (c.s6)); \
108 CONCAT(ARM_DOT, k0) \
109 ((a), (b##7), (c.s7)); \
111 #elif N0 == 16 // N0 == 16
112 #define ARM_DOT_K0XN0(k0, a, b, c) \
114 CONCAT(ARM_DOT, k0) \
115 ((a), (b##0), (c.s0)); \
116 CONCAT(ARM_DOT, k0) \
117 ((a), (b##1), (c.s1)); \
118 CONCAT(ARM_DOT, k0) \
119 ((a), (b##2), (c.s2)); \
120 CONCAT(ARM_DOT, k0) \
121 ((a), (b##3), (c.s3)); \
122 CONCAT(ARM_DOT, k0) \
123 ((a), (b##4), (c.s4)); \
124 CONCAT(ARM_DOT, k0) \
125 ((a), (b##5), (c.s5)); \
126 CONCAT(ARM_DOT, k0) \
127 ((a), (b##6), (c.s6)); \
128 CONCAT(ARM_DOT, k0) \
129 ((a), (b##7), (c.s7)); \
130 CONCAT(ARM_DOT, k0) \
131 ((a), (b##8), (c.s8)); \
132 CONCAT(ARM_DOT, k0) \
133 ((a), (b##9), (c.s9)); \
134 CONCAT(ARM_DOT, k0) \
135 ((a), (b##A), (c.sA)); \
136 CONCAT(ARM_DOT, k0) \
137 ((a), (b##B), (c.sB)); \
138 CONCAT(ARM_DOT, k0) \
139 ((a), (b##C), (c.sC)); \
140 CONCAT(ARM_DOT, k0) \
141 ((a), (b##D), (c.sD)); \
142 CONCAT(ARM_DOT, k0) \
143 ((a), (b##E), (c.sE)); \
144 CONCAT(ARM_DOT, k0) \
145 ((a), (b##F), (c.sF)); \
147 #else // N0 not supported
148 #error "N0 value not supported"
149 #endif // N0 conditions
151 #if defined(GEMM_MM_RESHAPED_ONLY_RHS_T)
225 #
if defined(REINTERPRET_INPUT_AS_3D)
227 uint lhs_cross_plane_pad
229 #
if defined(REINTERPRET_OUTPUT_AS_3D)
231 uint dst_cross_plane_pad
239 #define RHS_BLOCK_SIZE ((K0) * (N0))
242 #if defined(RHS_INTERLEAVE)
243 #define RHS_OFFSET_X (K0)
244 #define RHS_STEP_X ((K0) * (H0))
245 #define RHS_STEP_LOOP (1)
246 #else // defined(RHS_INTERLEAVE)
247 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
248 #define RHS_STEP_X (K0)
249 #define RHS_STEP_LOOP (H0)
250 #endif // defined(RHS_INTERLEAVE)
252 uint x = get_global_id(0);
253 uint y = get_global_id(1);
254 uint z = get_global_id(2);
256 const bool cond_y = y == 0;
257 const bool cond_x = ((x + 1) * N0 >=
N);
259 #if defined(DUMMY_WORK_ITEMS)
260 if((x * N0 >=
N) || (y * M0 >=
M))
264 #endif // defined(DUMMY_WORK_ITEMS)
270 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
272 #if defined(MATRIX_B_DEPTH)
274 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
275 #else // defined(MATRIX_B_DEPTH)
276 rhs_offset += z * rhs_stride_z;
277 #endif // defined(MATRIX_B_DEPTH)
282 #if defined(REINTERPRET_INPUT_AS_3D)
288 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
290 #else // defined(REINTERPRET_INPUT_AS_3D)
293 lhs_offset += z * lhs_stride_z;
295 #endif // defined(REINTERPRET_INPUT_AS_3D)
301 for(; i <= (
K - K0); i += K0)
313 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
316 LOAD_BLOCK(N0, K0, DATA_TYPE,
b, rhs_ptr, rhs_offset, RHS_STEP_X *
sizeof(DATA_TYPE), zero);
319 ARM_DOT_K0XN0(K0, a0,
b, c0);
321 ARM_DOT_K0XN0(K0, a1,
b, c1);
324 ARM_DOT_K0XN0(K0, a2,
b, c2);
327 ARM_DOT_K0XN0(K0, a3,
b, c3);
330 ARM_DOT_K0XN0(K0, a4,
b, c4);
333 ARM_DOT_K0XN0(K0, a5,
b, c5);
336 ARM_DOT_K0XN0(K0, a6,
b, c6);
339 ARM_DOT_K0XN0(K0, a7,
b, c7);
342 lhs_offset += K0 *
sizeof(DATA_TYPE);
343 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) *
sizeof(DATA_TYPE);
350 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
353 LOAD_BLOCK(N0, 1, DATA_TYPE,
b, rhs_ptr, rhs_offset, RHS_STEP_X *
sizeof(DATA_TYPE), zero);
356 ARM_DOT_K0XN0(1, a0,
b, c0);
358 ARM_DOT_K0XN0(1, a1,
b, c1);
361 ARM_DOT_K0XN0(1, a2,
b, c2);
364 ARM_DOT_K0XN0(1, a3,
b, c3);
367 ARM_DOT_K0XN0(1, a4,
b, c4);
370 ARM_DOT_K0XN0(1, a5,
b, c5);
373 ARM_DOT_K0XN0(1, a6,
b, c6);
376 ARM_DOT_K0XN0(1, a7,
b, c7);
379 lhs_offset +=
sizeof(DATA_TYPE);
380 rhs_offset +=
sizeof(DATA_TYPE);
387 #if defined(REINTERPRET_OUTPUT_AS_3D)
394 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
396 #else // defined(REINTERPRET_OUTPUT_AS_3D)
399 dst_addr += z * dst_stride_z;
401 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
406 #endif // defined(ALPHA)
410 #if defined(BROADCAST_BIAS)
411 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
413 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
422 #else // defined(BROADCAST_BIAS)
423 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (
COMPUTE_M0_START_ROW(y, M0,
PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
425 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
434 #endif // defined(BROADCAST_BIAS)
435 #endif // defined(BETA)
437 #if defined(ACTIVATION_TYPE)
439 #endif // defined(ACTIVATION_TYPE)
442 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
444 #undef RHS_BLOCK_SIZE
449 #endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_T)
451 #if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_T_TEXTURE)
513 __read_only image2d_t rhs_img,
524 #
if defined(REINTERPRET_INPUT_AS_3D)
526 uint lhs_cross_plane_pad
528 #
if defined(REINTERPRET_OUTPUT_AS_3D)
530 uint dst_cross_plane_pad
538 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
540 const uint LEFTOVER_K =
K % K0;
543 #define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
546 #if defined(RHS_INTERLEAVE)
547 #define RHS_OFFSET_X (PIXEL_UNIT)
548 #define RHS_STEP_X (PIXEL_UNIT * (H0))
549 #define RHS_STEP_LOOP (1)
550 #else // defined(RHS_INTERLEAVE)
551 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
552 #define RHS_STEP_X PIXEL_UNIT
553 #define RHS_STEP_LOOP (H0)
554 #endif // defined(RHS_INTERLEAVE)
556 uint x = get_global_id(0);
557 uint y = get_global_id(1);
558 uint z = get_global_id(2);
560 const bool cond_y = y == 0;
561 const bool cond_x = ((x + 1) * N0 >=
N);
563 #if defined(DUMMY_WORK_ITEMS)
564 if((x * N0 >=
N) || (y * M0 >=
M))
568 #endif // defined(DUMMY_WORK_ITEMS)
573 #if defined(MATRIX_B_DEPTH)
575 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
576 #else // defined(MATRIX_B_DEPTH)
577 const uint z_rhs = get_global_id(2);
578 #endif // defined(MATRIX_B_DEPTH)
581 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
582 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
587 #if defined(REINTERPRET_INPUT_AS_3D)
593 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
595 #else // defined(REINTERPRET_INPUT_AS_3D)
598 lhs_offset += z * lhs_stride_z;
600 #endif // defined(REINTERPRET_INPUT_AS_3D)
606 for(; i <= (
K - K0); i += K0)
609 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
613 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE,
b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
616 ARM_DOT_K0XN0(K0, a0,
b, c0);
618 ARM_DOT_K0XN0(K0, a1,
b, c1);
621 ARM_DOT_K0XN0(K0, a2,
b, c2);
624 ARM_DOT_K0XN0(K0, a3,
b, c3);
627 ARM_DOT_K0XN0(K0, a4,
b, c4);
630 ARM_DOT_K0XN0(K0, a5,
b, c5);
633 ARM_DOT_K0XN0(K0, a6,
b, c6);
636 ARM_DOT_K0XN0(K0, a7,
b, c7);
639 lhs_offset += K0 *
sizeof(DATA_TYPE);
640 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
656 union UNION_VEC_TYPE a0 = {.v = 0 };
658 union UNION_VEC_TYPE a1 = {.v = 0 };
661 union UNION_VEC_TYPE a2 = {.v = 0 };
664 union UNION_VEC_TYPE a3 = {.v = 0 };
667 union UNION_VEC_TYPE a4 = {.v = 0 };
670 union UNION_VEC_TYPE a5 = {.v = 0 };
673 union UNION_VEC_TYPE a6 = {.v = 0 };
676 union UNION_VEC_TYPE a7 = {.v = 0 };
682 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE,
b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
685 for(
int k = 0; k < LEFTOVER_K; ++k)
687 a0.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0);
689 a1.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1);
692 a2.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2);
695 a3.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3);
698 a4.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4);
701 a5.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5);
704 a6.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6);
707 a7.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7);
710 lhs_offset +=
sizeof(DATA_TYPE);
714 ARM_DOT_K0XN0(K0, a0.v,
b, c0);
716 ARM_DOT_K0XN0(K0, a1.v,
b, c1);
719 ARM_DOT_K0XN0(K0, a2.v,
b, c2);
722 ARM_DOT_K0XN0(K0, a3.v,
b, c3);
725 ARM_DOT_K0XN0(K0, a4.v,
b, c4);
728 ARM_DOT_K0XN0(K0, a5.v,
b, c5);
731 ARM_DOT_K0XN0(K0, a6.v,
b, c6);
734 ARM_DOT_K0XN0(K0, a7.v,
b, c7);
742 #if defined(REINTERPRET_OUTPUT_AS_3D)
749 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
751 #else // defined(REINTERPRET_OUTPUT_AS_3D)
754 dst_addr += z * dst_stride_z;
756 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
761 #endif // defined(ALPHA)
765 #if defined(BROADCAST_BIAS)
766 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
768 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
777 #else // defined(BROADCAST_BIAS)
778 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (
COMPUTE_M0_START_ROW(y, M0,
PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
780 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
789 #endif // defined(BROADCAST_BIAS)
790 #endif // defined(BETA)
792 #if defined(ACTIVATION_TYPE)
794 #endif // defined(ACTIVATION_TYPE)
797 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
799 #undef RHS_BLOCK_SIZE
805 #endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_T_TEXTURE)
807 #define VFMA(a, b, c) \
813 #define VFMA_M0xN0(i, a, b, c) \
815 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
817 #elif M0 == 2 // M0 == 2
818 #define VFMA_M0xN0(i, a, b, c) \
820 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
821 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
823 #elif M0 == 3 // M0 == 3
824 #define VFMA_M0xN0(i, a, b, c) \
826 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
827 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
828 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
830 #elif M0 == 4 // M0 == 4
831 #define VFMA_M0xN0(i, a, b, c) \
833 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
834 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
835 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
836 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
838 #elif M0 == 5 // M0 == 5
839 #define VFMA_M0xN0(i, a, b, c) \
841 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
842 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
843 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
844 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
845 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
847 #elif M0 == 6 // M0 == 6
848 #define VFMA_M0xN0(i, a, b, c) \
850 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
851 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
852 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
853 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
854 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
855 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
857 #elif M0 == 7 // M0 == 7
858 #define VFMA_M0xN0(i, a, b, c) \
860 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
861 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
862 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
863 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
864 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
865 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
866 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
868 #elif M0 == 8 // M0 == 8
869 #define VFMA_M0xN0(i, a, b, c) \
871 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
872 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
873 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
874 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
875 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
876 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
877 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
878 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
880 #else // M0 not supported
881 #error "M0 not supported"
882 #endif // M0 not supported
884 #if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT)
958 #
if defined(REINTERPRET_INPUT_AS_3D)
960 uint lhs_cross_plane_pad
962 #
if defined(REINTERPRET_OUTPUT_AS_3D)
964 uint dst_cross_plane_pad
972 #define RHS_BLOCK_SIZE ((K0) * (N0))
975 #if defined(RHS_INTERLEAVE)
976 #define RHS_OFFSET_X (N0)
977 #define RHS_STEP_X ((N0) * (H0))
978 #define RHS_STEP_LOOP (1)
979 #else // defined(RHS_INTERLEAVE)
980 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
981 #define RHS_STEP_X (N0)
982 #define RHS_STEP_LOOP (H0)
983 #endif // defined(RHS_INTERLEAVE)
985 uint x = get_global_id(0);
986 uint y = get_global_id(1);
987 uint z = get_global_id(2);
989 const bool cond_y = y == 0;
990 const bool cond_x = ((x + 1) * N0 >=
N);
992 #if defined(DUMMY_WORK_ITEMS)
993 if((x * N0 >=
N) || (y * M0 >=
M))
997 #endif // defined(DUMMY_WORK_ITEMS)
1003 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1005 #if defined(MATRIX_B_DEPTH)
1007 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1008 #else // defined(MATRIX_B_DEPTH)
1009 rhs_offset += z * rhs_stride_z;
1010 #endif // defined(MATRIX_B_DEPTH)
1015 #if defined(REINTERPRET_INPUT_AS_3D)
1022 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1024 #else // defined(REINTERPRET_INPUT_AS_3D)
1027 lhs_offset += z * lhs_stride_z;
1029 #endif // defined(REINTERPRET_INPUT_AS_3D)
1035 for(; i <= (
K - K0); i += K0)
1047 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1052 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1053 VFMA_M0xN0(0, a, b0, c);
1054 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1055 VFMA_M0xN0(1, a, b0, c);
1057 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1058 VFMA_M0xN0(2, a, b0, c);
1061 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1062 VFMA_M0xN0(3, a, b0, c);
1065 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1066 VFMA_M0xN0(4, a, b0, c);
1067 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1068 VFMA_M0xN0(5, a, b0, c);
1069 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1070 VFMA_M0xN0(6, a, b0, c);
1071 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1072 VFMA_M0xN0(7, a, b0, c);
1075 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1076 VFMA_M0xN0(8, a, b0, c);
1077 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1078 VFMA_M0xN0(9, a, b0, c);
1079 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1080 VFMA_M0xN0(A, a, b0, c);
1081 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1082 VFMA_M0xN0(B, a, b0, c);
1083 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1084 VFMA_M0xN0(C, a, b0, c);
1085 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1086 VFMA_M0xN0(D, a, b0, c);
1087 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1088 VFMA_M0xN0(E, a, b0, c);
1089 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1090 VFMA_M0xN0(F, a, b0, c);
1093 lhs_offset += K0 *
sizeof(DATA_TYPE);
1094 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP *
sizeof(DATA_TYPE);
1102 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1105 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1109 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1113 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1117 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1121 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1125 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1129 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1135 b0 =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X *
sizeof(DATA_TYPE)));
1136 VFMA_M0xN0(0, a, b0, c);
1138 lhs_offset +=
sizeof(DATA_TYPE);
1139 rhs_offset += RHS_STEP_X *
sizeof(DATA_TYPE);
1146 #if defined(REINTERPRET_OUTPUT_AS_3D)
1152 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1154 #else // defined(REINTERPRET_OUTPUT_AS_3D)
1157 dst_addr += z * dst_stride_z;
1159 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
1164 #endif // defined(ALPHA)
1168 #if defined(BROADCAST_BIAS)
1169 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
1171 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
1180 #else // defined(BROADCAST_BIAS)
1181 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (
COMPUTE_M0_START_ROW(y, M0,
PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
1183 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1192 #endif // defined(BROADCAST_BIAS)
1193 #endif // defined(BETA)
1195 #if defined(ACTIVATION_TYPE)
1197 #endif // defined(ACTIVATION_TYPE)
1200 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1202 #undef RHS_BLOCK_SIZE
1205 #undef RHS_STEP_LOOP
1207 #endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_NT)
1209 #if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_TEXTURE)
1271 __read_only image2d_t rhs_img,
1282 #
if defined(REINTERPRET_INPUT_AS_3D)
1284 uint lhs_cross_plane_pad
1286 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1288 uint dst_cross_plane_pad
1296 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
1299 #define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
1302 #if defined(RHS_INTERLEAVE)
1303 #define RHS_OFFSET_X (PIXEL_UNIT)
1304 #define RHS_STEP_X ((PIXEL_UNIT) * (H0))
1305 #define RHS_STEP_LOOP 1
1306 #else // defined(RHS_INTERLEAVE)
1307 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1308 #define RHS_STEP_X (PIXEL_UNIT)
1309 #define RHS_STEP_LOOP (H0)
1310 #endif // defined(RHS_INTERLEAVE)
1312 uint x = get_global_id(0);
1313 uint y = get_global_id(1);
1314 uint z = get_global_id(2);
1316 const bool cond_y = y == 0;
1317 const bool cond_x = ((x + 1) * N0 >=
N);
1319 #if defined(DUMMY_WORK_ITEMS)
1320 if((x * N0 >=
N) || (y * M0 >=
M))
1324 #endif // defined(DUMMY_WORK_ITEMS)
1329 #if defined(MATRIX_B_DEPTH)
1331 const uint z_rhs = (z % MATRIX_B_DEPTH);
1332 #else // defined(MATRIX_B_DEPTH)
1333 const uint z_rhs = z;
1334 #endif // defined(MATRIX_B_DEPTH)
1337 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
1338 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
1343 #if defined(REINTERPRET_INPUT_AS_3D)
1350 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1352 #else // defined(REINTERPRET_INPUT_AS_3D)
1355 lhs_offset += z * lhs_stride_z;
1357 #endif // defined(REINTERPRET_INPUT_AS_3D)
1363 for(; i <= (
K - K0); i += K0)
1366 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1371 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1372 VFMA_M0xN0(0, a, b0, c);
1373 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
1374 VFMA_M0xN0(1, a, b0, c);
1376 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
1377 VFMA_M0xN0(2, a, b0, c);
1380 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
1381 VFMA_M0xN0(3, a, b0, c);
1384 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
1385 VFMA_M0xN0(4, a, b0, c);
1386 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
1387 VFMA_M0xN0(5, a, b0, c);
1388 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
1389 VFMA_M0xN0(6, a, b0, c);
1390 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
1391 VFMA_M0xN0(7, a, b0, c);
1394 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
1395 VFMA_M0xN0(8, a, b0, c);
1396 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
1397 VFMA_M0xN0(9, a, b0, c);
1398 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
1399 VFMA_M0xN0(A, a, b0, c);
1400 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
1401 VFMA_M0xN0(B, a, b0, c);
1402 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
1403 VFMA_M0xN0(C, a, b0, c);
1404 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
1405 VFMA_M0xN0(D, a, b0, c);
1406 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
1407 VFMA_M0xN0(E, a, b0, c);
1408 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
1409 VFMA_M0xN0(F, a, b0, c);
1412 lhs_offset += K0 *
sizeof(DATA_TYPE);
1413 x_rhs += K0 * RHS_STEP_X * RHS_STEP_LOOP;
1421 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1424 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1428 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1432 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1436 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1440 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1444 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1448 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1453 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1455 VFMA_M0xN0(0, a, b0, c);
1457 lhs_offset +=
sizeof(DATA_TYPE);
1458 x_rhs += RHS_STEP_X;
1465 #if defined(REINTERPRET_OUTPUT_AS_3D)
1471 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1473 #else // defined(REINTERPRET_OUTPUT_AS_3D)
1476 dst_addr += z * dst_stride_z;
1478 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
1483 #endif // defined(ALPHA)
1487 #if defined(BROADCAST_BIAS)
1488 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
1490 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
1499 #else // defined(BROADCAST_BIAS)
1500 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (
COMPUTE_M0_START_ROW(y, M0,
PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
1502 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1511 #endif // defined(BROADCAST_BIAS)
1512 #endif // defined(BETA)
1514 #if defined(ACTIVATION_TYPE)
1516 #endif // defined(ACTIVATION_TYPE)
1519 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1521 #undef RHS_BLOCK_SIZE
1524 #undef RHS_STEP_LOOP
1526 #endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_TEXTURE)
1527 #endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE)
1529 #if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR)
1531 #if defined(MIXED_PRECISION)
1533 #define ARM_DOT_K0(a, b, c) \
1538 #elif K0 == 3 // K0 == 3
1539 #define ARM_DOT_K0(a, b, c) \
1545 #elif K0 == 4 // K0 == 4
1546 #define ARM_DOT_K0(a, b, c) \
1553 #elif K0 == 8 // K0 == 8
1554 #define ARM_DOT_K0(a, b, c) \
1565 #elif K0 == 16 // K0 == 16
1566 #define ARM_DOT_K0(a, b, c) \
1585 #else // K0 not supported
1586 #error "K0 value not supported"
1587 #endif // K0 conditions
1588 #else // defined(MIXED_PRECISION)
1590 #define ARM_DOT_K0(a, b, c) \
1592 c = fma(a.s0, b.s0, c); \
1593 c = fma(a.s1, b.s1, c); \
1595 #elif K0 == 3 // K0 == 3
1596 #define ARM_DOT_K0(a, b, c) \
1598 c = fma(a.s0, b.s0, c); \
1599 c = fma(a.s1, b.s1, c); \
1600 c = fma(a.s2, b.s2, c); \
1602 #elif K0 == 4 // K0 == 4
1603 #define ARM_DOT_K0(a, b, c) \
1605 c = fma(a.s0, b.s0, c); \
1606 c = fma(a.s1, b.s1, c); \
1607 c = fma(a.s2, b.s2, c); \
1608 c = fma(a.s3, b.s3, c); \
1610 #elif K0 == 8 // K0 == 8
1611 #define ARM_DOT_K0(a, b, c) \
1613 c = fma(a.s0, b.s0, c); \
1614 c = fma(a.s1, b.s1, c); \
1615 c = fma(a.s2, b.s2, c); \
1616 c = fma(a.s3, b.s3, c); \
1617 c = fma(a.s4, b.s4, c); \
1618 c = fma(a.s5, b.s5, c); \
1619 c = fma(a.s6, b.s6, c); \
1620 c = fma(a.s7, b.s7, c); \
1622 #elif K0 == 16 // K0 == 16
1623 #define ARM_DOT_K0(a, b, c) \
1625 c = fma(a.s0, b.s0, c); \
1626 c = fma(a.s1, b.s1, c); \
1627 c = fma(a.s2, b.s2, c); \
1628 c = fma(a.s3, b.s3, c); \
1629 c = fma(a.s4, b.s4, c); \
1630 c = fma(a.s5, b.s5, c); \
1631 c = fma(a.s6, b.s6, c); \
1632 c = fma(a.s7, b.s7, c); \
1633 c = fma(a.s8, b.s8, c); \
1634 c = fma(a.s9, b.s9, c); \
1635 c = fma(a.sA, b.sA, c); \
1636 c = fma(a.sB, b.sB, c); \
1637 c = fma(a.sC, b.sC, c); \
1638 c = fma(a.sD, b.sD, c); \
1639 c = fma(a.sE, b.sE, c); \
1640 c = fma(a.sF, b.sF, c); \
1642 #else // K0 not supported
1643 #error "K0 value not supported"
1644 #endif // K0 conditions
1645 #endif // defined(MIXED_PRECISION)
1647 #if defined(ARM_DOT_K0XN0)
1648 #undef ARM_DOT_K0XN0
1649 #endif // defined(ARM_DOT_K0XN0)
1652 #define ARM_DOT_K0XN0(a, b, c) \
1654 ARM_DOT_K0((a), (b##0), (c.s0)); \
1655 ARM_DOT_K0((a), (b##1), (c.s1)); \
1657 #elif N0 == 3 // N0 == 3
1658 #define ARM_DOT_K0XN0(a, b, c) \
1660 ARM_DOT_K0((a), (b##0), (c.s0)); \
1661 ARM_DOT_K0((a), (b##1), (c.s1)); \
1662 ARM_DOT_K0((a), (b##2), (c.s2)); \
1664 #elif N0 == 4 // N0 == 4
1665 #define ARM_DOT_K0XN0(a, b, c) \
1667 ARM_DOT_K0((a), (b##0), (c.s0)); \
1668 ARM_DOT_K0((a), (b##1), (c.s1)); \
1669 ARM_DOT_K0((a), (b##2), (c.s2)); \
1670 ARM_DOT_K0((a), (b##3), (c.s3)); \
1672 #elif N0 == 8 // N0 == 8
1673 #define ARM_DOT_K0XN0(a, b, c) \
1675 ARM_DOT_K0((a), (b##0), (c.s0)); \
1676 ARM_DOT_K0((a), (b##1), (c.s1)); \
1677 ARM_DOT_K0((a), (b##2), (c.s2)); \
1678 ARM_DOT_K0((a), (b##3), (c.s3)); \
1679 ARM_DOT_K0((a), (b##4), (c.s4)); \
1680 ARM_DOT_K0((a), (b##5), (c.s5)); \
1681 ARM_DOT_K0((a), (b##6), (c.s6)); \
1682 ARM_DOT_K0((a), (b##7), (c.s7)); \
1684 #elif N0 == 16 // N0 == 16
1685 #define ARM_DOT_K0XN0(a, b, c) \
1687 ARM_DOT_K0((a), (b##0), (c.s0)); \
1688 ARM_DOT_K0((a), (b##1), (c.s1)); \
1689 ARM_DOT_K0((a), (b##2), (c.s2)); \
1690 ARM_DOT_K0((a), (b##3), (c.s3)); \
1691 ARM_DOT_K0((a), (b##4), (c.s4)); \
1692 ARM_DOT_K0((a), (b##5), (c.s5)); \
1693 ARM_DOT_K0((a), (b##6), (c.s6)); \
1694 ARM_DOT_K0((a), (b##7), (c.s7)); \
1695 ARM_DOT_K0((a), (b##8), (c.s8)); \
1696 ARM_DOT_K0((a), (b##9), (c.s9)); \
1697 ARM_DOT_K0((a), (b##A), (c.sA)); \
1698 ARM_DOT_K0((a), (b##B), (c.sB)); \
1699 ARM_DOT_K0((a), (b##C), (c.sC)); \
1700 ARM_DOT_K0((a), (b##D), (c.sD)); \
1701 ARM_DOT_K0((a), (b##E), (c.sE)); \
1702 ARM_DOT_K0((a), (b##F), (c.sF)); \
1704 #else // N0 not supported
1705 #error "N0 value not supported"
1706 #endif // N0 conditions
1708 #if defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T)
1785 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1787 uint dst_cross_plane_pad
1795 #define LHS_BLOCK_SIZE ((K0) * (M0))
1797 #if defined(LHS_INTERLEAVE)
1798 #define LHS_OFFSET_X (K0)
1799 #define LHS_STEP_X ((K0) * (V0))
1800 #define LHS_STEP_LOOP (1)
1801 #else // defined(INTERLEAVE)
1802 #define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1803 #define LHS_STEP_X (K0)
1804 #define LHS_STEP_LOOP (V0)
1805 #endif // defined(INTERLEAVE)
1808 #define RHS_BLOCK_SIZE ((K0) * (N0))
1811 #if defined(RHS_INTERLEAVE)
1812 #define RHS_OFFSET_X (K0)
1813 #define RHS_STEP_X ((K0) * (H0))
1814 #define RHS_STEP_LOOP (1)
1815 #else // defined(RHS_INTERLEAVE)
1816 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1817 #define RHS_STEP_X (K0)
1818 #define RHS_STEP_LOOP (H0)
1819 #endif // defined(RHS_INTERLEAVE)
1821 #if defined(DUMMY_WORK_ITEMS)
1822 if((get_global_id(0) * N0 >=
N) || (get_global_id(1) * M0 >=
M))
1826 #endif // defined(DUMMY_WORK_ITEMS)
1829 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X *
sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1830 (get_global_id(2) * lhs_stride_z);
1833 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X *
sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1835 #if defined(MATRIX_B_DEPTH)
1837 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1838 #else // defined(MATRIX_B_DEPTH)
1839 rhs_addr += get_global_id(2) * rhs_stride_z;
1840 #endif // defined(MATRIX_B_DEPTH)
1848 for(
int i = 0; i <
K; i += K0)
1860 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X *
sizeof(DATA_TYPE), zlhs);
1863 LOAD_BLOCK(N0, K0, DATA_TYPE,
b, rhs_addr, 0, RHS_STEP_X *
sizeof(DATA_TYPE), zero);
1866 ARM_DOT_K0XN0(a0,
b, c0);
1868 ARM_DOT_K0XN0(a1,
b, c1);
1871 ARM_DOT_K0XN0(a2,
b, c2);
1874 ARM_DOT_K0XN0(a3,
b, c3);
1877 ARM_DOT_K0XN0(a4,
b, c4);
1880 ARM_DOT_K0XN0(a5,
b, c5);
1883 ARM_DOT_K0XN0(a6,
b, c6);
1886 ARM_DOT_K0XN0(a7,
b, c7);
1889 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) *
sizeof(DATA_TYPE);
1890 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) *
sizeof(DATA_TYPE);
1893 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1897 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
1898 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
1900 #if defined(REINTERPRET_OUTPUT_AS_3D)
1903 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1906 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1908 #else // defined(REINTERPRET_OUTPUT_AS_3D)
1911 dst_addr += get_global_id(2) * dst_stride_z;
1913 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
1918 #endif // defined(ALPHA)
1922 #if defined(BROADCAST_BIAS)
1923 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
1925 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
1932 #if defined(MIXED_PRECISION)
1935 #else // defined(MIXED_PRECISION)
1937 #endif // defined(MIXED_PRECISION)
1939 #else // defined(BROADCAST_BIAS)
1940 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1943 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1950 #if defined(MIXED_PRECISION)
1953 #else // defined(MIXED_PRECISION)
1955 #endif // defined(MIXED_PRECISION)
1957 #endif // defined(BROADCAST_BIAS)
1958 #endif // defined(BETA)
1960 #if defined(ACTIVATION_TYPE)
1961 #if defined(MIXED_PRECISION)
1962 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
1963 #else // defined(MIXED_PRECISION)
1965 #endif // defined(MIXED_PRECISION)
1966 #endif // defined(ACTIVATION_TYPE)
1969 #if defined(MIXED_PRECISION)
1971 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1972 #else // defined(MIXED_PRECISION)
1973 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1974 #endif // defined(MIXED_PRECISION)
1976 #undef LHS_BLOCK_SIZE
1979 #undef RHS_BLOCK_SIZE
1982 #undef LHS_STEP_LOOP
1983 #undef RHS_STEP_LOOP
1985 #endif // defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T)
1987 #if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T_TEXTURE)
2052 __read_only image2d_t rhs_img,
2063 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2065 uint dst_cross_plane_pad
2073 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
2076 #define LHS_BLOCK_SIZE ((K0) * (M0))
2078 #if defined(LHS_INTERLEAVE)
2079 #define LHS_OFFSET_X (K0)
2080 #define LHS_STEP_X ((K0) * (V0))
2081 #define LHS_STEP_LOOP (1)
2082 #else // defined(INTERLEAVE)
2083 #define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2084 #define LHS_STEP_X (K0)
2085 #define LHS_STEP_LOOP (V0)
2086 #endif // defined(INTERLEAVE)
2089 #define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
2092 #if defined(RHS_INTERLEAVE)
2093 #define RHS_OFFSET_X (PIXEL_UNIT)
2094 #define RHS_STEP_X (PIXEL_UNIT * (H0))
2095 #define RHS_STEP_LOOP (1)
2096 #else // defined(RHS_INTERLEAVE)
2097 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2098 #define RHS_STEP_X PIXEL_UNIT
2099 #define RHS_STEP_LOOP (H0)
2100 #endif // defined(RHS_INTERLEAVE)
2102 #if defined(DUMMY_WORK_ITEMS)
2103 if((get_global_id(0) * N0 >=
N) || (get_global_id(1) * M0 >=
M))
2107 #endif // defined(DUMMY_WORK_ITEMS)
2110 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X *
sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2111 (get_global_id(2) * lhs_stride_z);
2113 #if defined(MATRIX_B_DEPTH)
2115 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
2116 #else // defined(MATRIX_B_DEPTH)
2117 const uint z_rhs = get_global_id(2);
2118 #endif // defined(MATRIX_B_DEPTH)
2121 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
2122 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
2130 for(
int i = 0; i <
K; i += K0)
2133 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X *
sizeof(DATA_TYPE), zlhs);
2137 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE,
b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
2140 ARM_DOT_K0XN0(a0,
b, c0);
2142 ARM_DOT_K0XN0(a1,
b, c1);
2145 ARM_DOT_K0XN0(a2,
b, c2);
2148 ARM_DOT_K0XN0(a3,
b, c3);
2151 ARM_DOT_K0XN0(a4,
b, c4);
2154 ARM_DOT_K0XN0(a5,
b, c5);
2157 ARM_DOT_K0XN0(a6,
b, c6);
2160 ARM_DOT_K0XN0(a7,
b, c7);
2163 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) *
sizeof(DATA_TYPE);
2165 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2168 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2172 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
2173 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
2175 #if defined(REINTERPRET_OUTPUT_AS_3D)
2178 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2181 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2183 #else // defined(REINTERPRET_OUTPUT_AS_3D)
2186 dst_addr += get_global_id(2) * dst_stride_z;
2188 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
2193 #endif // defined(ALPHA)
2197 #if defined(BROADCAST_BIAS)
2198 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
2200 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
2207 #if defined(MIXED_PRECISION)
2210 #else // defined(MIXED_PRECISION)
2212 #endif // defined(MIXED_PRECISION)
2214 #else // defined(BROADCAST_BIAS)
2215 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2218 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2225 #if defined(MIXED_PRECISION)
2228 #else // defined(MIXED_PRECISION)
2230 #endif // defined(MIXED_PRECISION)
2232 #endif // defined(BROADCAST_BIAS)
2233 #endif // defined(BETA)
2235 #if defined(ACTIVATION_TYPE)
2236 #if defined(MIXED_PRECISION)
2237 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
2238 #else // defined(MIXED_PRECISION)
2240 #endif // defined(MIXED_PRECISION)
2241 #endif // defined(ACTIVATION_TYPE)
2244 #if defined(MIXED_PRECISION)
2246 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2247 #else // defined(MIXED_PRECISION)
2248 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2249 #endif // defined(MIXED_PRECISION)
2251 #undef LHS_BLOCK_SIZE
2254 #undef RHS_BLOCK_SIZE
2258 #undef LHS_STEP_LOOP
2259 #undef RHS_STEP_LOOP
2261 #endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T_TEXTURE)
2263 #if defined(LHS_TRANSPOSE)
2265 #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2267 #if defined(MIXED_PRECISION)
2269 #if(GPU_ARCH == GPU_ARCH_MIDGARD)
2270 #define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
2271 #else // GPU_ARCH == GPU_ARCH_MIDGARD
2272 #define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
2273 #endif // GPU_ARCH == GPU_ARCH_MIDGARD
2275 #else // defined(MIXED_PRECISION
2277 #if(GPU_ARCH == GPU_ARCH_MIDGARD)
2278 #define ARM_VFMA(N0, a, b, c) c += (a) * (b);
2279 #else // GPU_ARCH == GPU_ARCH_MIDGARD
2280 #define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
2281 #endif // GPU_ARCH == GPU_ARCH_MIDGARD
2283 #endif // defined(MIXED_PRECISION)
2285 #define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2287 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
2289 #define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2291 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2292 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
2294 #define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2296 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2297 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
2299 #define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2301 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2302 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
2304 #define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2306 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2307 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2308 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2309 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2310 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
2319 #define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2321 #define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2323 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2325 #define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2327 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2328 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2330 #define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2332 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2333 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2335 #define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2337 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2338 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2340 #define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2342 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2343 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2344 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2345 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2346 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2348 #define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2350 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2351 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2352 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2353 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2354 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2355 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2356 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2357 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2358 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2369 #define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2370 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
2371 (M0, N0, TYPE, A, B, C)
2373 #if defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT)
2448 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2450 uint dst_cross_plane_pad
2458 #define LHS_BLOCK_SIZE ((K0) * (M0))
2460 #if defined(LHS_INTERLEAVE)
2461 #define LHS_OFFSET_X (M0)
2462 #define LHS_STEP_X ((M0) * (V0))
2463 #define LHS_STEP_LOOP (1)
2464 #else // defined(INTERLEAVE)
2465 #define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2466 #define LHS_STEP_X (M0)
2467 #define LHS_STEP_LOOP (V0)
2468 #endif // defined(INTERLEAVE)
2471 #define RHS_BLOCK_SIZE ((K0) * (N0))
2474 #if defined(RHS_INTERLEAVE)
2475 #define RHS_OFFSET_X (N0)
2476 #define RHS_STEP_X ((N0) * (H0))
2477 #else // defined(RHS_INTERLEAVE)
2478 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2479 #define RHS_STEP_X (N0)
2480 #endif // defined(RHS_INTERLEAVE)
2482 const uint x = get_global_id(0);
2483 const uint y = get_global_id(1);
2484 const uint z = get_global_id(2);
2486 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
2487 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
2489 #if defined(DUMMY_WORK_ITEMS)
2490 if((x * N0 >=
N) || (y * M0 >=
M))
2494 #endif // defined(DUMMY_WORK_ITEMS)
2497 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X *
sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2500 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2502 #if defined(MATRIX_B_DEPTH)
2504 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2505 #else // defined(MATRIX_B_DEPTH)
2506 rhs_addr += z * rhs_stride_z;
2507 #endif // defined(MATRIX_B_DEPTH)
2514 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2515 __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2517 for(
int i = 0; i <
K; i += K0)
2524 a0 =
VLOAD(M0)(0, lhs);
2525 b0 =
VLOAD(N0)(0, rhs);
2527 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2533 a0 =
VLOAD(M0)(0, lhs);
2534 b0 =
VLOAD(N0)(0, rhs);
2536 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2543 a0 =
VLOAD(M0)(0, lhs);
2544 b0 =
VLOAD(N0)(0, rhs);
2546 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2553 a0 =
VLOAD(M0)(0, lhs);
2554 b0 =
VLOAD(N0)(0, rhs);
2556 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2563 a0 =
VLOAD(M0)(0, lhs);
2564 b0 =
VLOAD(N0)(0, rhs);
2566 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2571 a0 =
VLOAD(M0)(0, lhs);
2572 b0 =
VLOAD(N0)(0, rhs);
2574 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2579 a0 =
VLOAD(M0)(0, lhs);
2580 b0 =
VLOAD(N0)(0, rhs);
2582 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2587 a0 =
VLOAD(M0)(0, lhs);
2588 b0 =
VLOAD(N0)(0, rhs);
2590 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2597 a0 =
VLOAD(M0)(0, lhs);
2598 b0 =
VLOAD(N0)(0, rhs);
2600 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2605 a0 =
VLOAD(M0)(0, lhs);
2606 b0 =
VLOAD(N0)(0, rhs);
2608 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2613 a0 =
VLOAD(M0)(0, lhs);
2614 b0 =
VLOAD(N0)(0, rhs);
2616 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2621 a0 =
VLOAD(M0)(0, lhs);
2622 b0 =
VLOAD(N0)(0, rhs);
2624 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2629 a0 =
VLOAD(M0)(0, lhs);
2630 b0 =
VLOAD(N0)(0, rhs);
2632 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2637 a0 =
VLOAD(M0)(0, lhs);
2638 b0 =
VLOAD(N0)(0, rhs);
2640 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2645 a0 =
VLOAD(M0)(0, lhs);
2646 b0 =
VLOAD(N0)(0, rhs);
2648 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2653 a0 =
VLOAD(M0)(0, lhs);
2654 b0 =
VLOAD(N0)(0, rhs);
2656 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2662 #ifndef LHS_INTERLEAVE
2663 lhs += (M0 * K0 * (V0 - 1));
2664 #endif // LHS_INTERLEAVE
2666 #ifndef RHS_INTERLEAVE
2667 rhs += (N0 * K0 * (H0 - 1));
2668 #endif // RHS_INTERLEAVE
2671 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2675 #if defined(REINTERPRET_OUTPUT_AS_3D)
2678 CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2681 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2683 #else // defined(REINTERPRET_OUTPUT_AS_3D)
2686 dst_addr += z * dst_stride_z;
2688 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
2693 #endif // defined(ALPHA)
2697 #if defined(BROADCAST_BIAS)
2698 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE));
2700 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
2707 #if defined(MIXED_PRECISION)
2710 #else // defined(MIXED_PRECISION)
2712 #endif // defined(MIXED_PRECISION)
2714 #else // defined(BROADCAST_BIAS)
2715 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2718 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2724 #if defined(MIXED_PRECISION)
2727 #else // defined(MIXED_PRECISION)
2729 #endif // defined(MIXED_PRECISION)
2731 #endif // defined(BROADCAST_BIAS)
2732 #endif // defined(BETA)
2734 #if defined(ACTIVATION_TYPE)
2735 #if defined(MIXED_PRECISION)
2736 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
2737 #else // defined(MIXED_PRECISION)
2739 #endif // defined(MIXED_PRECISION)
2740 #endif // defined(ACTIVATION_TYPE)
2743 #if defined(MIXED_PRECISION)
2745 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2746 #else // defined(MIXED_PRECISION)
2747 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2748 #endif // defined(MIXED_PRECISION)
2750 #undef LHS_BLOCK_SIZE
2753 #undef RHS_BLOCK_SIZE
2757 #endif // defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT)
2759 #if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT_TEXTURE)
2822 __read_only image2d_t rhs_img,
2833 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2835 uint dst_cross_plane_pad
2843 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
2846 #define LHS_BLOCK_SIZE ((K0) * (M0))
2848 #if defined(LHS_INTERLEAVE)
2849 #define LHS_OFFSET_X (M0)
2850 #define LHS_STEP_X ((M0) * (V0))
2851 #define LHS_STEP_LOOP (1)
2852 #else // defined(INTERLEAVE)
2853 #define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2854 #define LHS_STEP_X (M0)
2855 #define LHS_STEP_LOOP (V0)
2856 #endif // defined(INTERLEAVE)
2859 #define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
2862 #if defined(RHS_INTERLEAVE)
2863 #define RHS_OFFSET_X (PIXEL_UNIT)
2864 #define RHS_STEP_X ((PIXEL_UNIT) * (H0))
2865 #else // defined(RHS_INTERLEAVE)
2866 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2867 #define RHS_STEP_X (PIXEL_UNIT)
2868 #endif // defined(RHS_INTERLEAVE)
2870 const uint x = get_global_id(0);
2871 const uint y = get_global_id(1);
2872 const uint z = get_global_id(2);
2874 #if defined(DUMMY_WORK_ITEMS)
2875 if((x * N0 >=
N) || (y * M0 >=
M))
2879 #endif // defined(DUMMY_WORK_ITEMS)
2882 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X *
sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2884 #if defined(MATRIX_B_DEPTH)
2886 const uint z_rhs = (z % MATRIX_B_DEPTH);
2887 #else // defined(MATRIX_B_DEPTH)
2888 const uint z_rhs = z;
2889 #endif // defined(MATRIX_B_DEPTH)
2892 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
2893 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
2900 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2902 for(
int i = 0; i <
K; i += K0)
2909 a0 =
VLOAD(M0)(0, lhs);
2910 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
2912 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2917 a0 =
VLOAD(M0)(0, lhs);
2918 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
2920 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2926 a0 =
VLOAD(M0)(0, lhs);
2927 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
2929 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2935 a0 =
VLOAD(M0)(0, lhs);
2936 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
2938 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2944 a0 =
VLOAD(M0)(0, lhs);
2945 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
2947 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2951 a0 =
VLOAD(M0)(0, lhs);
2952 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
2954 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2958 a0 =
VLOAD(M0)(0, lhs);
2959 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
2961 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2965 a0 =
VLOAD(M0)(0, lhs);
2966 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
2968 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2974 a0 =
VLOAD(M0)(0, lhs);
2975 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
2977 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2981 a0 =
VLOAD(M0)(0, lhs);
2982 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
2984 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2988 a0 =
VLOAD(M0)(0, lhs);
2989 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
2991 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
2995 a0 =
VLOAD(M0)(0, lhs);
2996 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
2998 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
3002 a0 =
VLOAD(M0)(0, lhs);
3003 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
3005 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
3009 a0 =
VLOAD(M0)(0, lhs);
3010 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
3012 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
3016 a0 =
VLOAD(M0)(0, lhs);
3017 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
3019 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
3023 a0 =
VLOAD(M0)(0, lhs);
3024 b0 =
READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
3026 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a,
b, c);
3031 #ifndef LHS_INTERLEAVE
3032 lhs += (M0 * K0 * (V0 - 1));
3033 #endif // LHS_INTERLEAVE
3035 x_rhs += K0 * RHS_STEP_X;
3036 #ifndef RHS_INTERLEAVE
3037 x_rhs += (PIXEL_UNIT * K0 * (H0 - 1));
3038 #endif // RHS_INTERLEAVE
3041 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3045 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
3046 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
3048 #if defined(REINTERPRET_OUTPUT_AS_3D)
3051 CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3054 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3056 #else // defined(REINTERPRET_OUTPUT_AS_3D)
3059 dst_addr += z * dst_stride_z;
3061 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
3066 #endif // defined(ALPHA)
3070 #if defined(BROADCAST_BIAS)
3071 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE));
3073 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero, 1,
PARTIAL_STORE_N0,
false, cond_x);
3080 #if defined(MIXED_PRECISION)
3083 #else // defined(MIXED_PRECISION)
3085 #endif // defined(MIXED_PRECISION)
3087 #else // defined(BROADCAST_BIAS)
3088 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
3090 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3096 #if defined(MIXED_PRECISION)
3099 #else // defined(MIXED_PRECISION)
3101 #endif // defined(MIXED_PRECISION)
3103 #endif // defined(BROADCAST_BIAS)
3104 #endif // defined(BETA)
3106 #if defined(ACTIVATION_TYPE)
3107 #if defined(MIXED_PRECISION)
3108 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
3109 #else // defined(MIXED_PRECISION)
3111 #endif // defined(MIXED_PRECISION)
3112 #endif // defined(ACTIVATION_TYPE)
3115 #if defined(MIXED_PRECISION)
3117 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3118 #else // defined(MIXED_PRECISION)
3119 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3120 #endif // defined(MIXED_PRECISION)
3122 #undef LHS_BLOCK_SIZE
3125 #undef RHS_BLOCK_SIZE
3129 #undef LHS_STEP_LOOP
3130 #undef RHS_STEP_LOOP
3132 #endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT_TEXTURE)
3134 #endif // defined(LHS_TRANSPOSE)
3136 #endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR)
3138 #if defined(M0) && defined(N0) && defined(K0) && defined(DATA_TYPE)
3140 #define VFMA(a, b, c) \
3146 #define RHS_VFMA_M0xN0(i, a, b, c) \
3148 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3150 #elif M0 == 2 // M0 == 2
3151 #define RHS_VFMA_M0xN0(i, a, b, c) \
3153 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3154 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3156 #elif M0 == 3 // M0 == 3
3157 #define RHS_VFMA_M0xN0(i, a, b, c) \
3159 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3160 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3161 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3163 #elif M0 == 4 // M0 == 4
3164 #define RHS_VFMA_M0xN0(i, a, b, c) \
3166 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3167 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3168 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3169 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3171 #elif M0 == 5 // M0 == 5
3172 #define RHS_VFMA_M0xN0(i, a, b, c) \
3174 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3175 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3176 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3177 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3178 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3180 #elif M0 == 6 // M0 == 6
3181 #define RHS_VFMA_M0xN0(i, a, b, c) \
3183 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3184 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3185 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3186 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3187 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3188 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3190 #elif M0 == 7 // M0 == 7
3191 #define RHS_VFMA_M0xN0(i, a, b, c) \
3193 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3194 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3195 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3196 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3197 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3198 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3199 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3201 #elif M0 == 8 // M0 == 8
3202 #define RHS_VFMA_M0xN0(i, a, b, c) \
3204 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3205 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3206 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3207 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3208 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3209 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3210 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3211 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
3213 #else // M0 not supported
3214 #error "M0 not supported"
3215 #endif // M0 not supported
3217 #if defined(GEMM_MM_NATIVE)
3292 #
if defined(REINTERPRET_INPUT_AS_3D)
3294 uint lhs_cross_plane_pad
3296 #
if defined(REINTERPRET_OUTPUT_AS_3D)
3298 uint dst_cross_plane_pad
3303 #define RHS_BLOCK_SIZE ((K0) * (N0))
3306 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
3308 uint x = get_global_id(0);
3309 uint y = get_global_id(1);
3310 uint z = get_global_id(2);
3312 #if defined(DUMMY_WORK_ITEMS)
3313 if((x * N0 >=
N) || (y * M0 >=
M))
3317 #endif // defined(DUMMY_WORK_ITEMS)
3323 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 *
sizeof(DATA_TYPE);
3325 #if defined(MATRIX_B_DEPTH)
3327 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
3328 #else // defined(MATRIX_B_DEPTH)
3329 rhs_offset += z * rhs_stride_z;
3330 #endif // defined(MATRIX_B_DEPTH)
3335 #if defined(REINTERPRET_INPUT_AS_3D)
3341 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
3343 #else // defined(REINTERPRET_INPUT_AS_3D)
3346 lhs_offset += z * lhs_stride_z;
3348 #endif // defined(REINTERPRET_INPUT_AS_3D)
3355 for(; i <= (
K - K0); i += K0)
3367 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
3370 LOAD_BLOCK(K0, N0, DATA_TYPE,
b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
3372 RHS_VFMA_M0xN0(0, a, b0, c);
3373 RHS_VFMA_M0xN0(1, a, b1, c);
3375 RHS_VFMA_M0xN0(2, a, b2, c);
3378 RHS_VFMA_M0xN0(3, a, b3, c);
3381 RHS_VFMA_M0xN0(4, a, b4, c);
3382 RHS_VFMA_M0xN0(5, a, b5, c);
3383 RHS_VFMA_M0xN0(6, a, b6, c);
3384 RHS_VFMA_M0xN0(7, a, b7, c);
3387 RHS_VFMA_M0xN0(8, a, b8, c);
3388 RHS_VFMA_M0xN0(9, a, b9, c);
3389 RHS_VFMA_M0xN0(A, a, bA, c);
3390 RHS_VFMA_M0xN0(B, a, bB, c);
3391 RHS_VFMA_M0xN0(C, a, bC, c);
3392 RHS_VFMA_M0xN0(D, a, bD, c);
3393 RHS_VFMA_M0xN0(E, a, bE, c);
3394 RHS_VFMA_M0xN0(F, a, bF, c);
3397 lhs_offset += K0 *
sizeof(DATA_TYPE);
3398 rhs_offset += K0 * rhs_stride_y;
3406 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
3409 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
3413 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
3417 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
3421 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
3425 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
3429 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
3433 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
3437 b =
VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
3438 RHS_VFMA_M0xN0(0, a,
b, c);
3440 lhs_offset +=
sizeof(DATA_TYPE);
3441 rhs_offset += rhs_stride_y;
3448 #if defined(REINTERPRET_OUTPUT_AS_3D)
3454 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3456 #else // defined(REINTERPRET_OUTPUT_AS_3D)
3459 dst_addr += z * dst_stride_z;
3461 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
3466 #endif // defined(ALPHA)
3470 #if defined(BROADCAST_BIAS)
3471 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(DATA_TYPE));
3473 LOAD_BLOCK(1, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero);
3482 #else // defined(BROADCAST_BIAS)
3483 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(DATA_TYPE)) + (
COMPUTE_M0_START_ROW(y, M0,
PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
3485 LOAD_BLOCK(M0, N0, DATA_TYPE,
bias, bias_addr, 0, bias_stride_y, zero);
3494 #endif // defined(BROADCAST_BIAS)
3495 #endif // defined(BETA)
3497 #if defined(ACTIVATION_TYPE)
3499 #endif // defined(ACTIVATION_TYPE)
3501 const bool cond_y = y == 0;
3502 const bool cond_x = ((x + 1) * N0 >=
N);
3505 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3507 #endif // defined(GEMM_MM_NATIVE)
3508 #endif // defined(M0) && defined(N0) && defined(K0) && defined(DATA_TYPE)
3540 float4 alpha_ab = vload4(0, (__global
float *)
dst.ptr);
3543 float4 c = vload4(0, (__global
float *)
src.ptr);
3546 float4 out = alpha_ab + (float4)BETA * c;
3549 vstore4(out, 0, (__global
float *)
dst.ptr);
3552 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3582 half8 alpha_ab = vload8(0, (__global
half *)
dst.ptr);
3585 half8 c = vload8(0, (__global
half *)
src.ptr);
3588 half8 out = alpha_ab + (half8)BETA * c;
3591 vstore8(out, 0, (__global
half *)
dst.ptr);
3593 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3594 #endif // defined(BETA)