27 #if defined(GEMMLOWP_MM_RESHAPED_ONLY_RHS_MMUL)
90 __kernel
void gemmlowp_mm_reshaped_only_rhs_mmul(
100 #
if defined(A_OFFSET)
104 #
if defined(B_OFFSET)
110 #define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_M0)
111 #define VEC_SIZE 4 // For int8 types input to mmul instruction is a length 4 vector
113 uint x0 = get_global_id(0);
114 uint y0 = get_global_id(1);
115 uint z = get_global_id(2);
122 uint block_x = thread_id % MMUL_N0;
123 uint block_y = (thread_id / MMUL_M0);
126 uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(
N - 1));
127 uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(
M - M0));
132 uint rhs_x =
VEC_SIZE * N0 * block_y;
133 uint rhs_y = 4 * block_id + block_x;
136 lhs_offset_first_element_in_bytes += lhs_x *
sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
137 rhs_offset_first_element_in_bytes += rhs_x *
sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
138 dst_offset_first_element_in_bytes += dst_x *
sizeof(OUT_DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
140 TILE(ACC_DATA_TYPE, M0, N0, c);
146 for(
int k = 0; k <=
K - MMUL_K0; k += MMUL_K0)
149 T_LOAD(DATA_TYPE, M0,
VEC_SIZE, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
158 VEC_TYPE vec_a = (VEC_TYPE)(a[m0].s[0], a[m0].s[1], a[m0].s[2], a[m0].s[3]);
159 VEC_TYPE vec_b = (VEC_TYPE)(
b[n0].s[0],
b[n0].s[1],
b[n0].s[2],
b[n0].s[3]);
160 c[m0].s[n0] = arm_matrix_multiply(vec_a, vec_b, c[m0].s[n0]);
164 lhs_offset_first_element_in_bytes += MMUL_K0 *
sizeof(DATA_TYPE);
165 rhs_offset_first_element_in_bytes += MMUL_K0 * N0 *
sizeof(DATA_TYPE);
168 if(block_x * N0 + block_id * MMUL_N0 * N0 >=
N)
173 if(block_y * M0 + y0 * M0 * MMUL_M0 >=
M)
178 #if defined(FUSED_OUTPUT_STAGE_FIXED_POINT)
180 TILE(
int, M0, N0, offset_s32);
186 #
if defined(A_OFFSET)
188 TILE(
int, 1, N0, a_offset_s32);
190 T_LOAD(
int, 1, N0, BUFFER, sum_col, dst_x, z, 1, sum_col_stride_z, a_offset_s32);
192 a_offset_s32[0].v *= A_OFFSET;
195 #endif // defined(A_OFFSET)
197 #if defined(B_OFFSET)
199 TILE(
int, M0, 1, b_offset_s32);
201 T_LOAD(
int, M0, 1, BUFFER, sum_row, dst_y, z *
M, 1, 4, b_offset_s32);
205 offset_s32[m0].v += b_offset_s32[m0].v *B_OFFSET;
210 #
if defined(ADD_BIAS)
211 #if defined(BROADCAST_BIAS)
212 bia_offset_first_element_in_bytes += dst_x *
sizeof(ACC_DATA_TYPE) + z * bia_stride_y;
216 T_LOAD(
int, M0, N0, BUFFER, bia, dst_x, dst_y, 1, 1,
bias);
218 T_ADD(ACC_DATA_TYPE, M0, N0, offset_s32,
bias, offset_s32);
220 #else // defined(BROADCAST_BIAS)
221 bia_offset_first_element_in_bytes += dst_x *
sizeof(ACC_DATA_TYPE);
225 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
227 bias[0].v =
VLOAD(N0)(0, (ACC_DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
232 (
bias[0].v, 0, (ACC_DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
237 #endif // defined(BROADCAST_BIAS)
238 #endif // defined(ADD_BIAS)
240 T_ADD(ACC_DATA_TYPE, M0, N0, c, offset_s32, c);
241 TILE(OUT_DATA_TYPE, M0, N0, c_lp);
242 T_QUANTIZE8(ACC_DATA_TYPE, OUT_DATA_TYPE, PER_TENSOR, M0, N0, RESULT_OFFSET, RESULT_SHIFT, RESULT_MULTIPLIER, c, 0, 0, c_lp);
244 #if defined(MIN_BOUND)
247 c_lp[i].v = max(c_lp[i].v, (
VEC_DATA_TYPE(OUT_DATA_TYPE, N0))MIN_BOUND);
250 #
if defined(MAX_BOUND)
253 c_lp[i].v = min(c_lp[i].v, (
VEC_DATA_TYPE(OUT_DATA_TYPE, N0))MAX_BOUND);
257 T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c);
259 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
263 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
266 (c_lp[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
274 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
277 (c_lp[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
282 #else // FUSED_OUTPUT_STAGE_FIXED_POINT
284 if(dst_x + N0 <=
N || N0_LEFTOVER == 0)
288 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
291 (c[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
299 if(dst_y + m0 <
M || M0_LEFTOVER == 0)
302 (c[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
306 #endif // FUSED_OUTPUT_STAGE_FIXED_POINT
309 #endif // defined(GEMMLOWP_MM_RESHAPED_ONLY_RHS_MMUL)