27 #if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(PARTIAL_LOAD_M0) && defined(PARTIAL_LOAD_K0) 28 #define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1) 29 #define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2) 30 #define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3) 31 #define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7) 32 #define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) 33 #define CONCAT_INC(K0) INC##K0 34 #define INC(K0) CONCAT_INC(K0) 37 #define BOUNDARY_CONDITION_X(x, a) \ 39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \ 41 #else // (SRC_WIDTH % K0) 42 #define BOUNDARY_CONDITION_X(x, a) \ 44 #endif // (SRC_WIDTH % K0) 46 #define LOAD_TENSOR_BOUNDARY_AWARE_M0XK0(M0, K0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 48 if(y * M0 + M0 >= SRC_HEIGHT && PARTIAL_LOAD_M0 != 0) \ 50 if(x * K0 + K0 >= SRC_WIDTH && (PARTIAL_LOAD_K0 != 0)) \ 52 LOAD_TENSOR_M0XN0(PARTIAL_LOAD_M0, PARTIAL_LOAD_K0, DATA_TYPE, a, input_ptr, src_stride_y, zin); \ 56 LOAD_TENSOR_M0XN0(PARTIAL_LOAD_M0, K0, DATA_TYPE, a, input_ptr, src_stride_y, zin); \ 61 if(x * K0 + K0 >= SRC_WIDTH && (PARTIAL_LOAD_K0 != 0)) \ 63 LOAD_TENSOR_M0XN0(M0, PARTIAL_LOAD_K0, DATA_TYPE, a, input_ptr, src_stride_y, zin); \ 67 LOAD_TENSOR_M0XN0(M0, K0, DATA_TYPE, a, input_ptr, src_stride_y, zin); \ 113 #
if defined(REINTERPRET_INPUT_AS_3D)
120 #define BLOCK_SIZE ((M0) * (K0)) 123 #if defined(INTERLEAVE) 124 #define OUTPUT_OFFSET_X (K0) 125 #else // defined(INTERLEAVE) 126 #define OUTPUT_OFFSET_X (BLOCK_SIZE) 127 #endif // defined(INTERLEAVE) 130 #if defined(INTERLEAVE) 131 #define OUTPUT_STEP_X (K0) * (V0) 132 #else // Do not interleave 133 #define OUTPUT_STEP_X (K0) 134 #endif // defined(INTERLEAVE) 137 uint x = get_global_id(0);
138 uint y = get_global_id(1);
139 uint z = get_global_id(2);
144 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 *
sizeof(
DATA_TYPE) + y * (uint)M0 * src_stride_y;
147 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 *
sizeof(
DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
148 (uint)OUTPUT_OFFSET_X *
sizeof(
DATA_TYPE));
153 #if defined(REINTERPRET_INPUT_AS_3D) 157 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
160 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
162 #else // defined(REINTERPRET_INPUT_AS_3D) 164 input_ptr += z * (uint)src_stride_z;
166 #endif // defined(REINTERPRET_INPUT_AS_3D) 169 output_ptr += z * (uint)dst_stride_z;
175 LOAD_TENSOR_BOUNDARY_AWARE_M0XK0(M0, K0,
DATA_TYPE, a, input_ptr, src_stride_y, zin);
182 #undef OUTPUT_OFFSET_X 187 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 189 VEC_DATA_TYPE(DATA_TYPE, M0) \ 190 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \ 192 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 194 #elif M0 == 3 // M0 == 3 195 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 197 VEC_DATA_TYPE(DATA_TYPE, M0) \ 198 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \ 200 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 202 #elif M0 == 4 // M0 == 4 203 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 205 VEC_DATA_TYPE(DATA_TYPE, M0) \ 206 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \ 208 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 210 #elif M0 == 5 // M0 == 5 211 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 213 VEC_DATA_TYPE(DATA_TYPE, 4) \ 214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \ 215 DATA_TYPE res1 = a4.s##i; \ 217 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 218 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \ 220 #elif M0 == 6 // M0 == 6 221 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 223 VEC_DATA_TYPE(DATA_TYPE, 4) \ 224 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \ 225 VEC_DATA_TYPE(DATA_TYPE, 2) \ 226 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \ 228 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 230 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \ 232 #elif M0 == 7 // M0 == 7 233 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 235 VEC_DATA_TYPE(DATA_TYPE, 4) \ 236 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \ 237 VEC_DATA_TYPE(DATA_TYPE, 3) \ 238 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \ 240 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 242 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \ 244 #elif M0 == 8 // M0 == 8 245 #define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \ 247 VEC_DATA_TYPE(DATA_TYPE, M0) \ 248 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \ 250 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \ 252 #else // M0 not supported 253 #error "M0 value not supported" 254 #endif // N0 conditions 297 #
if defined(REINTERPRET_INPUT_AS_3D)
304 #define BLOCK_SIZE ((M0) * (K0)) 307 #if defined(INTERLEAVE) 308 #define OUTPUT_OFFSET_X (M0) 309 #else // defined(INTERLEAVE) 310 #define OUTPUT_OFFSET_X (BLOCK_SIZE) 311 #endif // defined(INTERLEAVE) 314 #if defined(INTERLEAVE) 315 #define OUTPUT_STEP_X (M0) * (V0) 316 #else // Do not interleave 317 #define OUTPUT_STEP_X (M0) 318 #endif // defined(INTERLEAVE) 321 uint x = get_global_id(0);
322 uint y = get_global_id(1);
323 uint z = get_global_id(2);
328 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 *
sizeof(
DATA_TYPE) + y * (uint)M0 * src_stride_y;
331 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 *
sizeof(
DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
332 (uint)OUTPUT_OFFSET_X *
sizeof(
DATA_TYPE));
337 #if defined(REINTERPRET_INPUT_AS_3D) 341 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
344 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
346 #else // defined(REINTERPRET_INPUT_AS_3D) 348 input_ptr += z * (uint)src_stride_z;
350 #endif // defined(REINTERPRET_INPUT_AS_3D) 353 output_ptr += z * (uint)dst_stride_z;
358 LOAD_TENSOR_BOUNDARY_AWARE_M0XK0(M0, K0,
DATA_TYPE, a, input_ptr, src_stride_y, zin);
362 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
363 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
365 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
368 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
372 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
374 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
377 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
378 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
383 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
384 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
388 #undef OUTPUT_OFFSET_X 391 #endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(PARTIAL_LOAD_M0) && defined(PARTIAL_LOAD_K0) 393 #if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT) 428 #define BLOCK_SIZE ((K0) * (N0)) 431 #if defined(INTERLEAVE) 432 #define OUTPUT_OFFSET_X (N0) 433 #else // defined(INTERLEAVE) 434 #define OUTPUT_OFFSET_X (BLOCK_SIZE) 435 #endif // defined(INTERLEAVE) 438 #if defined(INTERLEAVE) 439 #define OUTPUT_STEP_X (N0) * (H0) 440 #else // Do not interleave 441 #define OUTPUT_STEP_X (N0) 442 #endif // defined(INTERLEAVE) 445 uint x = get_global_id(0);
446 uint y = get_global_id(1);
447 uint z = get_global_id(2);
452 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 *
sizeof(
DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
455 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 *
sizeof(
DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X *
sizeof(
DATA_TYPE)) + ((
457 * (uint)dst_stride_y)
458 + z * (uint)dst_stride_z;
465 a0 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 0 * src_stride_y));
467 if(y * (uint)K0 + 1 < SRC_HEIGHT)
469 a1 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 1 * src_stride_y));
473 if(y * (uint)K0 + 2 < SRC_HEIGHT)
475 a2 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 2 * src_stride_y));
479 if(y * (uint)K0 + 3 < SRC_HEIGHT)
481 a3 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 3 * src_stride_y));
485 if(y * (uint)K0 + 4 < SRC_HEIGHT)
487 a4 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 4 * src_stride_y));
489 if(y * (uint)K0 + 5 < SRC_HEIGHT)
491 a5 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 5 * src_stride_y));
493 if(y * (uint)K0 + 6 < SRC_HEIGHT)
495 a6 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 6 * src_stride_y));
497 if(y * (uint)K0 + 7 < SRC_HEIGHT)
499 a7 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 7 * src_stride_y));
503 if(y * (uint)K0 + 8 < SRC_HEIGHT)
505 a8 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 8 * src_stride_y));
507 if(y * (uint)K0 + 9 < SRC_HEIGHT)
509 a9 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 9 * src_stride_y));
511 if(y * (uint)K0 + 10 < SRC_HEIGHT)
513 aA =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 10 * src_stride_y));
515 if(y * (uint)K0 + 11 < SRC_HEIGHT)
517 aB =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 11 * src_stride_y));
519 if(y * (uint)K0 + 12 < SRC_HEIGHT)
521 aC =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 12 * src_stride_y));
523 if(y * (uint)K0 + 13 < SRC_HEIGHT)
525 aD =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 13 * src_stride_y));
527 if(y * (uint)K0 + 14 < SRC_HEIGHT)
529 aE =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 14 * src_stride_y));
531 if(y * (uint)K0 + 15 < SRC_HEIGHT)
533 aF =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 #undef OUTPUT_OFFSET_X 546 #if defined(TRANSPOSE) 582 #define BLOCK_SIZE ((K0) * (N0)) 585 #if defined(INTERLEAVE) 586 #define OUTPUT_OFFSET_X (K0) 587 #else // defined(INTERLEAVE) 588 #define OUTPUT_OFFSET_X (BLOCK_SIZE) 589 #endif // defined(INTERLEAVE) 592 #if defined(INTERLEAVE) 593 #define OUTPUT_STEP_X (K0) * (H0) 594 #else // Do not interleave 595 #define OUTPUT_STEP_X (K0) 596 #endif // defined(INTERLEAVE) 599 uint x = get_global_id(0);
600 uint y = get_global_id(1);
601 uint z = get_global_id(2);
606 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 *
sizeof(
DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
609 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 *
sizeof(
DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X *
sizeof(
DATA_TYPE)) + ((x /
610 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
616 a0 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 0 * src_stride_y));
617 if(y * (uint)K0 + 1 < SRC_HEIGHT)
619 a1 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 1 * src_stride_y));
622 if(y * (uint)K0 + 2 < SRC_HEIGHT)
624 a2 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 2 * src_stride_y));
628 if(y * (uint)K0 + 3 < SRC_HEIGHT)
630 a3 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 3 * src_stride_y));
634 if(y * (uint)K0 + 4 < SRC_HEIGHT)
636 a4 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 4 * src_stride_y));
638 if(y * (uint)K0 + 5 < SRC_HEIGHT)
640 a5 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 5 * src_stride_y));
642 if(y * (uint)K0 + 6 < SRC_HEIGHT)
644 a6 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 6 * src_stride_y));
646 if(y * (uint)K0 + 7 < SRC_HEIGHT)
648 a7 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 7 * src_stride_y));
652 if(y * (uint)K0 + 8 < SRC_HEIGHT)
654 a8 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 8 * src_stride_y));
656 if(y * (uint)K0 + 9 < SRC_HEIGHT)
658 a9 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 9 * src_stride_y));
660 if(y * (uint)K0 + 10 < SRC_HEIGHT)
662 aA =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 10 * src_stride_y));
664 if(y * (uint)K0 + 11 < SRC_HEIGHT)
666 aB =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 11 * src_stride_y));
668 if(y * (uint)K0 + 12 < SRC_HEIGHT)
670 aC =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 12 * src_stride_y));
672 if(y * (uint)K0 + 13 < SRC_HEIGHT)
674 aD =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 13 * src_stride_y));
676 if(y * (uint)K0 + 14 < SRC_HEIGHT)
678 aE =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 14 * src_stride_y));
680 if(y * (uint)K0 + 15 < SRC_HEIGHT)
682 aF =
VLOAD(N0)(0, (__global
DATA_TYPE *)(input_ptr + 15 * src_stride_y));
720 #elif K0 == 3 // K0 == 2 751 #elif K0 == 4 // K0 == 4 782 #elif K0 == 8 // K0 == 8 813 #elif K0 == 16 // K0 == 16 821 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
823 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
826 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
830 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
834 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
836 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
838 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
840 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
844 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
846 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
848 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
850 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
852 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
854 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
856 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
858 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
862 #error "Not supported N0 value" 870 #undef OUTPUT_OFFSET_X 873 #endif // defined(TRANSPOSE) 874 #endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT) 876 #if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K) 878 #define CONCAT(a, b) a##b 880 #define ARM_DOT1(a, b, c) \ 884 #define ARM_DOT2(a, b, c) \ 886 c = fma(a.s0, b.s0, c); \ 887 c = fma(a.s1, b.s1, c); \ 889 #define ARM_DOT3(a, b, c) \ 892 c = fma((a.s2), (b.s2), c); \ 894 #define ARM_DOT4(a, b, c) \ 897 c = fma((a.s3), (b.s3), c); \ 899 #define ARM_DOT8(a, b, c) \ 901 ARM_DOT4((a.lo), (b.lo), c); \ 902 ARM_DOT4((a.hi), (b.hi), c); \ 904 #define ARM_DOT16(a, b, c) \ 906 ARM_DOT8((a.lo), (b.lo), c); \ 907 ARM_DOT8((a.hi), (b.hi), c); \ 911 #define ARM_DOT_K0XN0(k0, a, b, c) \ 913 CONCAT(ARM_DOT, k0) \ 914 ((a), (b##0), (c.s0)); \ 915 CONCAT(ARM_DOT, k0) \ 916 ((a), (b##1), (c.s1)); \ 918 #elif N0 == 3 // N0 == 3 919 #define ARM_DOT_K0XN0(k0, a, b, c) \ 921 CONCAT(ARM_DOT, k0) \ 922 ((a), (b##0), (c.s0)); \ 923 CONCAT(ARM_DOT, k0) \ 924 ((a), (b##1), (c.s1)); \ 925 CONCAT(ARM_DOT, k0) \ 926 ((a), (b##2), (c.s2)); \ 928 #elif N0 == 4 // N0 == 4 929 #define ARM_DOT_K0XN0(k0, a, b, c) \ 931 CONCAT(ARM_DOT, k0) \ 932 ((a), (b##0), (c.s0)); \ 933 CONCAT(ARM_DOT, k0) \ 934 ((a), (b##1), (c.s1)); \ 935 CONCAT(ARM_DOT, k0) \ 936 ((a), (b##2), (c.s2)); \ 937 CONCAT(ARM_DOT, k0) \ 938 ((a), (b##3), (c.s3)); \ 940 #elif N0 == 8 // N0 == 8 941 #define ARM_DOT_K0XN0(k0, a, b, c) \ 943 CONCAT(ARM_DOT, k0) \ 944 ((a), (b##0), (c.s0)); \ 945 CONCAT(ARM_DOT, k0) \ 946 ((a), (b##1), (c.s1)); \ 947 CONCAT(ARM_DOT, k0) \ 948 ((a), (b##2), (c.s2)); \ 949 CONCAT(ARM_DOT, k0) \ 950 ((a), (b##3), (c.s3)); \ 951 CONCAT(ARM_DOT, k0) \ 952 ((a), (b##4), (c.s4)); \ 953 CONCAT(ARM_DOT, k0) \ 954 ((a), (b##5), (c.s5)); \ 955 CONCAT(ARM_DOT, k0) \ 956 ((a), (b##6), (c.s6)); \ 957 CONCAT(ARM_DOT, k0) \ 958 ((a), (b##7), (c.s7)); \ 960 #elif N0 == 16 // N0 == 16 961 #define ARM_DOT_K0XN0(k0, a, b, c) \ 963 CONCAT(ARM_DOT, k0) \ 964 ((a), (b##0), (c.s0)); \ 965 CONCAT(ARM_DOT, k0) \ 966 ((a), (b##1), (c.s1)); \ 967 CONCAT(ARM_DOT, k0) \ 968 ((a), (b##2), (c.s2)); \ 969 CONCAT(ARM_DOT, k0) \ 970 ((a), (b##3), (c.s3)); \ 971 CONCAT(ARM_DOT, k0) \ 972 ((a), (b##4), (c.s4)); \ 973 CONCAT(ARM_DOT, k0) \ 974 ((a), (b##5), (c.s5)); \ 975 CONCAT(ARM_DOT, k0) \ 976 ((a), (b##6), (c.s6)); \ 977 CONCAT(ARM_DOT, k0) \ 978 ((a), (b##7), (c.s7)); \ 979 CONCAT(ARM_DOT, k0) \ 980 ((a), (b##8), (c.s8)); \ 981 CONCAT(ARM_DOT, k0) \ 982 ((a), (b##9), (c.s9)); \ 983 CONCAT(ARM_DOT, k0) \ 984 ((a), (b##A), (c.sA)); \ 985 CONCAT(ARM_DOT, k0) \ 986 ((a), (b##B), (c.sB)); \ 987 CONCAT(ARM_DOT, k0) \ 988 ((a), (b##C), (c.sC)); \ 989 CONCAT(ARM_DOT, k0) \ 990 ((a), (b##D), (c.sD)); \ 991 CONCAT(ARM_DOT, k0) \ 992 ((a), (b##E), (c.sE)); \ 993 CONCAT(ARM_DOT, k0) \ 994 ((a), (b##F), (c.sF)); \ 996 #else // N0 not supported 997 #error "N0 value not supported" 998 #endif // N0 conditions 1071 #
if defined(REINTERPRET_INPUT_AS_3D)
1073 uint lhs_cross_plane_pad
1075 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1077 uint dst_cross_plane_pad
1082 #define RHS_BLOCK_SIZE ((K0) * (N0)) 1085 #if defined(RHS_INTERLEAVE) 1086 #define RHS_OFFSET_X (K0) 1087 #define RHS_STEP_X ((K0) * (H0)) 1088 #define RHS_STEP_LOOP (1) 1089 #else // defined(RHS_INTERLEAVE) 1090 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 1091 #define RHS_STEP_X (K0) 1092 #define RHS_STEP_LOOP (H0) 1093 #endif // defined(RHS_INTERLEAVE) 1095 uint x = get_global_id(0);
1096 uint y = get_global_id(1);
1097 uint z = get_global_id(2);
1099 #if defined(DUMMY_WORK_ITEMS) 1100 if((x * N0 >=
N) || (y * M0 >=
M))
1104 #endif // defined(DUMMY_WORK_ITEMS) 1110 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(
DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1112 #if defined(MATRIX_B_DEPTH) 1114 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1115 #else // defined(MATRIX_B_DEPTH) 1116 rhs_offset += z * rhs_stride_z;
1117 #endif // defined(MATRIX_B_DEPTH) 1122 #if defined(REINTERPRET_INPUT_AS_3D) 1128 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1130 #else // defined(REINTERPRET_INPUT_AS_3D) 1133 lhs_offset += z * lhs_stride_z;
1135 #endif // defined(REINTERPRET_INPUT_AS_3D) 1141 for(; i <= (
K - K0); i += K0)
1159 ARM_DOT_K0XN0(K0, a0,
b, c0);
1161 ARM_DOT_K0XN0(K0, a1,
b, c1);
1164 ARM_DOT_K0XN0(K0, a2,
b, c2);
1167 ARM_DOT_K0XN0(K0, a3,
b, c3);
1170 ARM_DOT_K0XN0(K0, a4,
b, c4);
1173 ARM_DOT_K0XN0(K0, a5,
b, c5);
1176 ARM_DOT_K0XN0(K0, a6,
b, c6);
1179 ARM_DOT_K0XN0(K0, a7,
b, c7);
1183 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) *
sizeof(
DATA_TYPE);
1196 ARM_DOT_K0XN0(1, a0,
b, c0);
1198 ARM_DOT_K0XN0(1, a1,
b, c1);
1201 ARM_DOT_K0XN0(1, a2,
b, c2);
1204 ARM_DOT_K0XN0(1, a3,
b, c3);
1207 ARM_DOT_K0XN0(1, a4,
b, c4);
1210 ARM_DOT_K0XN0(1, a5,
b, c5);
1213 ARM_DOT_K0XN0(1, a6,
b, c6);
1216 ARM_DOT_K0XN0(1, a7,
b, c7);
1227 #if defined(REINTERPRET_OUTPUT_AS_3D) 1234 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1236 #else // defined(REINTERPRET_OUTPUT_AS_3D) 1239 dst_addr += z * dst_stride_z;
1241 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 1246 #endif // defined(ALPHA) 1250 #if defined(BROADCAST_BIAS) 1251 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
1262 #else // defined(BROADCAST_BIAS) 1274 #endif // defined(BROADCAST_BIAS) 1275 #endif // defined(BETA) 1277 #if defined(ACTIVATION_TYPE) 1279 #endif // defined(ACTIVATION_TYPE) 1281 const bool cond_y = y == 0;
1282 const bool cond_x = ((x + 1) * N0 >=
N);
1285 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1287 #undef RHS_BLOCK_SIZE 1292 #if defined(OPENCL_IMAGE_SUPPORT) 1351 __read_only image2d_t rhs_img,
1362 #
if defined(REINTERPRET_INPUT_AS_3D)
1364 uint lhs_cross_plane_pad
1366 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1368 uint dst_cross_plane_pad
1373 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0) 1375 #define LEFTOVER_K (K % K0) 1378 #define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0)) 1381 #if defined(RHS_INTERLEAVE) 1382 #define RHS_OFFSET_X (PIXEL_UNIT) 1383 #define RHS_STEP_X (PIXEL_UNIT * (H0)) 1384 #define RHS_STEP_LOOP (1) 1385 #else // defined(RHS_INTERLEAVE) 1386 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 1387 #define RHS_STEP_X PIXEL_UNIT 1388 #define RHS_STEP_LOOP (H0) 1389 #endif // defined(RHS_INTERLEAVE) 1391 uint x = get_global_id(0);
1392 uint y = get_global_id(1);
1393 uint z = get_global_id(2);
1395 #if defined(DUMMY_WORK_ITEMS) 1396 if((x * N0 >=
N) || (y * M0 >=
M))
1400 #endif // defined(DUMMY_WORK_ITEMS) 1405 #if defined(MATRIX_B_DEPTH) 1407 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
1408 #else // defined(MATRIX_B_DEPTH) 1409 const uint z_rhs = get_global_id(2);
1410 #endif // defined(MATRIX_B_DEPTH) 1413 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
1414 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
1419 #if defined(REINTERPRET_INPUT_AS_3D) 1425 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1427 #else // defined(REINTERPRET_INPUT_AS_3D) 1430 lhs_offset += z * lhs_stride_z;
1432 #endif // defined(REINTERPRET_INPUT_AS_3D) 1438 for(; i <= (K - K0); i += K0)
1448 ARM_DOT_K0XN0(K0, a0,
b, c0);
1450 ARM_DOT_K0XN0(K0, a1,
b, c1);
1453 ARM_DOT_K0XN0(K0, a2,
b, c2);
1456 ARM_DOT_K0XN0(K0, a3,
b, c3);
1459 ARM_DOT_K0XN0(K0, a4,
b, c4);
1462 ARM_DOT_K0XN0(K0, a5,
b, c5);
1465 ARM_DOT_K0XN0(K0, a6,
b, c6);
1468 ARM_DOT_K0XN0(K0, a7,
b, c7);
1472 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
1479 union UNION_VEC_TYPE
1486 union UNION_VEC_TYPE a0 = {.v = 0 };
1488 union UNION_VEC_TYPE a1 = {.v = 0 };
1491 union UNION_VEC_TYPE a2 = {.v = 0 };
1494 union UNION_VEC_TYPE a3 = {.v = 0 };
1497 union UNION_VEC_TYPE a4 = {.v = 0 };
1500 union UNION_VEC_TYPE a5 = {.v = 0 };
1503 union UNION_VEC_TYPE a6 = {.v = 0 };
1506 union UNION_VEC_TYPE a7 = {.v = 0 };
1515 for(
int k = 0; k < LEFTOVER_K; ++k)
1517 a0.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0);
1519 a1.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1);
1522 a2.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2);
1525 a3.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3);
1528 a4.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4);
1531 a5.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5);
1534 a6.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6);
1537 a7.s[k] = *(__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7);
1544 ARM_DOT_K0XN0(K0, a0.v,
b, c0);
1546 ARM_DOT_K0XN0(K0, a1.v,
b, c1);
1549 ARM_DOT_K0XN0(K0, a2.v,
b, c2);
1552 ARM_DOT_K0XN0(K0, a3.v,
b, c3);
1555 ARM_DOT_K0XN0(K0, a4.v,
b, c4);
1558 ARM_DOT_K0XN0(K0, a5.v,
b, c5);
1561 ARM_DOT_K0XN0(K0, a6.v,
b, c6);
1564 ARM_DOT_K0XN0(K0, a7.v,
b, c7);
1567 #endif // LEFTOVER_K != 0 1573 #if defined(REINTERPRET_OUTPUT_AS_3D) 1580 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1582 #else // defined(REINTERPRET_OUTPUT_AS_3D) 1585 dst_addr += z * dst_stride_z;
1587 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 1592 #endif // defined(ALPHA) 1596 #if defined(BROADCAST_BIAS) 1597 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
1608 #else // defined(BROADCAST_BIAS) 1620 #endif // defined(BROADCAST_BIAS) 1621 #endif // defined(BETA) 1623 #if defined(ACTIVATION_TYPE) 1625 #endif // defined(ACTIVATION_TYPE) 1627 const bool cond_y = y == 0;
1628 const bool cond_x = ((x + 1) * N0 >=
N);
1631 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1633 #undef RHS_BLOCK_SIZE 1639 #endif // defined(OPENCL_IMAGE_SUPPORT) 1641 #define VFMA(a, b, c) \ 1647 #define VFMA_M0xN0(i, a, b, c) \ 1649 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1651 #elif M0 == 2 // M0 == 2 1652 #define VFMA_M0xN0(i, a, b, c) \ 1654 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1655 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1657 #elif M0 == 3 // M0 == 3 1658 #define VFMA_M0xN0(i, a, b, c) \ 1660 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1661 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1662 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1664 #elif M0 == 4 // M0 == 4 1665 #define VFMA_M0xN0(i, a, b, c) \ 1667 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1668 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1669 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1670 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 1672 #elif M0 == 5 // M0 == 5 1673 #define VFMA_M0xN0(i, a, b, c) \ 1675 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1676 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1677 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1678 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 1679 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 1681 #elif M0 == 6 // M0 == 6 1682 #define VFMA_M0xN0(i, a, b, c) \ 1684 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1685 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1686 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1687 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 1688 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 1689 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 1691 #elif M0 == 7 // M0 == 7 1692 #define VFMA_M0xN0(i, a, b, c) \ 1694 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1695 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1696 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1697 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 1698 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 1699 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 1700 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \ 1702 #elif M0 == 8 // M0 == 8 1703 #define VFMA_M0xN0(i, a, b, c) \ 1705 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 1706 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 1707 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 1708 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 1709 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 1710 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 1711 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \ 1712 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \ 1714 #else // M0 not supported 1715 #error "M0 not supported" 1716 #endif // M0 not supported 1788 #
if defined(REINTERPRET_INPUT_AS_3D)
1790 uint lhs_cross_plane_pad
1792 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1794 uint dst_cross_plane_pad
1799 #define RHS_BLOCK_SIZE ((K0) * (N0)) 1802 #if defined(RHS_INTERLEAVE) 1803 #define RHS_OFFSET_X (N0) 1804 #define RHS_STEP_X ((N0) * (H0)) 1805 #define RHS_STEP_LOOP (1) 1806 #else // defined(RHS_INTERLEAVE) 1807 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 1808 #define RHS_STEP_X (N0) 1809 #define RHS_STEP_LOOP (H0) 1810 #endif // defined(RHS_INTERLEAVE) 1812 uint x = get_global_id(0);
1813 uint y = get_global_id(1);
1814 uint z = get_global_id(2);
1816 #if defined(DUMMY_WORK_ITEMS) 1817 if((x * N0 >=
N) || (y * M0 >=
M))
1821 #endif // defined(DUMMY_WORK_ITEMS) 1827 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(
DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1829 #if defined(MATRIX_B_DEPTH) 1831 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1832 #else // defined(MATRIX_B_DEPTH) 1833 rhs_offset += z * rhs_stride_z;
1834 #endif // defined(MATRIX_B_DEPTH) 1839 #if defined(REINTERPRET_INPUT_AS_3D) 1846 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1848 #else // defined(REINTERPRET_INPUT_AS_3D) 1851 lhs_offset += z * lhs_stride_z;
1853 #endif // defined(REINTERPRET_INPUT_AS_3D) 1859 for(; i <= (K - K0); i += K0)
1877 VFMA_M0xN0(0, a, b0, c);
1879 VFMA_M0xN0(1, a, b0, c);
1882 VFMA_M0xN0(2, a, b0, c);
1886 VFMA_M0xN0(3, a, b0, c);
1890 VFMA_M0xN0(4, a, b0, c);
1892 VFMA_M0xN0(5, a, b0, c);
1894 VFMA_M0xN0(6, a, b0, c);
1896 VFMA_M0xN0(7, a, b0, c);
1900 VFMA_M0xN0(8, a, b0, c);
1902 VFMA_M0xN0(9, a, b0, c);
1904 VFMA_M0xN0(A, a, b0, c);
1906 VFMA_M0xN0(B, a, b0, c);
1908 VFMA_M0xN0(C, a, b0, c);
1910 VFMA_M0xN0(D, a, b0, c);
1912 VFMA_M0xN0(E, a, b0, c);
1914 VFMA_M0xN0(F, a, b0, c);
1918 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP *
sizeof(
DATA_TYPE);
1926 a0 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1929 a1 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1933 a2 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1937 a3 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1941 a4 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1945 a5 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1949 a6 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1953 a7 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1960 VFMA_M0xN0(0, a, b0, c);
1963 rhs_offset += RHS_STEP_X *
sizeof(
DATA_TYPE);
1970 #if defined(REINTERPRET_OUTPUT_AS_3D) 1976 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1978 #else // defined(REINTERPRET_OUTPUT_AS_3D) 1981 dst_addr += z * dst_stride_z;
1983 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 1988 #endif // defined(ALPHA) 1992 #if defined(BROADCAST_BIAS) 1993 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
2004 #else // defined(BROADCAST_BIAS) 2016 #endif // defined(BROADCAST_BIAS) 2017 #endif // defined(BETA) 2019 #if defined(ACTIVATION_TYPE) 2021 #endif // defined(ACTIVATION_TYPE) 2023 const bool cond_y = y == 0;
2024 const bool cond_x = ((x + 1) * N0 >=
N);
2027 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2029 #undef RHS_BLOCK_SIZE 2034 #if defined(OPENCL_IMAGE_SUPPORT) 2093 __read_only image2d_t rhs_img,
2104 #
if defined(REINTERPRET_INPUT_AS_3D)
2106 uint lhs_cross_plane_pad
2108 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2110 uint dst_cross_plane_pad
2115 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0) 2118 #define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT)) 2121 #if defined(RHS_INTERLEAVE) 2122 #define RHS_OFFSET_X (PIXEL_UNIT) 2123 #define RHS_STEP_X ((PIXEL_UNIT) * (H0)) 2124 #else // defined(RHS_INTERLEAVE) 2125 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 2126 #define RHS_STEP_X (PIXEL_UNIT) 2127 #endif // defined(RHS_INTERLEAVE) 2129 uint x = get_global_id(0);
2130 uint y = get_global_id(1);
2131 uint z = get_global_id(2);
2133 #if defined(DUMMY_WORK_ITEMS) 2134 if((x * N0 >=
N) || (y * M0 >=
M))
2138 #endif // defined(DUMMY_WORK_ITEMS) 2143 #if defined(MATRIX_B_DEPTH) 2145 const uint z_rhs = (z % MATRIX_B_DEPTH);
2146 #else // defined(MATRIX_B_DEPTH) 2147 const uint z_rhs = z;
2148 #endif // defined(MATRIX_B_DEPTH) 2151 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
2152 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
2157 #if defined(REINTERPRET_INPUT_AS_3D) 2164 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2166 #else // defined(REINTERPRET_INPUT_AS_3D) 2169 lhs_offset += z * lhs_stride_z;
2171 #endif // defined(REINTERPRET_INPUT_AS_3D) 2177 for(; i <= (K - K0); i += K0)
2186 VFMA_M0xN0(0, a, b0, c);
2188 VFMA_M0xN0(1, a, b0, c);
2191 VFMA_M0xN0(2, a, b0, c);
2195 VFMA_M0xN0(3, a, b0, c);
2199 VFMA_M0xN0(4, a, b0, c);
2201 VFMA_M0xN0(5, a, b0, c);
2203 VFMA_M0xN0(6, a, b0, c);
2205 VFMA_M0xN0(7, a, b0, c);
2209 VFMA_M0xN0(8, a, b0, c);
2211 VFMA_M0xN0(9, a, b0, c);
2213 VFMA_M0xN0(A, a, b0, c);
2215 VFMA_M0xN0(B, a, b0, c);
2217 VFMA_M0xN0(C, a, b0, c);
2219 VFMA_M0xN0(D, a, b0, c);
2221 VFMA_M0xN0(E, a, b0, c);
2223 VFMA_M0xN0(F, a, b0, c);
2227 x_rhs += K0 * RHS_STEP_X * RHS_STEP_LOOP;
2235 a0 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
2238 a1 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
2242 a2 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
2246 a3 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
2250 a4 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
2254 a5 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
2258 a6 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
2262 a7 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
2269 VFMA_M0xN0(0, a, b0, c);
2272 x_rhs += RHS_STEP_X;
2279 #if defined(REINTERPRET_OUTPUT_AS_3D) 2285 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2287 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2290 dst_addr += z * dst_stride_z;
2292 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2297 #endif // defined(ALPHA) 2301 #if defined(BROADCAST_BIAS) 2302 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
2313 #else // defined(BROADCAST_BIAS) 2325 #endif // defined(BROADCAST_BIAS) 2326 #endif // defined(BETA) 2328 #if defined(ACTIVATION_TYPE) 2330 #endif // defined(ACTIVATION_TYPE) 2332 const bool cond_y = y == 0;
2333 const bool cond_x = ((x + 1) * N0 >=
N);
2336 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2338 #undef RHS_BLOCK_SIZE 2342 #endif // defined(OPENCL_IMAGE_SUPPORT) 2343 #endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K) 2345 #if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N) 2347 #if defined(MIXED_PRECISION) 2349 #define ARM_DOT_K0(a, b, c) \ 2354 #elif K0 == 3 // K0 == 3 2355 #define ARM_DOT_K0(a, b, c) \ 2361 #elif K0 == 4 // K0 == 4 2362 #define ARM_DOT_K0(a, b, c) \ 2369 #elif K0 == 8 // K0 == 8 2370 #define ARM_DOT_K0(a, b, c) \ 2381 #elif K0 == 16 // K0 == 16 2382 #define ARM_DOT_K0(a, b, c) \ 2401 #else // K0 not supported 2402 #error "K0 value not supported" 2403 #endif // K0 conditions 2404 #else // defined(MIXED_PRECISION) 2406 #define ARM_DOT_K0(a, b, c) \ 2408 c = fma(a.s0, b.s0, c); \ 2409 c = fma(a.s1, b.s1, c); \ 2411 #elif K0 == 3 // K0 == 3 2412 #define ARM_DOT_K0(a, b, c) \ 2414 c = fma(a.s0, b.s0, c); \ 2415 c = fma(a.s1, b.s1, c); \ 2416 c = fma(a.s2, b.s2, c); \ 2418 #elif K0 == 4 // K0 == 4 2419 #define ARM_DOT_K0(a, b, c) \ 2421 c = fma(a.s0, b.s0, c); \ 2422 c = fma(a.s1, b.s1, c); \ 2423 c = fma(a.s2, b.s2, c); \ 2424 c = fma(a.s3, b.s3, c); \ 2426 #elif K0 == 8 // K0 == 8 2427 #define ARM_DOT_K0(a, b, c) \ 2429 c = fma(a.s0, b.s0, c); \ 2430 c = fma(a.s1, b.s1, c); \ 2431 c = fma(a.s2, b.s2, c); \ 2432 c = fma(a.s3, b.s3, c); \ 2433 c = fma(a.s4, b.s4, c); \ 2434 c = fma(a.s5, b.s5, c); \ 2435 c = fma(a.s6, b.s6, c); \ 2436 c = fma(a.s7, b.s7, c); \ 2438 #elif K0 == 16 // K0 == 16 2439 #define ARM_DOT_K0(a, b, c) \ 2441 c = fma(a.s0, b.s0, c); \ 2442 c = fma(a.s1, b.s1, c); \ 2443 c = fma(a.s2, b.s2, c); \ 2444 c = fma(a.s3, b.s3, c); \ 2445 c = fma(a.s4, b.s4, c); \ 2446 c = fma(a.s5, b.s5, c); \ 2447 c = fma(a.s6, b.s6, c); \ 2448 c = fma(a.s7, b.s7, c); \ 2449 c = fma(a.s8, b.s8, c); \ 2450 c = fma(a.s9, b.s9, c); \ 2451 c = fma(a.sA, b.sA, c); \ 2452 c = fma(a.sB, b.sB, c); \ 2453 c = fma(a.sC, b.sC, c); \ 2454 c = fma(a.sD, b.sD, c); \ 2455 c = fma(a.sE, b.sE, c); \ 2456 c = fma(a.sF, b.sF, c); \ 2458 #else // K0 not supported 2459 #error "K0 value not supported" 2460 #endif // K0 conditions 2461 #endif // defined(MIXED_PRECISION) 2464 #define ARM_DOT_K0XN0(a, b, c) \ 2466 ARM_DOT_K0((a), (b##0), (c.s0)); \ 2467 ARM_DOT_K0((a), (b##1), (c.s1)); \ 2469 #elif N0 == 3 // N0 == 3 2470 #define ARM_DOT_K0XN0(a, b, c) \ 2472 ARM_DOT_K0((a), (b##0), (c.s0)); \ 2473 ARM_DOT_K0((a), (b##1), (c.s1)); \ 2474 ARM_DOT_K0((a), (b##2), (c.s2)); \ 2476 #elif N0 == 4 // N0 == 4 2477 #define ARM_DOT_K0XN0(a, b, c) \ 2479 ARM_DOT_K0((a), (b##0), (c.s0)); \ 2480 ARM_DOT_K0((a), (b##1), (c.s1)); \ 2481 ARM_DOT_K0((a), (b##2), (c.s2)); \ 2482 ARM_DOT_K0((a), (b##3), (c.s3)); \ 2484 #elif N0 == 8 // N0 == 8 2485 #define ARM_DOT_K0XN0(a, b, c) \ 2487 ARM_DOT_K0((a), (b##0), (c.s0)); \ 2488 ARM_DOT_K0((a), (b##1), (c.s1)); \ 2489 ARM_DOT_K0((a), (b##2), (c.s2)); \ 2490 ARM_DOT_K0((a), (b##3), (c.s3)); \ 2491 ARM_DOT_K0((a), (b##4), (c.s4)); \ 2492 ARM_DOT_K0((a), (b##5), (c.s5)); \ 2493 ARM_DOT_K0((a), (b##6), (c.s6)); \ 2494 ARM_DOT_K0((a), (b##7), (c.s7)); \ 2496 #elif N0 == 16 // N0 == 16 2497 #define ARM_DOT_K0XN0(a, b, c) \ 2499 ARM_DOT_K0((a), (b##0), (c.s0)); \ 2500 ARM_DOT_K0((a), (b##1), (c.s1)); \ 2501 ARM_DOT_K0((a), (b##2), (c.s2)); \ 2502 ARM_DOT_K0((a), (b##3), (c.s3)); \ 2503 ARM_DOT_K0((a), (b##4), (c.s4)); \ 2504 ARM_DOT_K0((a), (b##5), (c.s5)); \ 2505 ARM_DOT_K0((a), (b##6), (c.s6)); \ 2506 ARM_DOT_K0((a), (b##7), (c.s7)); \ 2507 ARM_DOT_K0((a), (b##8), (c.s8)); \ 2508 ARM_DOT_K0((a), (b##9), (c.s9)); \ 2509 ARM_DOT_K0((a), (b##A), (c.sA)); \ 2510 ARM_DOT_K0((a), (b##B), (c.sB)); \ 2511 ARM_DOT_K0((a), (b##C), (c.sC)); \ 2512 ARM_DOT_K0((a), (b##D), (c.sD)); \ 2513 ARM_DOT_K0((a), (b##E), (c.sE)); \ 2514 ARM_DOT_K0((a), (b##F), (c.sF)); \ 2516 #else // N0 not supported 2517 #error "N0 value not supported" 2518 #endif // N0 conditions 2595 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2597 uint dst_cross_plane_pad
2602 #define LHS_BLOCK_SIZE ((K0) * (M0)) 2604 #if defined(LHS_INTERLEAVE) 2605 #define LHS_OFFSET_X (K0) 2606 #define LHS_STEP_X ((K0) * (V0)) 2607 #define LHS_STEP_LOOP (1) 2608 #else // defined(INTERLEAVE) 2609 #define LHS_OFFSET_X (LHS_BLOCK_SIZE) 2610 #define LHS_STEP_X (K0) 2611 #define LHS_STEP_LOOP (V0) 2612 #endif // defined(INTERLEAVE) 2615 #define RHS_BLOCK_SIZE ((K0) * (N0)) 2618 #if defined(RHS_INTERLEAVE) 2619 #define RHS_OFFSET_X (K0) 2620 #define RHS_STEP_X ((K0) * (H0)) 2621 #define RHS_STEP_LOOP (1) 2622 #else // defined(RHS_INTERLEAVE) 2623 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 2624 #define RHS_STEP_X (K0) 2625 #define RHS_STEP_LOOP (H0) 2626 #endif // defined(RHS_INTERLEAVE) 2628 #if defined(DUMMY_WORK_ITEMS) 2629 if((get_global_id(0) * N0 >=
N) || (get_global_id(1) * M0 >=
M))
2633 #endif // defined(DUMMY_WORK_ITEMS) 2636 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X *
sizeof(
DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2637 (get_global_id(2) * lhs_stride_z);
2640 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X *
sizeof(
DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
2642 #if defined(MATRIX_B_DEPTH) 2644 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
2645 #else // defined(MATRIX_B_DEPTH) 2646 rhs_addr += get_global_id(2) * rhs_stride_z;
2647 #endif // defined(MATRIX_B_DEPTH) 2655 for(
int i = 0; i < k; i += K0)
2673 ARM_DOT_K0XN0(a0,
b, c0);
2675 ARM_DOT_K0XN0(a1,
b, c1);
2678 ARM_DOT_K0XN0(a2,
b, c2);
2681 ARM_DOT_K0XN0(a3,
b, c3);
2684 ARM_DOT_K0XN0(a4,
b, c4);
2687 ARM_DOT_K0XN0(a5,
b, c5);
2690 ARM_DOT_K0XN0(a6,
b, c6);
2693 ARM_DOT_K0XN0(a7,
b, c7);
2696 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) *
sizeof(
DATA_TYPE);
2697 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) *
sizeof(
DATA_TYPE);
2700 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2704 #if defined(REINTERPRET_OUTPUT_AS_3D) 2707 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2710 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2712 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2715 dst_addr += get_global_id(2) * dst_stride_z;
2717 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2722 #endif // defined(ALPHA) 2726 #if defined(BROADCAST_BIAS) 2727 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
2736 #if defined(MIXED_PRECISION) 2739 #else // defined(MIXED_PRECISION) 2741 #endif // defined(MIXED_PRECISION) 2743 #else // defined(BROADCAST_BIAS) 2744 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2754 #if defined(MIXED_PRECISION) 2755 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2757 #else // defined(MIXED_PRECISION) 2759 #endif // defined(MIXED_PRECISION) 2761 #endif // defined(BROADCAST_BIAS) 2762 #endif // defined(BETA) 2764 #if defined(ACTIVATION_TYPE) 2765 #if defined(MIXED_PRECISION) 2767 #else // defined(MIXED_PRECISION) 2769 #endif // defined(MIXED_PRECISION) 2770 #endif // defined(ACTIVATION_TYPE) 2772 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
2773 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
2776 #if defined(MIXED_PRECISION) 2778 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2779 #else // defined(MIXED_PRECISION) 2780 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2781 #endif // defined(MIXED_PRECISION) 2783 #undef LHS_BLOCK_SIZE 2786 #undef RHS_BLOCK_SIZE 2789 #undef LHS_STEP_LOOP 2790 #undef RHS_STEP_LOOP 2793 #if defined(OPENCL_IMAGE_SUPPORT) 2856 __read_only image2d_t rhs_img,
2868 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2870 uint dst_cross_plane_pad
2875 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0) 2878 #define LHS_BLOCK_SIZE ((K0) * (M0)) 2880 #if defined(LHS_INTERLEAVE) 2881 #define LHS_OFFSET_X (K0) 2882 #define LHS_STEP_X ((K0) * (V0)) 2883 #define LHS_STEP_LOOP (1) 2884 #else // defined(INTERLEAVE) 2885 #define LHS_OFFSET_X (LHS_BLOCK_SIZE) 2886 #define LHS_STEP_X (K0) 2887 #define LHS_STEP_LOOP (V0) 2888 #endif // defined(INTERLEAVE) 2891 #define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0)) 2894 #if defined(RHS_INTERLEAVE) 2895 #define RHS_OFFSET_X (PIXEL_UNIT) 2896 #define RHS_STEP_X (PIXEL_UNIT * (H0)) 2897 #define RHS_STEP_LOOP (1) 2898 #else // defined(RHS_INTERLEAVE) 2899 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 2900 #define RHS_STEP_X PIXEL_UNIT 2901 #define RHS_STEP_LOOP (H0) 2902 #endif // defined(RHS_INTERLEAVE) 2904 #if defined(DUMMY_WORK_ITEMS) 2905 if((get_global_id(0) * N0 >=
N) || (get_global_id(1) * M0 >=
M))
2909 #endif // defined(DUMMY_WORK_ITEMS) 2912 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X *
sizeof(
DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2913 (get_global_id(2) * lhs_stride_z);
2915 #if defined(MATRIX_B_DEPTH) 2917 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
2918 #else // defined(MATRIX_B_DEPTH) 2919 const uint z_rhs = get_global_id(2);
2920 #endif // defined(MATRIX_B_DEPTH) 2923 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
2924 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
2932 for(
int i = 0; i <
K; i += K0)
2942 ARM_DOT_K0XN0(a0,
b, c0);
2944 ARM_DOT_K0XN0(a1,
b, c1);
2947 ARM_DOT_K0XN0(a2,
b, c2);
2950 ARM_DOT_K0XN0(a3,
b, c3);
2953 ARM_DOT_K0XN0(a4,
b, c4);
2956 ARM_DOT_K0XN0(a5,
b, c5);
2959 ARM_DOT_K0XN0(a6,
b, c6);
2962 ARM_DOT_K0XN0(a7,
b, c7);
2965 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) *
sizeof(
DATA_TYPE);
2967 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2970 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2974 #if defined(REINTERPRET_OUTPUT_AS_3D) 2977 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2980 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2982 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2985 dst_addr += get_global_id(2) * dst_stride_z;
2987 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2992 #endif // defined(ALPHA) 2996 #if defined(BROADCAST_BIAS) 2997 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
3006 #if defined(MIXED_PRECISION) 3009 #else // defined(MIXED_PRECISION) 3011 #endif // defined(MIXED_PRECISION) 3013 #else // defined(BROADCAST_BIAS) 3014 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
3024 #if defined(MIXED_PRECISION) 3025 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3027 #else // defined(MIXED_PRECISION) 3029 #endif // defined(MIXED_PRECISION) 3031 #endif // defined(BROADCAST_BIAS) 3032 #endif // defined(BETA) 3034 #if defined(ACTIVATION_TYPE) 3035 #if defined(MIXED_PRECISION) 3037 #else // defined(MIXED_PRECISION) 3039 #endif // defined(MIXED_PRECISION) 3040 #endif // defined(ACTIVATION_TYPE) 3042 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
3043 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
3046 #if defined(MIXED_PRECISION) 3048 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3049 #else // defined(MIXED_PRECISION) 3050 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3051 #endif // defined(MIXED_PRECISION) 3053 #undef LHS_BLOCK_SIZE 3056 #undef RHS_BLOCK_SIZE 3060 #undef LHS_STEP_LOOP 3061 #undef RHS_STEP_LOOP 3063 #endif // defined(OPENCL_IMAGE_SUPPORT) 3065 #if defined(LHS_TRANSPOSE) 3067 #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE) 3069 #if defined(MIXED_PRECISION) 3071 #if(GPU_ARCH == GPU_ARCH_MIDGARD) 3072 #define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))); 3073 #else // GPU_ARCH == GPU_ARCH_MIDGARD 3074 #define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c)); 3075 #endif // GPU_ARCH == GPU_ARCH_MIDGARD 3077 #else // defined(MIXED_PRECISION 3079 #if(GPU_ARCH == GPU_ARCH_MIDGARD) 3080 #define ARM_VFMA(N0, a, b, c) c += (a) * (b); 3081 #else // GPU_ARCH == GPU_ARCH_MIDGARD 3082 #define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c)); 3083 #endif // GPU_ARCH == GPU_ARCH_MIDGARD 3085 #endif // defined(MIXED_PRECISION) 3087 #define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ 3089 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \ 3091 #define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ 3093 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ 3094 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ 3096 #define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ 3098 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ 3099 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ 3101 #define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ 3103 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ 3104 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ 3106 #define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ 3108 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ 3109 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ 3110 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ 3111 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ 3112 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ 3121 #define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C) 3123 #define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \ 3125 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \ 3127 #define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \ 3129 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \ 3130 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \ 3132 #define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \ 3134 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \ 3135 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \ 3137 #define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \ 3139 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \ 3140 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \ 3142 #define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \ 3144 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \ 3145 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \ 3146 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \ 3147 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \ 3148 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \ 3150 #define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \ 3152 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \ 3153 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \ 3154 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \ 3155 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \ 3156 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \ 3157 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \ 3158 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \ 3159 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \ 3160 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \ 3171 #define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \ 3172 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \ 3173 (M0, N0, TYPE, A, B, C) 3248 #
if defined(REINTERPRET_OUTPUT_AS_3D)
3250 uint dst_cross_plane_pad
3255 #define LHS_BLOCK_SIZE ((K0) * (M0)) 3257 #if defined(LHS_INTERLEAVE) 3258 #define LHS_OFFSET_X (M0) 3259 #define LHS_STEP_X ((M0) * (V0)) 3260 #define LHS_STEP_LOOP (1) 3261 #else // defined(INTERLEAVE) 3262 #define LHS_OFFSET_X (LHS_BLOCK_SIZE) 3263 #define LHS_STEP_X (M0) 3264 #define LHS_STEP_LOOP (V0) 3265 #endif // defined(INTERLEAVE) 3268 #define RHS_BLOCK_SIZE ((K0) * (N0)) 3271 #if defined(RHS_INTERLEAVE) 3272 #define RHS_OFFSET_X (N0) 3273 #define RHS_STEP_X ((N0) * (H0)) 3274 #else // defined(RHS_INTERLEAVE) 3275 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 3276 #define RHS_STEP_X (N0) 3277 #endif // defined(RHS_INTERLEAVE) 3279 const uint x = get_global_id(0);
3280 const uint y = get_global_id(1);
3281 const uint z = get_global_id(2);
3283 #if defined(DUMMY_WORK_ITEMS) 3284 if((x * N0 >=
N) || (y * M0 >=
M))
3288 #endif // defined(DUMMY_WORK_ITEMS) 3291 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X *
sizeof(
DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
3294 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X *
sizeof(
DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
3296 #if defined(MATRIX_B_DEPTH) 3298 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
3299 #else // defined(MATRIX_B_DEPTH) 3300 rhs_addr += z * rhs_stride_z;
3301 #endif // defined(MATRIX_B_DEPTH) 3311 for(
int i = 0; i < k; i += K0)
3318 a0 =
VLOAD(M0)(0, lhs);
3319 b0 =
VLOAD(N0)(0, rhs);
3327 a0 =
VLOAD(M0)(0, lhs);
3328 b0 =
VLOAD(N0)(0, rhs);
3337 a0 =
VLOAD(M0)(0, lhs);
3338 b0 =
VLOAD(N0)(0, rhs);
3347 a0 =
VLOAD(M0)(0, lhs);
3348 b0 =
VLOAD(N0)(0, rhs);
3357 a0 =
VLOAD(M0)(0, lhs);
3358 b0 =
VLOAD(N0)(0, rhs);
3365 a0 =
VLOAD(M0)(0, lhs);
3366 b0 =
VLOAD(N0)(0, rhs);
3373 a0 =
VLOAD(M0)(0, lhs);
3374 b0 =
VLOAD(N0)(0, rhs);
3381 a0 =
VLOAD(M0)(0, lhs);
3382 b0 =
VLOAD(N0)(0, rhs);
3391 a0 =
VLOAD(M0)(0, lhs);
3392 b0 =
VLOAD(N0)(0, rhs);
3399 a0 =
VLOAD(M0)(0, lhs);
3400 b0 =
VLOAD(N0)(0, rhs);
3407 a0 =
VLOAD(M0)(0, lhs);
3408 b0 =
VLOAD(N0)(0, rhs);
3415 a0 =
VLOAD(M0)(0, lhs);
3416 b0 =
VLOAD(N0)(0, rhs);
3423 a0 =
VLOAD(M0)(0, lhs);
3424 b0 =
VLOAD(N0)(0, rhs);
3431 a0 =
VLOAD(M0)(0, lhs);
3432 b0 =
VLOAD(N0)(0, rhs);
3439 a0 =
VLOAD(M0)(0, lhs);
3440 b0 =
VLOAD(N0)(0, rhs);
3447 a0 =
VLOAD(M0)(0, lhs);
3448 b0 =
VLOAD(N0)(0, rhs);
3456 #ifndef LHS_INTERLEAVE 3457 lhs += (M0 * K0 * (V0 - 1));
3458 #endif // LHS_INTERLEAVE 3460 #ifndef RHS_INTERLEAVE 3461 rhs += (N0 * K0 * (H0 - 1));
3462 #endif // RHS_INTERLEAVE 3465 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(
DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3469 #if defined(REINTERPRET_OUTPUT_AS_3D) 3472 CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3475 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3477 #else // defined(REINTERPRET_OUTPUT_AS_3D) 3480 dst_addr += z * dst_stride_z;
3482 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 3487 #endif // defined(ALPHA) 3491 #if defined(BROADCAST_BIAS) 3492 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(
DATA_TYPE));
3501 #if defined(MIXED_PRECISION) 3504 #else // defined(MIXED_PRECISION) 3506 #endif // defined(MIXED_PRECISION) 3508 #else // defined(BROADCAST_BIAS) 3509 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
3518 #if defined(MIXED_PRECISION) 3519 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3521 #else // defined(MIXED_PRECISION) 3523 #endif // defined(MIXED_PRECISION) 3525 #endif // defined(BROADCAST_BIAS) 3526 #endif // defined(BETA) 3528 #if defined(ACTIVATION_TYPE) 3529 #if defined(MIXED_PRECISION) 3531 #else // defined(MIXED_PRECISION) 3533 #endif // defined(MIXED_PRECISION) 3534 #endif // defined(ACTIVATION_TYPE) 3536 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
3537 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
3540 #if defined(MIXED_PRECISION) 3542 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3543 #else // defined(MIXED_PRECISION) 3544 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3545 #endif // defined(MIXED_PRECISION) 3547 #undef LHS_BLOCK_SIZE 3550 #undef RHS_BLOCK_SIZE 3555 #if defined(OPENCL_IMAGE_SUPPORT) 3616 __read_only image2d_t rhs_img,
3628 #
if defined(REINTERPRET_OUTPUT_AS_3D)
3630 uint dst_cross_plane_pad
3635 #define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0) 3638 #define LHS_BLOCK_SIZE ((K0) * (M0)) 3640 #if defined(LHS_INTERLEAVE) 3641 #define LHS_OFFSET_X (M0) 3642 #define LHS_STEP_X ((M0) * (V0)) 3643 #define LHS_STEP_LOOP (1) 3644 #else // defined(INTERLEAVE) 3645 #define LHS_OFFSET_X (LHS_BLOCK_SIZE) 3646 #define LHS_STEP_X (M0) 3647 #define LHS_STEP_LOOP (V0) 3648 #endif // defined(INTERLEAVE) 3651 #define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT)) 3654 #if defined(RHS_INTERLEAVE) 3655 #define RHS_OFFSET_X (PIXEL_UNIT) 3656 #define RHS_STEP_X ((PIXEL_UNIT) * (H0)) 3657 #else // defined(RHS_INTERLEAVE) 3658 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 3659 #define RHS_STEP_X (PIXEL_UNIT) 3660 #endif // defined(RHS_INTERLEAVE) 3662 const uint x = get_global_id(0);
3663 const uint y = get_global_id(1);
3664 const uint z = get_global_id(2);
3666 #if defined(DUMMY_WORK_ITEMS) 3667 if((x * N0 >=
N) || (y * M0 >=
M))
3671 #endif // defined(DUMMY_WORK_ITEMS) 3674 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X *
sizeof(
DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
3676 #if defined(MATRIX_B_DEPTH) 3678 const uint z_rhs = (z % MATRIX_B_DEPTH);
3679 #else // defined(MATRIX_B_DEPTH) 3680 const uint z_rhs = z;
3681 #endif // defined(MATRIX_B_DEPTH) 3684 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
3685 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
3694 for(
int i = 0; i <
K; i += K0)
3701 a0 =
VLOAD(M0)(0, lhs);
3709 a0 =
VLOAD(M0)(0, lhs);
3718 a0 =
VLOAD(M0)(0, lhs);
3727 a0 =
VLOAD(M0)(0, lhs);
3736 a0 =
VLOAD(M0)(0, lhs);
3743 a0 =
VLOAD(M0)(0, lhs);
3750 a0 =
VLOAD(M0)(0, lhs);
3757 a0 =
VLOAD(M0)(0, lhs);
3766 a0 =
VLOAD(M0)(0, lhs);
3773 a0 =
VLOAD(M0)(0, lhs);
3780 a0 =
VLOAD(M0)(0, lhs);
3787 a0 =
VLOAD(M0)(0, lhs);
3794 a0 =
VLOAD(M0)(0, lhs);
3801 a0 =
VLOAD(M0)(0, lhs);
3808 a0 =
VLOAD(M0)(0, lhs);
3815 a0 =
VLOAD(M0)(0, lhs);
3823 #ifndef LHS_INTERLEAVE 3824 lhs += (M0 * K0 * (V0 - 1));
3825 #endif // LHS_INTERLEAVE 3827 x_rhs += K0 * RHS_STEP_X;
3828 #ifndef RHS_INTERLEAVE 3829 x_rhs += (PIXEL_UNIT * K0 * (H0 - 1));
3830 #endif // RHS_INTERLEAVE 3833 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(
DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3837 #if defined(REINTERPRET_OUTPUT_AS_3D) 3840 CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3843 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3845 #else // defined(REINTERPRET_OUTPUT_AS_3D) 3848 dst_addr += z * dst_stride_z;
3850 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 3855 #endif // defined(ALPHA) 3859 #if defined(BROADCAST_BIAS) 3860 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(
DATA_TYPE));
3869 #if defined(MIXED_PRECISION) 3872 #else // defined(MIXED_PRECISION) 3874 #endif // defined(MIXED_PRECISION) 3876 #else // defined(BROADCAST_BIAS) 3877 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 *
sizeof(
DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
3885 #if defined(MIXED_PRECISION) 3886 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3888 #else // defined(MIXED_PRECISION) 3890 #endif // defined(MIXED_PRECISION) 3892 #endif // defined(BROADCAST_BIAS) 3893 #endif // defined(BETA) 3895 #if defined(ACTIVATION_TYPE) 3896 #if defined(MIXED_PRECISION) 3898 #else // defined(MIXED_PRECISION) 3900 #endif // defined(MIXED_PRECISION) 3901 #endif // defined(ACTIVATION_TYPE) 3903 const bool cond_y = ((get_global_id(1) + 1) * M0 >=
M);
3904 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
3907 #if defined(MIXED_PRECISION) 3909 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3910 #else // defined(MIXED_PRECISION) 3911 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3912 #endif // defined(MIXED_PRECISION) 3914 #undef LHS_BLOCK_SIZE 3917 #undef RHS_BLOCK_SIZE 3921 #undef LHS_STEP_LOOP 3922 #undef RHS_STEP_LOOP 3924 #endif // defined(OPENCL_IMAGE_SUPPORT) 3926 #endif // defined(LHS_TRANSPOSE) 3928 #endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE) 3930 #if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE) 3932 #define VFMA(a, b, c) \ 3938 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3940 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3942 #elif M0 == 2 // M0 == 2 3943 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3945 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3946 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3948 #elif M0 == 3 // M0 == 3 3949 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3951 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3952 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3953 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3955 #elif M0 == 4 // M0 == 4 3956 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3958 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3959 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3960 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3961 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 3963 #elif M0 == 5 // M0 == 5 3964 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3966 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3967 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3968 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3969 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 3970 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 3972 #elif M0 == 6 // M0 == 6 3973 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3975 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3976 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3977 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3978 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 3979 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 3980 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 3982 #elif M0 == 7 // M0 == 7 3983 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3985 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3986 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3987 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3988 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 3989 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 3990 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 3991 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \ 3993 #elif M0 == 8 // M0 == 8 3994 #define RHS_VFMA_M0xN0(i, a, b, c) \ 3996 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \ 3997 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \ 3998 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \ 3999 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \ 4000 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \ 4001 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \ 4002 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \ 4003 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \ 4005 #else // M0 not supported 4006 #error "M0 not supported" 4007 #endif // M0 not supported 4078 #
if defined(REINTERPRET_INPUT_AS_3D)
4080 uint lhs_cross_plane_pad
4082 #
if defined(REINTERPRET_OUTPUT_AS_3D)
4084 uint dst_cross_plane_pad
4089 #define RHS_BLOCK_SIZE ((K0) * (N0)) 4092 #define RHS_OFFSET_X (RHS_BLOCK_SIZE) 4094 uint x = get_global_id(0);
4095 uint y = get_global_id(1);
4096 uint z = get_global_id(2);
4098 #if defined(DUMMY_WORK_ITEMS) 4099 if((x * N0 >=
N) || (y * M0 >=
M))
4103 #endif // defined(DUMMY_WORK_ITEMS) 4109 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 *
sizeof(
DATA_TYPE);
4111 #if defined(MATRIX_B_DEPTH) 4113 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
4114 #else // defined(MATRIX_B_DEPTH) 4115 rhs_offset += z * rhs_stride_z;
4116 #endif // defined(MATRIX_B_DEPTH) 4121 #if defined(REINTERPRET_INPUT_AS_3D) 4127 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
4129 #else // defined(REINTERPRET_INPUT_AS_3D) 4132 lhs_offset += z * lhs_stride_z;
4134 #endif // defined(REINTERPRET_INPUT_AS_3D) 4140 for(; i <= (K - K0); i += K0)
4157 RHS_VFMA_M0xN0(0, a, b0, c);
4158 RHS_VFMA_M0xN0(1, a, b1, c);
4160 RHS_VFMA_M0xN0(2, a, b2, c);
4163 RHS_VFMA_M0xN0(3, a, b3, c);
4166 RHS_VFMA_M0xN0(4, a, b4, c);
4167 RHS_VFMA_M0xN0(5, a, b5, c);
4168 RHS_VFMA_M0xN0(6, a, b6, c);
4169 RHS_VFMA_M0xN0(7, a, b7, c);
4172 RHS_VFMA_M0xN0(8, a, b8, c);
4173 RHS_VFMA_M0xN0(9, a, b9, c);
4174 RHS_VFMA_M0xN0(A, a, bA, c);
4175 RHS_VFMA_M0xN0(B, a, bB, c);
4176 RHS_VFMA_M0xN0(C, a, bC, c);
4177 RHS_VFMA_M0xN0(D, a, bD, c);
4178 RHS_VFMA_M0xN0(E, a, bE, c);
4179 RHS_VFMA_M0xN0(F, a, bF, c);
4183 rhs_offset += K0 * rhs_stride_y;
4191 a0 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
4194 a1 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
4198 a2 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
4202 a3 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
4206 a4 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
4210 a5 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
4214 a6 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
4218 a7 = *((__global
DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
4222 b =
VLOAD(N0)(0, (__global
DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
4223 RHS_VFMA_M0xN0(0, a, b, c);
4226 rhs_offset += rhs_stride_y;
4233 #if defined(REINTERPRET_OUTPUT_AS_3D) 4239 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4241 #else // defined(REINTERPRET_OUTPUT_AS_3D) 4244 dst_addr += z * dst_stride_z;
4246 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 4251 #endif // defined(ALPHA) 4255 #if defined(BROADCAST_BIAS) 4256 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
4267 #else // defined(BROADCAST_BIAS) 4279 #endif // defined(BROADCAST_BIAS) 4280 #endif // defined(BETA) 4282 #if defined(ACTIVATION_TYPE) 4284 #endif // defined(ACTIVATION_TYPE) 4286 const bool cond_y = y == 0;
4287 const bool cond_x = ((x + 1) * N0 >=
N);
4290 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, c, dst_addr, dst_stride_y, zout,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
4292 #undef RHS_BLOCK_SIZE 4296 #endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE) 4328 float4 alpha_ab = vload4(0, (__global
float *)dst.
ptr);
4331 float4 c = vload4(0, (__global
float *)src.
ptr);
4334 float4 out = alpha_ab + (float4)BETA * c;
4337 vstore4(out, 0, (__global
float *)dst.
ptr);
4340 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 4370 half8 alpha_ab = vload8(0, (__global
half *)dst.
ptr);
4373 half8 c = vload8(0, (__global
half *)src.
ptr);
4376 half8 out = alpha_ab + (half8)BETA * c;
4379 vstore8(out, 0, (__global
half *)dst.
ptr);
4381 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 4382 #endif // defined(BETA)
#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)
#define REPEAT_VAR_INIT_TO_CONST(N, TYPE, VAR, VAL)
half_float::half half
16-bit floating point type
#define IMAGE_DECLARATION(name)
#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS)
Structure to hold 3D tensor information.
SimpleTensor< float > src
#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#define ADD_BLOCK(N, BASENAME, BIAS)
#define LOAD_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)
#define LOAD_TEXTURE2D(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)
#define CALCULATE_Z_OFFSET(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)
#define CONVERT_TO_TENSOR3D_STRUCT(name)
#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord)
#define CONVERT_BLOCK(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)
#define SCALE_BLOCK(N, DATA_TYPE, BASENAME, SCALE)
__global uchar * ptr
Pointer to the starting postion of the buffer.
#define TENSOR3D_DECLARATION(name)
#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0)
#define VEC_DATA_TYPE(type, size)