28 #if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL)
79 __kernel
void gemm_mm_reshaped_only_rhs_nt_mmul(
90 #define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0)
92 uint x0 = get_global_id(0);
93 uint y0 = get_global_id(1);
94 uint z = get_global_id(2);
101 uint block_x = thread_id % MMUL_N0;
102 uint block_y = (thread_id / MMUL_M0);
105 uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(
N - 1));
106 uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(
M - M0));
112 uint lhs_x = block_x;
116 uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0;
117 uint rhs_y = block_id;
120 #ifdef REINTERPRET_INPUT_AS_3D
121 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + (lhs_y + z *
M) * lhs_stride_y;
122 #else // REINTERPRET_INPUT_AS_3D
123 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
124 #endif // REINTERPRET_INPUT_AS_3D
127 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
129 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y;
130 #endif // BATCHED_RHS
132 #ifdef REINTERPRET_OUTPUT_AS_3D
133 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + (dst_y + z *
M) * dst_stride_y;
134 #else // REINTERPRET_OUTPUT_AS_3D
135 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
136 #endif // REINTERPRET_OUTPUT_AS_3D
143 TILE(
float, M0, N0, c_f32);
145 #if !defined(HALF_PRECISION)
147 #endif // !defined(HALF_PRECISION)
154 for(
int k = 0; k <=
K - MMUL_K0; k += MMUL_K0)
156 TILE(DATA_TYPE, M0, 1, a);
157 TILE(DATA_TYPE, 1, N0,
b);
160 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
161 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, 0,
b);
167 c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0],
b[0].s[n0], c_f32[m0].s[n0]);
171 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
172 rhs_offset_first_element_in_bytes += MMUL_K0 * MMUL_N0 * N0 *
sizeof(DATA_TYPE);
175 if(block_x * N0 + block_id * MMUL_N0 * N0 >=
N)
180 if(block_y * M0 + y0 * M0 * MMUL_M0 >=
M)
185 #if defined(HALF_PRECISION)
186 TILE(DATA_TYPE, M0, N0, c);
193 c[m0].s[n0] = c_f32[m0].s[n0];
201 #endif // defined(ALPHA)
205 #if defined(BROADCAST_BIAS)
206 bia_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE);
208 TILE(DATA_TYPE, 1, N0, bias0);
210 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
212 bias0[0].v =
VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
217 (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
226 #else // defined(BROADCAST_BIAS)
227 TILE(DATA_TYPE, M0, N0, bias0);
229 bia_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z;
231 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
235 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
237 bias0[m0].v =
VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
245 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
248 (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
258 T_ADD(DATA_TYPE, M0, N0, c, bias0, c);
260 #endif // defined(BROADCAST_BIAS)
261 #endif // defined(BETA)
263 T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c);
266 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
270 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
273 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
281 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
284 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
289 #undef RHS_BLOCK_SIZE
293 #endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL)
295 #if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL_TEXTURE)
346 __kernel
void gemm_mm_reshaped_only_rhs_nt_mmul_texture(
357 #define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0)
359 uint x0 = get_global_id(0);
360 uint y0 = get_global_id(1);
361 uint z = get_global_id(2);
368 uint block_x = thread_id % MMUL_N0;
369 uint block_y = (thread_id / MMUL_M0);
372 uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(
N - 1));
373 uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(
M - M0));
379 uint lhs_x = block_x;
383 uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0;
386 uint rhs_y = block_id + z * rhs_h;
388 uint rhs_y = block_id;
389 #endif // BATCHED_RHS
392 #ifdef REINTERPRET_INPUT_AS_3D
393 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + (lhs_y + z *
M) * lhs_stride_y;
394 #else // REINTERPRET_INPUT_AS_3D
395 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
396 #endif // REINTERPRET_INPUT_AS_3D
398 #ifdef REINTERPRET_OUTPUT_AS_3D
399 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + (dst_y + z *
M) * dst_stride_y;
400 #else // REINTERPRET_OUTPUT_AS_3D
401 dst_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
402 #endif // REINTERPRET_OUTPUT_AS_3D
406 TILE(
float, M0, N0, c_f32);
408 #if !defined(HALF_PRECISION)
410 #endif // !defined(HALF_PRECISION)
417 for(
int k = 0; k <=
K - MMUL_K0; k += MMUL_K0)
419 TILE(DATA_TYPE, M0, 1, a);
420 TILE(DATA_TYPE, 1, N0,
b);
423 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
424 T_LOAD(DATA_TYPE, 1, N0, IMAGE, rhs, rhs_x, rhs_y, 1, rhs_stride_y,
b);
430 c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0],
b[0].s[n0], c_f32[m0].s[n0]);
434 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
435 rhs_x += MMUL_K0 * MMUL_N0 * N0;
438 if(block_x * N0 + block_id * MMUL_N0 * N0 >=
N)
443 if(block_y * M0 + y0 * M0 * MMUL_M0 >=
M)
448 #if defined(HALF_PRECISION)
449 TILE(DATA_TYPE, M0, N0, c);
456 c[m0].s[n0] = c_f32[m0].s[n0];
464 #endif // defined(ALPHA)
468 #if defined(BROADCAST_BIAS)
469 bia_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE);
471 TILE(DATA_TYPE, 1, N0, bias0);
473 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
475 bias0[0].v =
VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
480 (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
489 #else // defined(BROADCAST_BIAS)
490 TILE(DATA_TYPE, M0, N0, bias0);
492 bia_offset_first_element_in_bytes += dst_x *
sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z;
494 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
498 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
500 bias0[m0].v =
VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
508 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
511 (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
521 T_ADD(DATA_TYPE, M0, N0, c, bias0, c);
523 #endif // defined(BROADCAST_BIAS)
524 #endif // defined(BETA)
526 T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c);
529 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
533 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
536 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
544 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
547 (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
552 #undef RHS_BLOCK_SIZE
556 #endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE)