27 #if defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 91 #
if defined(REINTERPRET_OUTPUT_AS_3D)
97 int x = get_global_id(0) / H0;
98 int y = get_global_id(1) / V0;
99 int z = get_global_id(2);
102 const int offset_row_a = (get_global_id(1) % V0) * 4;
103 const int offset_row_b = (get_global_id(0) % H0) * 4;
107 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
108 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
110 #if defined(MATRIX_B_DEPTH) 112 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
113 #else // defined(MATRIX_B_DEPTH) 114 src1_addr_in_bytes += z * src1_stride_z;
115 #endif // defined(MATRIX_B_DEPTH) 117 __global
float *src_addr_a = (__global
float *)(src0_ptr + src0_addr_in_bytes);
118 __global
float *src_addr_b = (__global
float *)(src1_ptr + src1_addr_in_bytes);
121 __global
float *src_end_addr_b = src_addr_b + (src1_stride_y /
sizeof(float));
123 src_addr_a += offset_row_a;
124 src_addr_b += offset_row_b;
132 for(; src_addr_b <= (src_end_addr_b - (int)(8 * H0)); src_addr_a += 8 * V0, src_addr_b += 8 * H0)
135 float4 a0 = vload4(0, src_addr_a);
136 float4 b0 = vload4(0, src_addr_b);
138 c0 += (float4)a0.s0 * b0;
139 c1 += (float4)a0.s1 * b0;
140 c2 += (float4)a0.s2 * b0;
141 c3 += (float4)a0.s3 * b0;
144 a0 = vload4(0, src_addr_a + 4 * V0);
145 b0 = vload4(0, src_addr_b + 4 * H0);
147 c0 += (float4)a0.s0 * b0;
148 c1 += (float4)a0.s1 * b0;
149 c2 += (float4)a0.s2 * b0;
150 c3 += (float4)a0.s3 * b0;
153 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 4 * H0)
156 float4 a0 = vload4(0, src_addr_a);
157 float4 b0 = vload4(0, src_addr_b);
159 c0 += (float4)a0.s0 * b0;
160 c1 += (float4)a0.s1 * b0;
161 c2 += (float4)a0.s2 * b0;
162 c3 += (float4)a0.s3 * b0;
169 __global uchar *dst_addr =
offset(&dst, 0, 0);
173 #if defined(REINTERPRET_OUTPUT_AS_3D) 190 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
191 zout = min(DEPTH_GEMM3D - 1, zout);
194 zout *= (cross_plane_pad * dst_stride_y);
198 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
199 #else // defined(REINTERPRET_OUTPUT_AS_3D) 201 dst_addr += z * dst_stride_z;
202 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 207 #endif // defined(ALPHA) 213 #if defined(BROADCAST_BIAS) 214 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float));
216 LOAD_BLOCK(1, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
225 #else // defined(BROADCAST_BIAS) 226 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
229 LOAD_BLOCK(4, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
238 #endif // defined(BROADCAST_BIAS) 239 #endif // defined(BETA) 241 #if defined(ACTIVATION_TYPE) 243 #endif // defined(ACTIVATION_TYPE) 246 const bool cond_y = ((get_global_id(1) + 1) * 4 >=
M);
247 const bool cond_x = ((get_global_id(0) + 1) * 4 >=
N);
248 STORE_BLOCK_BOUNDARY_AWARE(4, 4,
float, c, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
302 __kernel
void gemm_mm_interleaved_transposed_f32_bifrost(
IMAGE_DECLARATION(src0),
314 #
if defined(REINTERPRET_OUTPUT_AS_3D)
320 int x = get_global_id(0) / H0;
321 int y = get_global_id(1) / V0;
322 int z = get_global_id(2);
325 const int offset_row_a = (get_global_id(1) % V0) * 4;
326 const int offset_row_b = (get_global_id(0) % H0) * 4;
330 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
331 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
333 #if defined(MATRIX_B_DEPTH) 335 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
336 #else // defined(MATRIX_B_DEPTH) 337 src1_addr_in_bytes += z * src1_stride_z;
338 #endif // defined(MATRIX_B_DEPTH) 340 __global
float *src_addr_a = (__global
float *)(src0_ptr + src0_addr_in_bytes);
341 __global
float *src_addr_b = (__global
float *)(src1_ptr + src1_addr_in_bytes);
343 src_addr_a += offset_row_a;
344 src_addr_b += offset_row_b;
353 for(; i <= (int)(
K - 4); i += 4)
356 float4 a0 = vload4(0, src_addr_a);
357 float4 b0 = vload4(0, src_addr_b);
359 src_addr_a += 4 * V0;
360 src_addr_b += 4 * H0;
362 c0.s0 =
fma(a0.s0, b0.s0, c0.s0);
363 c0.s1 =
fma(a0.s0, b0.s1, c0.s1);
364 c0.s2 =
fma(a0.s0, b0.s2, c0.s2);
365 c0.s3 =
fma(a0.s0, b0.s3, c0.s3);
367 c1.s0 =
fma(a0.s1, b0.s0, c1.s0);
368 c1.s1 =
fma(a0.s1, b0.s1, c1.s1);
369 c1.s2 =
fma(a0.s1, b0.s2, c1.s2);
370 c1.s3 =
fma(a0.s1, b0.s3, c1.s3);
372 c2.s0 =
fma(a0.s2, b0.s0, c2.s0);
373 c2.s1 =
fma(a0.s2, b0.s1, c2.s1);
374 c2.s2 =
fma(a0.s2, b0.s2, c2.s2);
375 c2.s3 =
fma(a0.s2, b0.s3, c2.s3);
377 c3.s0 =
fma(a0.s3, b0.s0, c3.s0);
378 c3.s1 =
fma(a0.s3, b0.s1, c3.s1);
379 c3.s2 =
fma(a0.s3, b0.s2, c3.s2);
380 c3.s3 =
fma(a0.s3, b0.s3, c3.s3);
383 a0 = vload4(0, src_addr_a);
384 b0 = vload4(0, src_addr_b);
386 src_addr_a += 4 * V0;
387 src_addr_b += 4 * H0;
389 c0.s0 =
fma(a0.s0, b0.s0, c0.s0);
390 c0.s1 =
fma(a0.s0, b0.s1, c0.s1);
391 c0.s2 =
fma(a0.s0, b0.s2, c0.s2);
392 c0.s3 =
fma(a0.s0, b0.s3, c0.s3);
394 c1.s0 =
fma(a0.s1, b0.s0, c1.s0);
395 c1.s1 =
fma(a0.s1, b0.s1, c1.s1);
396 c1.s2 =
fma(a0.s1, b0.s2, c1.s2);
397 c1.s3 =
fma(a0.s1, b0.s3, c1.s3);
399 c2.s0 =
fma(a0.s2, b0.s0, c2.s0);
400 c2.s1 =
fma(a0.s2, b0.s1, c2.s1);
401 c2.s2 =
fma(a0.s2, b0.s2, c2.s2);
402 c2.s3 =
fma(a0.s2, b0.s3, c2.s3);
404 c3.s0 =
fma(a0.s3, b0.s0, c3.s0);
405 c3.s1 =
fma(a0.s3, b0.s1, c3.s1);
406 c3.s2 =
fma(a0.s3, b0.s2, c3.s2);
407 c3.s3 =
fma(a0.s3, b0.s3, c3.s3);
410 a0 = vload4(0, src_addr_a);
411 b0 = vload4(0, src_addr_b);
413 src_addr_a += 4 * V0;
414 src_addr_b += 4 * H0;
416 c0.s0 =
fma(a0.s0, b0.s0, c0.s0);
417 c0.s1 =
fma(a0.s0, b0.s1, c0.s1);
418 c0.s2 =
fma(a0.s0, b0.s2, c0.s2);
419 c0.s3 =
fma(a0.s0, b0.s3, c0.s3);
421 c1.s0 =
fma(a0.s1, b0.s0, c1.s0);
422 c1.s1 =
fma(a0.s1, b0.s1, c1.s1);
423 c1.s2 =
fma(a0.s1, b0.s2, c1.s2);
424 c1.s3 =
fma(a0.s1, b0.s3, c1.s3);
426 c2.s0 =
fma(a0.s2, b0.s0, c2.s0);
427 c2.s1 =
fma(a0.s2, b0.s1, c2.s1);
428 c2.s2 =
fma(a0.s2, b0.s2, c2.s2);
429 c2.s3 =
fma(a0.s2, b0.s3, c2.s3);
431 c3.s0 =
fma(a0.s3, b0.s0, c3.s0);
432 c3.s1 =
fma(a0.s3, b0.s1, c3.s1);
433 c3.s2 =
fma(a0.s3, b0.s2, c3.s2);
434 c3.s3 =
fma(a0.s3, b0.s3, c3.s3);
437 a0 = vload4(0, src_addr_a);
438 b0 = vload4(0, src_addr_b);
440 src_addr_a += 4 * V0;
441 src_addr_b += 4 * H0;
443 c0.s0 =
fma(a0.s0, b0.s0, c0.s0);
444 c0.s1 =
fma(a0.s0, b0.s1, c0.s1);
445 c0.s2 =
fma(a0.s0, b0.s2, c0.s2);
446 c0.s3 =
fma(a0.s0, b0.s3, c0.s3);
448 c1.s0 =
fma(a0.s1, b0.s0, c1.s0);
449 c1.s1 =
fma(a0.s1, b0.s1, c1.s1);
450 c1.s2 =
fma(a0.s1, b0.s2, c1.s2);
451 c1.s3 =
fma(a0.s1, b0.s3, c1.s3);
453 c2.s0 =
fma(a0.s2, b0.s0, c2.s0);
454 c2.s1 =
fma(a0.s2, b0.s1, c2.s1);
455 c2.s2 =
fma(a0.s2, b0.s2, c2.s2);
456 c2.s3 =
fma(a0.s2, b0.s3, c2.s3);
458 c3.s0 =
fma(a0.s3, b0.s0, c3.s0);
459 c3.s1 =
fma(a0.s3, b0.s1, c3.s1);
460 c3.s2 =
fma(a0.s3, b0.s2, c3.s2);
461 c3.s3 =
fma(a0.s3, b0.s3, c3.s3);
464 for(; i < (int)
K; ++i)
467 float4 a0 = vload4(0, src_addr_a);
468 float4 b0 = vload4(0, src_addr_b);
470 src_addr_a += 4 * V0;
471 src_addr_b += 4 * H0;
473 c0.s0 =
fma(a0.s0, b0.s0, c0.s0);
474 c0.s1 =
fma(a0.s0, b0.s1, c0.s1);
475 c0.s2 =
fma(a0.s0, b0.s2, c0.s2);
476 c0.s3 =
fma(a0.s0, b0.s3, c0.s3);
478 c1.s0 =
fma(a0.s1, b0.s0, c1.s0);
479 c1.s1 =
fma(a0.s1, b0.s1, c1.s1);
480 c1.s2 =
fma(a0.s1, b0.s2, c1.s2);
481 c1.s3 =
fma(a0.s1, b0.s3, c1.s3);
483 c2.s0 =
fma(a0.s2, b0.s0, c2.s0);
484 c2.s1 =
fma(a0.s2, b0.s1, c2.s1);
485 c2.s2 =
fma(a0.s2, b0.s2, c2.s2);
486 c2.s3 =
fma(a0.s2, b0.s3, c2.s3);
488 c3.s0 =
fma(a0.s3, b0.s0, c3.s0);
489 c3.s1 =
fma(a0.s3, b0.s1, c3.s1);
490 c3.s2 =
fma(a0.s3, b0.s2, c3.s2);
491 c3.s3 =
fma(a0.s3, b0.s3, c3.s3);
498 __global uchar *dst_addr =
offset(&dst, 0, 0);
502 #if defined(REINTERPRET_OUTPUT_AS_3D) 519 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
520 zout = min(DEPTH_GEMM3D - 1, zout);
523 zout *= (cross_plane_pad * dst_stride_y);
527 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
528 #else // defined(REINTERPRET_OUTPUT_AS_3D) 530 dst_addr += z * dst_stride_z;
531 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 536 #endif // defined(ALPHA) 542 #if defined(BROADCAST_BIAS) 543 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float));
545 LOAD_BLOCK(1, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
554 #else // defined(BROADCAST_BIAS) 555 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
558 LOAD_BLOCK(4, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
567 #endif // defined(BROADCAST_BIAS) 568 #endif // defined(BETA) 570 #if defined(ACTIVATION_TYPE) 572 #endif // defined(ACTIVATION_TYPE) 575 const bool cond_y = ((get_global_id(1) + 1) * 4 >=
M);
576 const bool cond_x = ((get_global_id(0) + 1) * 4 >=
N);
577 STORE_BLOCK_BOUNDARY_AWARE(4, 4,
float, c, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
580 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 644 #
if defined(REINTERPRET_OUTPUT_AS_3D)
650 int x = get_global_id(0) / H0;
651 int y = get_global_id(1) / V0;
652 int z = get_global_id(2);
655 const int offset_row_a = (get_global_id(1) % V0) * 4;
656 const int offset_row_b = (get_global_id(0) % H0) * 8;
660 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
661 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
663 #if defined(MATRIX_B_DEPTH) 665 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
666 #else // defined(MATRIX_B_DEPTH) 667 src1_addr_in_bytes += z * src1_stride_z;
668 #endif // defined(MATRIX_B_DEPTH) 670 __global
half *src_addr_a = (__global
half *)(src0_ptr + src0_addr_in_bytes);
671 __global
half *src_addr_b = (__global
half *)(src1_ptr + src1_addr_in_bytes);
674 __global
half *src_end_addr_b = src_addr_b + (src1_stride_y /
sizeof(
half));
676 src_addr_a += offset_row_a;
677 src_addr_b += offset_row_b;
685 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
688 half4 a0 = vload4(0, src_addr_a);
689 half8 b0 = vload8(0, src_addr_b);
691 c0 += (half8)a0.s0 * b0;
692 c1 += (half8)a0.s1 * b0;
693 c2 += (half8)a0.s2 * b0;
694 c3 += (half8)a0.s3 * b0;
697 a0 = vload4(0, src_addr_a + 4 * V0);
698 b0 = vload8(0, src_addr_b + 8 * H0);
700 c0 += (half8)a0.s0 * b0;
701 c1 += (half8)a0.s1 * b0;
702 c2 += (half8)a0.s2 * b0;
703 c3 += (half8)a0.s3 * b0;
706 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
709 half4 a0 = vload4(0, src_addr_a);
710 half8 b0 = vload8(0, src_addr_b);
712 c0 += (half8)a0.s0 * b0;
713 c1 += (half8)a0.s1 * b0;
714 c2 += (half8)a0.s2 * b0;
715 c3 += (half8)a0.s3 * b0;
722 __global uchar *dst_addr =
offset(&dst, 0, 0);
726 #if defined(REINTERPRET_OUTPUT_AS_3D) 743 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
744 zout = min(DEPTH_GEMM3D - 1, zout);
747 zout *= (cross_plane_pad * dst_stride_y);
751 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
752 #else // defined(REINTERPRET_OUTPUT_AS_3D) 754 dst_addr += z * dst_stride_z;
755 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 760 #endif // defined(ALPHA) 766 #if defined(BROADCAST_BIAS) 767 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half));
778 #else // defined(BROADCAST_BIAS) 780 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
792 #endif // defined(BROADCAST_BIAS) 793 #endif // defined(BETA) 795 #if defined(ACTIVATION_TYPE) 797 #endif // defined(ACTIVATION_TYPE) 800 const bool cond_y = ((get_global_id(1) + 1) * 4 >=
M);
801 const bool cond_x = ((get_global_id(0) + 1) * 8 >=
N);
802 STORE_BLOCK_BOUNDARY_AWARE(4, 8,
half, c, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
868 #
if defined(REINTERPRET_OUTPUT_AS_3D)
874 int x = get_global_id(0) / H0;
875 int y = get_global_id(1) / V0;
876 int z = get_global_id(2);
879 const int offset_row_a = (get_global_id(1) % V0) * 4;
880 const int offset_row_b = (get_global_id(0) % H0) * 8;
884 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
885 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
887 #if defined(MATRIX_B_DEPTH) 889 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
890 #else // defined(MATRIX_B_DEPTH) 891 src1_addr_in_bytes += z * src1_stride_z;
892 #endif // defined(MATRIX_B_DEPTH) 894 __global
half *src_addr_a = (__global
half *)(src0_ptr + src0_addr_in_bytes);
895 __global
half *src_addr_b = (__global
half *)(src1_ptr + src1_addr_in_bytes);
898 __global
half *src_end_addr_b = src_addr_b + (src1_stride_y /
sizeof(
half));
900 src_addr_a += offset_row_a;
901 src_addr_b += offset_row_b;
909 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
912 float4 a0 = convert_float4(vload4(0, src_addr_a));
913 float8 b0 = convert_float8(vload8(0, src_addr_b));
915 c0 += (float8)a0.s0 * b0;
916 c1 += (float8)a0.s1 * b0;
917 c2 += (float8)a0.s2 * b0;
918 c3 += (float8)a0.s3 * b0;
921 a0 = convert_float4(vload4(0, src_addr_a + 4 * V0));
922 b0 = convert_float8(vload8(0, src_addr_b + 8 * H0));
924 c0 += (float8)a0.s0 * b0;
925 c1 += (float8)a0.s1 * b0;
926 c2 += (float8)a0.s2 * b0;
927 c3 += (float8)a0.s3 * b0;
930 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
933 float4 a0 = convert_float4(vload4(0, src_addr_a));
934 float8 b0 = convert_float8(vload8(0, src_addr_b));
936 c0 += (float8)a0.s0 * b0;
937 c1 += (float8)a0.s1 * b0;
938 c2 += (float8)a0.s2 * b0;
939 c3 += (float8)a0.s3 * b0;
946 __global uchar *dst_addr =
offset(&dst, 0, 0);
950 #if defined(REINTERPRET_OUTPUT_AS_3D) 967 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
968 zout = min(DEPTH_GEMM3D - 1, zout);
971 zout *= (cross_plane_pad * dst_stride_y);
975 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
976 #else // defined(REINTERPRET_OUTPUT_AS_3D) 978 dst_addr += z * dst_stride_z;
979 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 984 #endif // defined(ALPHA) 989 #if defined(BROADCAST_BIAS) 990 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half));
994 float8 bias_f0 = convert_float8(bias0);
1003 #else // defined(BROADCAST_BIAS) 1004 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1009 float8 bias_f0 = convert_float8(bias0);
1010 float8 bias_f1 = convert_float8(bias1);
1011 float8 bias_f2 = convert_float8(bias2);
1012 float8 bias_f3 = convert_float8(bias3);
1021 #endif // defined(BROADCAST_BIAS) 1022 #endif // defined(BETA) 1024 half8 c_h0 = convert_half8(c0);
1025 half8 c_h1 = convert_half8(c1);
1026 half8 c_h2 = convert_half8(c2);
1027 half8 c_h3 = convert_half8(c3);
1029 #if defined(ACTIVATION_TYPE) 1031 #endif // defined(ACTIVATION_TYPE) 1034 const bool cond_y = ((get_global_id(1) + 1) * 4 >=
M);
1035 const bool cond_x = ((get_global_id(0) + 1) * 8 >=
N);
1036 STORE_BLOCK_BOUNDARY_AWARE(4, 8,
half, c_h, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1089 __kernel
void gemm_mm_interleaved_transposed_f16_bifrost(
IMAGE_DECLARATION(src0),
1101 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1103 uint cross_plane_pad
1107 int x = get_global_id(0) / H0;
1108 int y = get_global_id(1) / V0;
1109 int z = get_global_id(2);
1112 const int offset_row_a = (get_global_id(1) % V0) * 4;
1113 const int offset_row_b = (get_global_id(0) % H0) * 8;
1117 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1118 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1120 #if defined(MATRIX_B_DEPTH) 1122 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1123 #else // defined(MATRIX_B_DEPTH) 1124 src1_addr_in_bytes += z * src1_stride_z;
1125 #endif // defined(MATRIX_B_DEPTH) 1127 __global
half *src_addr_a = (__global
half *)(src0_ptr + src0_addr_in_bytes);
1128 __global
half *src_addr_b = (__global
half *)(src1_ptr + src1_addr_in_bytes);
1130 src_addr_a += offset_row_a;
1131 src_addr_b += offset_row_b;
1140 for(; i <= (int)(
K - 4); i += 4)
1144 half8 a0 = vload8(0, src_addr_a);
1145 half8 b0 = vload8(0, src_addr_b);
1147 src_addr_a += 8 * V0;
1148 src_addr_b += 8 * H0;
1150 c0 =
fma((half8)a0.s0, b0, c0);
1151 c1 =
fma((half8)a0.s1, b0, c1);
1152 c2 =
fma((half8)a0.s2, b0, c2);
1153 c3 =
fma((half8)a0.s3, b0, c3);
1156 b0 = vload8(0, src_addr_b);
1158 src_addr_b += 8 * H0;
1160 c0 =
fma((half8)a0.s4, b0, c0);
1161 c1 =
fma((half8)a0.s5, b0, c1);
1162 c2 =
fma((half8)a0.s6, b0, c2);
1163 c3 =
fma((half8)a0.s7, b0, c3);
1166 a0 = vload8(0, src_addr_a);
1167 b0 = vload8(0, src_addr_b);
1169 src_addr_a += 8 * V0;
1170 src_addr_b += 8 * H0;
1172 c0 =
fma((half8)a0.s0, b0, c0);
1173 c1 =
fma((half8)a0.s1, b0, c1);
1174 c2 =
fma((half8)a0.s2, b0, c2);
1175 c3 =
fma((half8)a0.s3, b0, c3);
1178 b0 = vload8(0, src_addr_b);
1180 src_addr_b += 8 * H0;
1182 c0 =
fma((half8)a0.s4, b0, c0);
1183 c1 =
fma((half8)a0.s5, b0, c1);
1184 c2 =
fma((half8)a0.s6, b0, c2);
1185 c3 =
fma((half8)a0.s7, b0, c3);
1188 half4 a0 = vload4(0, src_addr_a);
1189 half8 b0 = vload8(0, src_addr_b);
1191 src_addr_a += 4 * V0;
1192 src_addr_b += 8 * H0;
1194 c0 =
fma((half8)a0.s0, b0, c0);
1195 c1 =
fma((half8)a0.s1, b0, c1);
1196 c2 =
fma((half8)a0.s2, b0, c2);
1197 c3 =
fma((half8)a0.s3, b0, c3);
1200 a0 = vload4(0, src_addr_a);
1201 b0 = vload8(0, src_addr_b);
1203 src_addr_a += 4 * V0;
1204 src_addr_b += 8 * H0;
1206 c0 =
fma((half8)a0.s0, b0, c0);
1207 c1 =
fma((half8)a0.s1, b0, c1);
1208 c2 =
fma((half8)a0.s2, b0, c2);
1209 c3 =
fma((half8)a0.s3, b0, c3);
1212 a0 = vload4(0, src_addr_a);
1213 b0 = vload8(0, src_addr_b);
1215 src_addr_a += 4 * V0;
1216 src_addr_b += 8 * H0;
1218 c0 =
fma((half8)a0.s0, b0, c0);
1219 c1 =
fma((half8)a0.s1, b0, c1);
1220 c2 =
fma((half8)a0.s2, b0, c2);
1221 c3 =
fma((half8)a0.s3, b0, c3);
1224 a0 = vload4(0, src_addr_a);
1225 b0 = vload8(0, src_addr_b);
1227 src_addr_a += 4 * V0;
1228 src_addr_b += 8 * H0;
1230 c0 =
fma((half8)a0.s0, b0, c0);
1231 c1 =
fma((half8)a0.s1, b0, c1);
1232 c2 =
fma((half8)a0.s2, b0, c2);
1233 c3 =
fma((half8)a0.s3, b0, c3);
1237 for(; i < (int)
K; ++i)
1240 half4 a0 = vload4(0, src_addr_a);
1241 half8 b0 = vload8(0, src_addr_b);
1243 src_addr_a += 4 * V0;
1244 src_addr_b += 8 * H0;
1246 c0 =
fma((half8)a0.s0, b0, c0);
1247 c1 =
fma((half8)a0.s1, b0, c1);
1248 c2 =
fma((half8)a0.s2, b0, c2);
1249 c3 =
fma((half8)a0.s3, b0, c3);
1256 __global uchar *dst_addr =
offset(&dst, 0, 0);
1260 #if defined(REINTERPRET_OUTPUT_AS_3D) 1277 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1278 zout = min(DEPTH_GEMM3D - 1, zout);
1281 zout *= (cross_plane_pad * dst_stride_y);
1285 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1286 #else // defined(REINTERPRET_OUTPUT_AS_3D) 1288 dst_addr += z * dst_stride_z;
1289 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 1294 #endif // defined(ALPHA) 1300 #if defined(BROADCAST_BIAS) 1301 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half));
1312 #else // defined(BROADCAST_BIAS) 1313 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1325 #endif // defined(BROADCAST_BIAS) 1326 #endif // defined(BETA) 1328 #if defined(ACTIVATION_TYPE) 1330 #endif // defined(ACTIVATION_TYPE) 1333 const bool cond_y = ((get_global_id(1) + 1) * 4 >=
M);
1334 const bool cond_x = ((get_global_id(0) + 1) * 8 >=
N);
1335 STORE_BLOCK_BOUNDARY_AWARE(4, 8,
half, c, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1338 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 1340 #endif // defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 1342 #if defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 1343 #if defined(DATA_TYPE) 1344 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, N0) 1409 #
if defined(REINTERPRET_INPUT_AS_3D)
1411 uint src_cross_plane_pad
1413 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1415 uint dst_cross_plane_pad
1419 int idx = get_global_id(0) * N0;
1422 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1430 #if defined(REINTERPRET_INPUT_AS_3D) 1448 zin = min(DEPTH_GEMM3D - 1, zin);
1451 zin *= (src_cross_plane_pad * src0_stride_y);
1455 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1457 #else // defined(REINTERPRET_INPUT_AS_3D) 1460 src_addr.s0 += get_global_id(2) * src0_stride_z;
1462 #endif // defined(REINTERPRET_INPUT_AS_3D) 1464 #if defined(MATRIX_B_DEPTH) 1466 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1467 #else // defined(MATRIX_B_DEPTH) 1468 src_addr.s1 += get_global_id(2) * src1_stride_z;
1469 #endif // defined(MATRIX_B_DEPTH) 1471 int end_row_vec_a = src_addr.s0 + (
K *
sizeof(
DATA_TYPE));
1473 VECTOR_TYPE acc0 = 0.0f;
1475 VECTOR_TYPE acc1 = 0.0f;
1478 VECTOR_TYPE acc2 = 0.0f;
1481 VECTOR_TYPE acc3 = 0.0f;
1484 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)
sizeof(
DATA_TYPE)); src_addr += (int2)(2 *
sizeof(
DATA_TYPE), 2 * src1_stride_y))
1486 #if defined(REINTERPRET_INPUT_AS_3D) 1489 #else // defined(REINTERPRET_INPUT_AS_3D) 1492 a0 = vload2(0, (__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1495 a1 = vload2(0, (__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1499 a2 = vload2(0, (__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1503 a3 = vload2(0, (__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1505 #endif // defined(REINTERPRET_INPUT_AS_3D) 1508 VECTOR_TYPE b0 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(src1_ptr + src_addr.s1));
1509 VECTOR_TYPE b1 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
1512 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1513 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1515 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1516 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1519 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1520 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1523 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1524 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1528 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(
sizeof(
DATA_TYPE), src1_stride_y))
1530 #if defined(REINTERPRET_INPUT_AS_3D) 1532 DATA_TYPE a0 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1534 DATA_TYPE a1 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1537 DATA_TYPE a2 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1540 DATA_TYPE a3 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1542 #else // defined(REINTERPRET_INPUT_AS_3D) 1544 DATA_TYPE a0 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1546 DATA_TYPE a1 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1549 DATA_TYPE a2 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1552 DATA_TYPE a3 = *((__global
DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1554 #endif // defined(REINTERPRET_INPUT_AS_3D) 1557 VECTOR_TYPE b0 =
VLOAD(N0)(0, (__global
DATA_TYPE *)(src1_ptr + src_addr.s1));
1560 acc0 += b0 * (VECTOR_TYPE)a0;
1562 acc1 += b0 * (VECTOR_TYPE)a1;
1565 acc2 += b0 * (VECTOR_TYPE)a2;
1568 acc3 += b0 * (VECTOR_TYPE)a3;
1572 int z = get_global_id(2);
1575 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
1581 #if defined(REINTERPRET_OUTPUT_AS_3D) 1600 zout = min(DEPTH_GEMM3D - 1, zout);
1603 zout *= (dst_cross_plane_pad * dst_stride_y);
1607 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1608 #else // defined(REINTERPRET_OUTPUT_AS_3D) 1610 dst_addr += z * dst_stride_z;
1611 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 1616 #endif // defined(ALPHA) 1622 #if defined(BROADCAST_BIAS) 1623 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE));
1634 #else // defined(BROADCAST_BIAS) 1635 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 *
sizeof(
DATA_TYPE)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
1638 + z * src2_stride_z;
1649 #endif // defined(BROADCAST_BIAS) 1650 #endif // defined(BETA) 1652 #if defined(ACTIVATION_TYPE) 1654 #endif // defined(ACTIVATION_TYPE) 1657 const bool cond_y = get_global_id(1) == 0;
1658 const bool cond_x = ((get_global_id(0) + 1) * N0 >=
N);
1659 STORE_BLOCK_BOUNDARY_AWARE(M0, N0,
DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
1661 #endif // defined(DATA_TYPE) 1727 #
if defined(REINTERPRET_INPUT_AS_3D)
1729 uint src_cross_plane_pad
1731 #
if defined(REINTERPRET_OUTPUT_AS_3D)
1733 uint dst_cross_plane_pad
1737 int idx = get_global_id(0) * N0;
1740 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1746 src_addr.s1 += idx *
sizeof(float);
1748 #if defined(REINTERPRET_INPUT_AS_3D) 1766 zin = min(DEPTH_GEMM3D - 1, zin);
1769 zin *= (src_cross_plane_pad * src0_stride_y);
1773 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1775 #else // defined(REINTERPRET_INPUT_AS_3D) 1778 src_addr.s0 += get_global_id(2) * src0_stride_z;
1780 #endif // defined(REINTERPRET_INPUT_AS_3D) 1782 #if defined(MATRIX_B_DEPTH) 1784 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1785 #else // defined(MATRIX_B_DEPTH) 1786 src_addr.s1 += get_global_id(2) * src1_stride_z;
1787 #endif // defined(MATRIX_B_DEPTH) 1806 for(; i <= ((int)
K - 4); i += 4)
1808 #if defined(REINTERPRET_INPUT_AS_3D) 1810 LOAD_BLOCK(M0, 4,
float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
1811 #else // defined(REINTERPRET_INPUT_AS_3D) 1813 float4 a0 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1815 float4 a1 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1818 float4 a2 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1821 float4 a3 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1823 #endif // defined(REINTERPRET_INPUT_AS_3D) 1825 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1826 src_addr.s1 += src1_stride_y;
1829 acc0.s0 =
fma(a0.s0, b0.s0, acc0.s0);
1830 acc0.s1 =
fma(a0.s0, b0.s1, acc0.s1);
1831 acc0.s2 =
fma(a0.s0, b0.s2, acc0.s2);
1832 acc0.s3 =
fma(a0.s0, b0.s3, acc0.s3);
1836 acc1.s0 =
fma(a1.s0, b0.s0, acc1.s0);
1837 acc1.s1 =
fma(a1.s0, b0.s1, acc1.s1);
1838 acc1.s2 =
fma(a1.s0, b0.s2, acc1.s2);
1839 acc1.s3 =
fma(a1.s0, b0.s3, acc1.s3);
1844 acc2.s0 =
fma(a2.s0, b0.s0, acc2.s0);
1845 acc2.s1 =
fma(a2.s0, b0.s1, acc2.s1);
1846 acc2.s2 =
fma(a2.s0, b0.s2, acc2.s2);
1847 acc2.s3 =
fma(a2.s0, b0.s3, acc2.s3);
1852 acc3.s0 =
fma(a3.s0, b0.s0, acc3.s0);
1853 acc3.s1 =
fma(a3.s0, b0.s1, acc3.s1);
1854 acc3.s2 =
fma(a3.s0, b0.s2, acc3.s2);
1855 acc3.s3 =
fma(a3.s0, b0.s3, acc3.s3);
1859 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1860 src_addr.s1 += src1_stride_y;
1863 acc0.s0 =
fma(a0.s1, b0.s0, acc0.s0);
1864 acc0.s1 =
fma(a0.s1, b0.s1, acc0.s1);
1865 acc0.s2 =
fma(a0.s1, b0.s2, acc0.s2);
1866 acc0.s3 =
fma(a0.s1, b0.s3, acc0.s3);
1870 acc1.s0 =
fma(a1.s1, b0.s0, acc1.s0);
1871 acc1.s1 =
fma(a1.s1, b0.s1, acc1.s1);
1872 acc1.s2 =
fma(a1.s1, b0.s2, acc1.s2);
1873 acc1.s3 =
fma(a1.s1, b0.s3, acc1.s3);
1878 acc2.s0 =
fma(a2.s1, b0.s0, acc2.s0);
1879 acc2.s1 =
fma(a2.s1, b0.s1, acc2.s1);
1880 acc2.s2 =
fma(a2.s1, b0.s2, acc2.s2);
1881 acc2.s3 =
fma(a2.s1, b0.s3, acc2.s3);
1886 acc3.s0 =
fma(a3.s1, b0.s0, acc3.s0);
1887 acc3.s1 =
fma(a3.s1, b0.s1, acc3.s1);
1888 acc3.s2 =
fma(a3.s1, b0.s2, acc3.s2);
1889 acc3.s3 =
fma(a3.s1, b0.s3, acc3.s3);
1893 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1894 src_addr.s1 += src1_stride_y;
1897 acc0.s0 =
fma(a0.s2, b0.s0, acc0.s0);
1898 acc0.s1 =
fma(a0.s2, b0.s1, acc0.s1);
1899 acc0.s2 =
fma(a0.s2, b0.s2, acc0.s2);
1900 acc0.s3 =
fma(a0.s2, b0.s3, acc0.s3);
1904 acc1.s0 =
fma(a1.s2, b0.s0, acc1.s0);
1905 acc1.s1 =
fma(a1.s2, b0.s1, acc1.s1);
1906 acc1.s2 =
fma(a1.s2, b0.s2, acc1.s2);
1907 acc1.s3 =
fma(a1.s2, b0.s3, acc1.s3);
1912 acc2.s0 =
fma(a2.s2, b0.s0, acc2.s0);
1913 acc2.s1 =
fma(a2.s2, b0.s1, acc2.s1);
1914 acc2.s2 =
fma(a2.s2, b0.s2, acc2.s2);
1915 acc2.s3 =
fma(a2.s2, b0.s3, acc2.s3);
1920 acc3.s0 =
fma(a3.s2, b0.s0, acc3.s0);
1921 acc3.s1 =
fma(a3.s2, b0.s1, acc3.s1);
1922 acc3.s2 =
fma(a3.s2, b0.s2, acc3.s2);
1923 acc3.s3 =
fma(a3.s2, b0.s3, acc3.s3);
1927 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1928 src_addr.s1 += src1_stride_y;
1931 acc0.s0 =
fma(a0.s3, b0.s0, acc0.s0);
1932 acc0.s1 =
fma(a0.s3, b0.s1, acc0.s1);
1933 acc0.s2 =
fma(a0.s3, b0.s2, acc0.s2);
1934 acc0.s3 =
fma(a0.s3, b0.s3, acc0.s3);
1938 acc1.s0 =
fma(a1.s3, b0.s0, acc1.s0);
1939 acc1.s1 =
fma(a1.s3, b0.s1, acc1.s1);
1940 acc1.s2 =
fma(a1.s3, b0.s2, acc1.s2);
1941 acc1.s3 =
fma(a1.s3, b0.s3, acc1.s3);
1946 acc2.s0 =
fma(a2.s3, b0.s0, acc2.s0);
1947 acc2.s1 =
fma(a2.s3, b0.s1, acc2.s1);
1948 acc2.s2 =
fma(a2.s3, b0.s2, acc2.s2);
1949 acc2.s3 =
fma(a2.s3, b0.s3, acc2.s3);
1954 acc3.s0 =
fma(a3.s3, b0.s0, acc3.s0);
1955 acc3.s1 =
fma(a3.s3, b0.s1, acc3.s1);
1956 acc3.s2 =
fma(a3.s3, b0.s2, acc3.s2);
1957 acc3.s3 =
fma(a3.s3, b0.s3, acc3.s3);
1960 src_addr.s0 += 4 *
sizeof(float);
1963 for(; i < (int)
K; ++i)
1965 #if defined(REINTERPRET_INPUT_AS_3D) 1967 float a0 = *((__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1969 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1972 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1975 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1977 #else // defined(REINTERPRET_INPUT_AS_3D) 1979 float a0 = *((__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1981 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1984 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1987 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1989 #endif // defined(REINTERPRET_INPUT_AS_3D) 1992 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1993 src_addr.s1 += src1_stride_y;
1996 acc0.s0 =
fma(a0, b0.s0, acc0.s0);
1997 acc0.s1 =
fma(a0, b0.s1, acc0.s1);
1998 acc0.s2 =
fma(a0, b0.s2, acc0.s2);
1999 acc0.s3 =
fma(a0, b0.s3, acc0.s3);
2001 acc1.s0 =
fma(a1, b0.s0, acc1.s0);
2002 acc1.s1 =
fma(a1, b0.s1, acc1.s1);
2003 acc1.s2 =
fma(a1, b0.s2, acc1.s2);
2004 acc1.s3 =
fma(a1, b0.s3, acc1.s3);
2007 acc2.s0 =
fma(a2, b0.s0, acc2.s0);
2008 acc2.s1 =
fma(a2, b0.s1, acc2.s1);
2009 acc2.s2 =
fma(a2, b0.s2, acc2.s2);
2010 acc2.s3 =
fma(a2, b0.s3, acc2.s3);
2013 acc3.s0 =
fma(a3, b0.s0, acc3.s0);
2014 acc3.s1 =
fma(a3, b0.s1, acc3.s1);
2015 acc3.s2 =
fma(a3, b0.s2, acc3.s2);
2016 acc3.s3 =
fma(a3, b0.s3, acc3.s3);
2019 src_addr.s0 +=
sizeof(float);
2022 int z = get_global_id(2);
2025 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
2030 #if defined(REINTERPRET_OUTPUT_AS_3D) 2048 zout = min(DEPTH_GEMM3D - 1, zout);
2051 zout *= (dst_cross_plane_pad * dst_stride_y);
2055 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2056 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2058 dst_addr += z * dst_stride_z;
2059 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2064 #endif // defined(ALPHA) 2070 #if defined(BROADCAST_BIAS) 2071 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float));
2073 LOAD_BLOCK(1, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
2082 #else // defined(BROADCAST_BIAS) 2083 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 *
sizeof(
float)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
2086 + z * src2_stride_z;
2088 LOAD_BLOCK(M0, 4,
float, bias, src2_addr, 0, src2_stride_y, zero);
2097 #endif // defined(BROADCAST_BIAS) 2098 #endif // defined(BETA) 2100 #if defined(ACTIVATION_TYPE) 2102 #endif // defined(ACTIVATION_TYPE) 2105 const bool cond_y = get_global_id(1) == 0;
2106 const bool cond_x = ((get_global_id(0) + 1) * 4 >=
N);
2107 STORE_BLOCK_BOUNDARY_AWARE(M0, 4,
float, acc, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2175 #
if defined(REINTERPRET_INPUT_AS_3D)
2177 uint src_cross_plane_pad
2179 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2181 uint dst_cross_plane_pad
2186 int idx = get_global_id(0) * N0;
2189 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2195 src_addr.s1 += idx *
sizeof(float);
2197 #if defined(REINTERPRET_INPUT_AS_3D) 2215 zin = min(DEPTH_GEMM3D - 1, zin);
2218 zin *= (src_cross_plane_pad * src0_stride_y);
2222 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2224 #else // defined(REINTERPRET_INPUT_AS_3D) 2227 src_addr.s0 += get_global_id(2) * src0_stride_z;
2229 #endif // defined(REINTERPRET_INPUT_AS_3D) 2231 #if defined(MATRIX_B_DEPTH) 2233 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2234 #else // defined(MATRIX_B_DEPTH) 2235 src_addr.s1 += get_global_id(2) * src1_stride_z;
2236 #endif // defined(MATRIX_B_DEPTH) 2252 for(; i <= ((int)
K - 8); i += 8)
2254 #if defined(REINTERPRET_INPUT_AS_3D) 2256 float8 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + zin.s0));
2257 #else // defined(REINTERPRET_INPUT_AS_3D) 2259 float8 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0));
2260 #endif // defined(REINTERPRET_INPUT_AS_3D) 2263 float2 b0 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2264 src_addr.s1 += src1_stride_y;
2265 float2 b1 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2266 src_addr.s1 += src1_stride_y;
2267 float2 b2 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2268 src_addr.s1 += src1_stride_y;
2269 float2 b3 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2270 src_addr.s1 += src1_stride_y;
2271 float2 b4 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2272 src_addr.s1 += src1_stride_y;
2273 float2 b5 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2274 src_addr.s1 += src1_stride_y;
2275 float2 b6 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2276 src_addr.s1 += src1_stride_y;
2277 float2 b7 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2278 src_addr.s1 += src1_stride_y;
2281 acc0.s0 =
fma(a0.s0, b0.s0, acc0.s0);
2282 acc0.s0 =
fma(a0.s1, b1.s0, acc0.s0);
2283 acc0.s0 =
fma(a0.s2, b2.s0, acc0.s0);
2284 acc0.s0 =
fma(a0.s3, b3.s0, acc0.s0);
2285 acc0.s0 =
fma(a0.s4, b4.s0, acc0.s0);
2286 acc0.s0 =
fma(a0.s5, b5.s0, acc0.s0);
2287 acc0.s0 =
fma(a0.s6, b6.s0, acc0.s0);
2288 acc0.s0 =
fma(a0.s7, b7.s0, acc0.s0);
2290 acc0.s1 =
fma(a0.s0, b0.s1, acc0.s1);
2291 acc0.s1 =
fma(a0.s1, b1.s1, acc0.s1);
2292 acc0.s1 =
fma(a0.s2, b2.s1, acc0.s1);
2293 acc0.s1 =
fma(a0.s3, b3.s1, acc0.s1);
2294 acc0.s1 =
fma(a0.s4, b4.s1, acc0.s1);
2295 acc0.s1 =
fma(a0.s5, b5.s1, acc0.s1);
2296 acc0.s1 =
fma(a0.s6, b6.s1, acc0.s1);
2297 acc0.s1 =
fma(a0.s7, b7.s1, acc0.s1);
2300 #if defined(REINTERPRET_INPUT_AS_3D) 2301 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2302 #else // defined(REINTERPRET_INPUT_AS_3D) 2303 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2304 #endif // defined(REINTERPRET_INPUT_AS_3D) 2305 acc1.s0 =
fma(a0.s0, b0.s0, acc1.s0);
2306 acc1.s0 =
fma(a0.s1, b1.s0, acc1.s0);
2307 acc1.s0 =
fma(a0.s2, b2.s0, acc1.s0);
2308 acc1.s0 =
fma(a0.s3, b3.s0, acc1.s0);
2309 acc1.s0 =
fma(a0.s4, b4.s0, acc1.s0);
2310 acc1.s0 =
fma(a0.s5, b5.s0, acc1.s0);
2311 acc1.s0 =
fma(a0.s6, b6.s0, acc1.s0);
2312 acc1.s0 =
fma(a0.s7, b7.s0, acc1.s0);
2314 acc1.s1 =
fma(a0.s0, b0.s1, acc1.s1);
2315 acc1.s1 =
fma(a0.s1, b1.s1, acc1.s1);
2316 acc1.s1 =
fma(a0.s2, b2.s1, acc1.s1);
2317 acc1.s1 =
fma(a0.s3, b3.s1, acc1.s1);
2318 acc1.s1 =
fma(a0.s4, b4.s1, acc1.s1);
2319 acc1.s1 =
fma(a0.s5, b5.s1, acc1.s1);
2320 acc1.s1 =
fma(a0.s6, b6.s1, acc1.s1);
2321 acc1.s1 =
fma(a0.s7, b7.s1, acc1.s1);
2324 #if defined(REINTERPRET_INPUT_AS_3D) 2325 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2326 #else // defined(REINTERPRET_INPUT_AS_3D) 2327 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2328 #endif // defined(REINTERPRET_INPUT_AS_3D) 2329 acc2.s0 =
fma(a0.s0, b0.s0, acc2.s0);
2330 acc2.s0 =
fma(a0.s1, b1.s0, acc2.s0);
2331 acc2.s0 =
fma(a0.s2, b2.s0, acc2.s0);
2332 acc2.s0 =
fma(a0.s3, b3.s0, acc2.s0);
2333 acc2.s0 =
fma(a0.s4, b4.s0, acc2.s0);
2334 acc2.s0 =
fma(a0.s5, b5.s0, acc2.s0);
2335 acc2.s0 =
fma(a0.s6, b6.s0, acc2.s0);
2336 acc2.s0 =
fma(a0.s7, b7.s0, acc2.s0);
2338 acc2.s1 =
fma(a0.s0, b0.s1, acc2.s1);
2339 acc2.s1 =
fma(a0.s1, b1.s1, acc2.s1);
2340 acc2.s1 =
fma(a0.s2, b2.s1, acc2.s1);
2341 acc2.s1 =
fma(a0.s3, b3.s1, acc2.s1);
2342 acc2.s1 =
fma(a0.s4, b4.s1, acc2.s1);
2343 acc2.s1 =
fma(a0.s5, b5.s1, acc2.s1);
2344 acc2.s1 =
fma(a0.s6, b6.s1, acc2.s1);
2345 acc2.s1 =
fma(a0.s7, b7.s1, acc2.s1);
2348 #if defined(REINTERPRET_INPUT_AS_3D) 2349 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2350 #else // defined(REINTERPRET_INPUT_AS_3D) 2351 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2352 #endif // defined(REINTERPRET_INPUT_AS_3D) 2353 acc3.s0 =
fma(a0.s0, b0.s0, acc3.s0);
2354 acc3.s0 =
fma(a0.s1, b1.s0, acc3.s0);
2355 acc3.s0 =
fma(a0.s2, b2.s0, acc3.s0);
2356 acc3.s0 =
fma(a0.s3, b3.s0, acc3.s0);
2357 acc3.s0 =
fma(a0.s4, b4.s0, acc3.s0);
2358 acc3.s0 =
fma(a0.s5, b5.s0, acc3.s0);
2359 acc3.s0 =
fma(a0.s6, b6.s0, acc3.s0);
2360 acc3.s0 =
fma(a0.s7, b7.s0, acc3.s0);
2362 acc3.s1 =
fma(a0.s0, b0.s1, acc3.s1);
2363 acc3.s1 =
fma(a0.s1, b1.s1, acc3.s1);
2364 acc3.s1 =
fma(a0.s2, b2.s1, acc3.s1);
2365 acc3.s1 =
fma(a0.s3, b3.s1, acc3.s1);
2366 acc3.s1 =
fma(a0.s4, b4.s1, acc3.s1);
2367 acc3.s1 =
fma(a0.s5, b5.s1, acc3.s1);
2368 acc3.s1 =
fma(a0.s6, b6.s1, acc3.s1);
2369 acc3.s1 =
fma(a0.s7, b7.s1, acc3.s1);
2372 src_addr.s0 +=
sizeof(float) * 8;
2375 for(; i < (int)
K; ++i)
2377 #if defined(REINTERPRET_INPUT_AS_3D) 2379 float a0 = *((__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2381 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2384 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2387 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2389 #else // defined(REINTERPRET_INPUT_AS_3D) 2391 float a0 = *((__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2393 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2396 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2399 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2401 #endif // defined(REINTERPRET_INPUT_AS_3D) 2404 float2 b0 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
2405 src_addr.s1 += src1_stride_y;
2408 acc0.s0 =
fma(a0, b0.s0, acc0.s0);
2409 acc0.s1 =
fma(a0, b0.s1, acc0.s1);
2411 acc1.s0 =
fma(a1, b0.s0, acc1.s0);
2412 acc1.s1 =
fma(a1, b0.s1, acc1.s1);
2415 acc2.s0 =
fma(a2, b0.s0, acc2.s0);
2416 acc2.s1 =
fma(a2, b0.s1, acc2.s1);
2419 acc3.s0 =
fma(a3, b0.s0, acc3.s0);
2420 acc3.s1 =
fma(a3, b0.s1, acc3.s1);
2423 src_addr.s0 +=
sizeof(float);
2426 int z = get_global_id(2);
2429 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 *
sizeof(
float)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
2434 #if defined(REINTERPRET_OUTPUT_AS_3D) 2453 zout = min(DEPTH_GEMM3D - 1, zout);
2456 zout *= (dst_cross_plane_pad * dst_stride_y);
2460 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2461 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2463 dst_addr += z * dst_stride_z;
2464 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2469 #endif // defined(ALPHA) 2475 #if defined(BROADCAST_BIAS) 2476 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 *
sizeof(
float));
2478 LOAD_BLOCK(1, 2,
float, bias, src2_addr, 0, src2_stride_y, zero);
2487 #else // defined(BROADCAST_BIAS) 2488 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 *
sizeof(
float)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
2491 + z * src2_stride_z;
2493 LOAD_BLOCK(M0, 2,
float, bias, src2_addr, 0, src2_stride_y, zero);
2502 #endif // defined(BROADCAST_BIAS) 2503 #endif // defined(BETA) 2505 #if defined(ACTIVATION_TYPE) 2507 #endif // defined(ACTIVATION_TYPE) 2510 const bool cond_y = get_global_id(1) == 0;
2511 const bool cond_x = ((get_global_id(0) + 1) * 2 >=
N);
2512 STORE_BLOCK_BOUNDARY_AWARE(M0, 2,
float, acc, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2515 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 2580 #
if defined(REINTERPRET_INPUT_AS_3D)
2582 uint src_cross_plane_pad
2584 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2586 uint dst_cross_plane_pad
2590 int idx = get_global_id(0) * N0;
2593 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2599 src_addr.s1 += idx *
sizeof(
half);
2601 #if defined(REINTERPRET_INPUT_AS_3D) 2619 zin = min(DEPTH_GEMM3D - 1, zin);
2622 zin *= (src_cross_plane_pad * src0_stride_y);
2626 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2628 #else // defined(REINTERPRET_INPUT_AS_3D) 2631 src_addr.s0 += get_global_id(2) * src0_stride_z;
2633 #endif // defined(REINTERPRET_INPUT_AS_3D) 2635 #if defined(MATRIX_B_DEPTH) 2637 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2638 #else // defined(MATRIX_B_DEPTH) 2639 src_addr.s1 += get_global_id(2) * src1_stride_z;
2640 #endif // defined(MATRIX_B_DEPTH) 2654 for(; i <= ((int)
K - 4); i += 4)
2656 #if defined(REINTERPRET_INPUT_AS_3D) 2658 LOAD_BLOCK(M0, 4,
half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
2659 #else // defined(REINTERPRET_INPUT_AS_3D) 2661 half4 a0 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2663 half4 a1 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2666 half4 a2 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2669 half4 a3 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2671 #endif // defined(REINTERPRET_INPUT_AS_3D) 2674 float8 b0 = convert_float8(vload8(0, (__global
half *)(src1_ptr + src_addr.s1)));
2675 src_addr.s1 += src1_stride_y;
2678 acc0 =
fma(b0, (float8)a0.s0, acc0);
2680 acc1 =
fma(b0, (float8)a1.s0, acc1);
2683 acc2 =
fma(b0, (float8)a2.s0, acc2);
2686 acc3 =
fma(b0, (float8)a3.s0, acc3);
2689 b0 = convert_float8(vload8(0, (__global
half *)(src1_ptr + src_addr.s1)));
2690 src_addr.s1 += src1_stride_y;
2691 acc0 =
fma(b0, (float8)a0.s1, acc0);
2693 acc1 =
fma(b0, (float8)a1.s1, acc1);
2696 acc2 =
fma(b0, (float8)a2.s1, acc2);
2699 acc3 =
fma(b0, (float8)a3.s1, acc3);
2702 b0 = convert_float8(vload8(0, (__global
half *)(src1_ptr + src_addr.s1)));
2703 src_addr.s1 += src1_stride_y;
2704 acc0 =
fma(b0, (float8)a0.s2, acc0);
2706 acc1 =
fma(b0, (float8)a1.s2, acc1);
2709 acc2 =
fma(b0, (float8)a2.s2, acc2);
2712 acc3 =
fma(b0, (float8)a3.s2, acc3);
2715 b0 = convert_float8(vload8(0, (__global
half *)(src1_ptr + src_addr.s1)));
2716 src_addr.s1 += src1_stride_y;
2717 acc0 =
fma(b0, (float8)a0.s3, acc0);
2719 acc1 =
fma(b0, (float8)a1.s3, acc1);
2722 acc2 =
fma(b0, (float8)a2.s3, acc2);
2725 acc3 =
fma(b0, (float8)a3.s3, acc3);
2728 src_addr.s0 += 4 *
sizeof(
half);
2731 for(; i < (int)
K; ++i)
2733 #if defined(REINTERPRET_INPUT_AS_3D) 2735 half a0 = *((__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2737 half a1 = *((__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2740 half a2 = *((__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2743 half a3 = *((__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2745 #else // defined(REINTERPRET_INPUT_AS_3D) 2747 half a0 = *((__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2749 half a1 = *((__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2752 half a2 = *((__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2755 half a3 = *((__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2757 #endif // defined(REINTERPRET_INPUT_AS_3D) 2760 float8 b0 = convert_float8(vload8(0, (__global
half *)(src1_ptr + src_addr.s1)));
2762 src_addr += (int2)(
sizeof(
half), src1_stride_y);
2765 acc0 =
fma(b0, (float8)a0, acc0);
2767 acc1 =
fma(b0, (float8)a1, acc1);
2770 acc2 =
fma(b0, (float8)a2, acc2);
2773 acc3 =
fma(b0, (float8)a3, acc3);
2777 int z = get_global_id(2);
2780 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
PARTIAL_STORE_M0) * dst_stride_y);
2784 #if defined(REINTERPRET_OUTPUT_AS_3D) 2803 zout = min(DEPTH_GEMM3D - 1, zout);
2806 zout *= (dst_cross_plane_pad * dst_stride_y);
2810 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2811 #else // defined(REINTERPRET_OUTPUT_AS_3D) 2813 dst_addr += z * dst_stride_z;
2814 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 2819 #endif // defined(ALPHA) 2824 #if defined(BROADCAST_BIAS) 2825 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half));
2829 float8 bias_f0 = convert_float8(bias0);
2838 #else // defined(BROADCAST_BIAS) 2839 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
2842 + z * src2_stride_z;
2844 LOAD_BLOCK(M0, 8,
half, bias, src2_addr, 0, src2_stride_y, zero);
2846 float8 bias_f0 = convert_float8(bias0);
2848 float8 bias_f1 = convert_float8(bias1);
2851 float8 bias_f2 = convert_float8(bias2);
2854 float8 bias_f3 = convert_float8(bias3);
2864 #endif // defined(BROADCAST_BIAS) 2865 #endif // defined(BETA) 2867 half8 acc_h0 = convert_half8(acc0);
2869 half8 acc_h1 = convert_half8(acc1);
2872 half8 acc_h2 = convert_half8(acc2);
2875 half8 acc_h3 = convert_half8(acc3);
2878 #if defined(ACTIVATION_TYPE) 2880 #endif // defined(ACTIVATION_TYPE) 2883 const bool cond_y = get_global_id(1) == 0;
2884 const bool cond_x = ((get_global_id(0) + 1) * 8 >=
N);
2885 STORE_BLOCK_BOUNDARY_AWARE(M0, 8,
half, acc_h, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
2952 #
if defined(REINTERPRET_INPUT_AS_3D)
2954 uint src_cross_plane_pad
2956 #
if defined(REINTERPRET_OUTPUT_AS_3D)
2958 uint dst_cross_plane_pad
2962 int idx = get_global_id(0) * N0;
2965 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2971 src_addr.s1 += idx *
sizeof(
half);
2973 #if defined(REINTERPRET_INPUT_AS_3D) 2991 zin = min(DEPTH_GEMM3D - 1, zin);
2994 zin *= (src_cross_plane_pad * src0_stride_y);
2998 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3000 #else // defined(REINTERPRET_INPUT_AS_3D) 3003 src_addr.s0 += get_global_id(2) * src0_stride_z;
3005 #endif // defined(REINTERPRET_INPUT_AS_3D) 3007 #if defined(MATRIX_B_DEPTH) 3009 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3010 #else // defined(MATRIX_B_DEPTH) 3011 src_addr.s1 += get_global_id(2) * src1_stride_z;
3012 #endif // defined(MATRIX_B_DEPTH) 3026 for(; i <= ((int)
K - 4); i += 4)
3028 #if defined(REINTERPRET_INPUT_AS_3D) 3030 LOAD_BLOCK(M0, 4,
half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3031 #else // defined(REINTERPRET_INPUT_AS_3D) 3033 half4 a0 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3035 half4 a1 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3038 half4 a2 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3041 half4 a3 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3043 #endif // defined(REINTERPRET_INPUT_AS_3D) 3046 half8 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
3047 src_addr.s1 += src1_stride_y;
3050 acc0 =
fma(b0, (half8)a0.s0, acc0);
3052 acc1 =
fma(b0, (half8)a1.s0, acc1);
3055 acc2 =
fma(b0, (half8)a2.s0, acc2);
3058 acc3 =
fma(b0, (half8)a3.s0, acc3);
3061 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
3062 src_addr.s1 += src1_stride_y;
3063 acc0 =
fma(b0, (half8)a0.s1, acc0);
3065 acc1 =
fma(b0, (half8)a1.s1, acc1);
3068 acc2 =
fma(b0, (half8)a2.s1, acc2);
3071 acc3 =
fma(b0, (half8)a3.s1, acc3);
3074 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
3075 src_addr.s1 += src1_stride_y;
3076 acc0 =
fma(b0, (half8)a0.s2, acc0);
3078 acc1 =
fma(b0, (half8)a1.s2, acc1);
3081 acc2 =
fma(b0, (half8)a2.s2, acc2);
3084 acc3 =
fma(b0, (half8)a3.s2, acc3);
3087 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
3088 src_addr.s1 += src1_stride_y;
3089 acc0 =
fma(b0, (half8)a0.s3, acc0);
3091 acc1 =
fma(b0, (half8)a1.s3, acc1);
3094 acc2 =
fma(b0, (half8)a2.s3, acc2);
3097 acc3 =
fma(b0, (half8)a3.s3, acc3);
3100 src_addr.s0 += 4 *
sizeof(
half);
3103 for(; i < (int)
K; ++i)
3105 #if defined(REINTERPRET_INPUT_AS_3D) 3107 half a0 = *((__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3109 half a1 = *((__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3112 half a2 = *((__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3115 half a3 = *((__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3117 #else // defined(REINTERPRET_INPUT_AS_3D) 3119 half a0 = *((__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3121 half a1 = *((__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3124 half a2 = *((__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3127 half a3 = *((__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3129 #endif // defined(REINTERPRET_INPUT_AS_3D) 3132 half8 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
3134 src_addr += (int2)(
sizeof(
half), src1_stride_y);
3137 acc0 =
fma(b0, (half8)a0, acc0);
3139 acc1 =
fma(b0, (half8)a1, acc1);
3142 acc2 =
fma(b0, (half8)a2, acc2);
3145 acc3 =
fma(b0, (half8)a3, acc3);
3149 int z = get_global_id(2);
3152 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
PARTIAL_STORE_M0) * dst_stride_y);
3156 #if defined(REINTERPRET_OUTPUT_AS_3D) 3175 zout = min(DEPTH_GEMM3D - 1, zout);
3178 zout *= (dst_cross_plane_pad * dst_stride_y);
3182 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3183 #else // defined(REINTERPRET_OUTPUT_AS_3D) 3185 dst_addr += z * dst_stride_z;
3186 #endif // defined(REINTERPRET_OUTPUT_AS_3D) 3191 #endif // defined(ALPHA) 3197 #if defined(BROADCAST_BIAS) 3198 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half));
3209 #else // defined(BROADCAST_BIAS) 3210 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 *
sizeof(
half)) + (
COMPUTE_M0_START_ROW(get_global_id(1), M0,
3213 + z * src2_stride_z;
3215 LOAD_BLOCK(M0, 8,
half, bias, src2_addr, 0, src2_stride_y, zero);
3224 #endif // defined(BROADCAST_BIAS) 3225 #endif // defined(BETA) 3227 #if defined(ACTIVATION_TYPE) 3229 #endif // defined(ACTIVATION_TYPE) 3232 const bool cond_y = get_global_id(1) == 0;
3233 const bool cond_x = ((get_global_id(0) + 1) * 8 >=
N);
3234 STORE_BLOCK_BOUNDARY_AWARE(M0, 8,
half, acc, dst_addr, dst_stride_y, zout.s,
PARTIAL_STORE_M0,
PARTIAL_STORE_N0, cond_y, cond_x);
3236 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 3238 #endif // defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)
__global uchar * offset(const Image *img, int x, int y)
Get the pointer position of a Image.
#define REPEAT_VAR_INIT_TO_CONST(N, TYPE, VAR, VAL)
#define CONVERT_TO_IMAGE_STRUCT(name)
half_float::half half
16-bit floating point type
#define IMAGE_DECLARATION(name)
#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS)
#define ADD_BLOCK(N, BASENAME, BIAS)
#define LOAD_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)
Structure to hold Image information.
T fma(T x, T y, T z)
Computes (x*y) + z as if to infinite precision and rounded only once to fit the result type...
#define SCALE_BLOCK(N, DATA_TYPE, BASENAME, SCALE)
#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0)
#define VEC_DATA_TYPE(type, size)