28 #if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) 29 #if defined(VEC_SIZE) && VEC_SIZE == 2 63 __kernel
void winograd_output_transform_2x2_3x3_nchw(
73 #if defined(SRC_DEPTH) 87 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 92 float out00 = d00 + d01 + d02;
93 float out01 = d01 - d02 - d03;
94 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 112 float k0 = d01 + d11 + d21;
113 float k1 = d02 + d12 + d22;
114 float k2 = d11 - d21 - d31;
115 float k3 = d12 - d22 - d32;
127 out00 += d00 + d20 + k0 + k1;
128 out01 += k0 - k1 - (d03 + d23);
129 out10 += -d20 - d30 + k2 + k3;
130 out11 += k2 - k3 + d23 + d33;
131 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 133 int y_in = get_global_id(1);
134 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
135 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
136 int z_out = get_global_id(0);
137 #if defined(SRC_DEPTH) 138 int batch = get_global_id(2) / SRC_DEPTH;
141 #if defined(HAS_BIAS) 149 #endif // defined(HAS_BIAS) 152 #if defined(SRC_DEPTH) 153 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
155 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
159 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 162 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
163 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
164 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 165 vstore2(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 2))(out00, out01),
VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
166 (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
167 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 169 #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 170 #if defined(HAS_BIAS) 174 #endif // defined(HAS_BIAS) 175 vstore2(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 2))(out10, out11),
VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
176 (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
177 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 180 #define COMPUTE_TMP_COL_2x2_7x7(col, d0, d1, d2, d3, d4, d5, d6, d7) \ 182 col.s0 = d0 + d1 + d2 + d3 + d4 + d5 + d6; \ 183 col.s1 = -d1 + d2 - 2 * d3 + 2 * d4 + -3 * d5 + 3 * d6 + d7; \ 218 __kernel
void winograd_output_transform_2x2_7x7_nhwc(
221 #
if defined(HAS_BIAS)
227 #if defined(SRC_DEPTH) 235 int y_in = get_global_id(1);
236 int x_out = get_global_id(0);
237 int y_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
238 int z_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
239 #if defined(SRC_DEPTH) 240 int batch = get_global_id(2) / SRC_DEPTH;
243 __global
unsigned char *dst_base_ptr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE);
245 #if defined(SRC_DEPTH) 246 dst_base_ptr += batch * dst_stride_w;
247 #endif // defined(SRC_DEPTH) 250 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
251 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
252 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
253 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
254 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
255 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
256 DATA_TYPE d06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
257 DATA_TYPE d07 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
259 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 261 float out00 = d00 + d01 + d02 + d03 + d04 + d05 + d06;
262 float out01 = -d01 + d02 - 2.f * d03 + 2.0f * d04 - 3.0f * d05 + 3.0f * d06 + d07;
264 #if defined(HAS_BIAS) 268 float b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, x_out)));
272 #endif // defined(HAS_BIAS) 275 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 277 dst_base_ptr += y_out * dst_stride_y;
279 int2 offset_z = min((int2)z_out + (int2)(0, 1), (int2)((
int)DST_HEIGHT - 1)) * (int2)dst_stride_z;
286 *(__global DATA_TYPE *)(dst_base_ptr + offset_z.s1) = out0_dt.s1;
287 *(__global DATA_TYPE *)(dst_base_ptr + offset_z.s0) = out0_dt.s0;
288 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 290 dst_base_ptr += z_out * dst_stride_z;
292 int2 offset_y = min((int2)y_out + (int2)(0, 1), (int2)((
int)DST_WIDTH - 1)) * (int2)dst_stride_y;
300 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s1) = out0_dt.s1;
301 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s0) = out0_dt.s0;
302 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 304 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 306 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
307 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
308 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
309 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
310 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
311 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
312 DATA_TYPE d16 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
313 DATA_TYPE d17 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
315 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
316 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
317 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
318 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
319 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
320 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
321 DATA_TYPE d26 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
322 DATA_TYPE d27 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
324 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
325 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
326 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
327 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
328 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
329 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
330 DATA_TYPE d36 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
331 DATA_TYPE d37 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
333 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
334 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
335 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
336 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
337 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 36 * src_stride_z));
338 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 37 * src_stride_z));
339 DATA_TYPE d46 = *((__global DATA_TYPE *)(src_addr + 38 * src_stride_z));
340 DATA_TYPE d47 = *((__global DATA_TYPE *)(src_addr + 39 * src_stride_z));
342 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 40 * src_stride_z));
343 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 41 * src_stride_z));
344 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 42 * src_stride_z));
345 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 43 * src_stride_z));
346 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 44 * src_stride_z));
347 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 45 * src_stride_z));
348 DATA_TYPE d56 = *((__global DATA_TYPE *)(src_addr + 46 * src_stride_z));
349 DATA_TYPE d57 = *((__global DATA_TYPE *)(src_addr + 47 * src_stride_z));
351 DATA_TYPE d60 = *((__global DATA_TYPE *)(src_addr + 48 * src_stride_z));
352 DATA_TYPE d61 = *((__global DATA_TYPE *)(src_addr + 49 * src_stride_z));
353 DATA_TYPE d62 = *((__global DATA_TYPE *)(src_addr + 50 * src_stride_z));
354 DATA_TYPE d63 = *((__global DATA_TYPE *)(src_addr + 51 * src_stride_z));
355 DATA_TYPE d64 = *((__global DATA_TYPE *)(src_addr + 52 * src_stride_z));
356 DATA_TYPE d65 = *((__global DATA_TYPE *)(src_addr + 53 * src_stride_z));
357 DATA_TYPE d66 = *((__global DATA_TYPE *)(src_addr + 54 * src_stride_z));
358 DATA_TYPE d67 = *((__global DATA_TYPE *)(src_addr + 55 * src_stride_z));
360 DATA_TYPE d70 = *((__global DATA_TYPE *)(src_addr + 56 * src_stride_z));
361 DATA_TYPE d71 = *((__global DATA_TYPE *)(src_addr + 57 * src_stride_z));
362 DATA_TYPE d72 = *((__global DATA_TYPE *)(src_addr + 58 * src_stride_z));
363 DATA_TYPE d73 = *((__global DATA_TYPE *)(src_addr + 59 * src_stride_z));
364 DATA_TYPE d74 = *((__global DATA_TYPE *)(src_addr + 60 * src_stride_z));
365 DATA_TYPE d75 = *((__global DATA_TYPE *)(src_addr + 61 * src_stride_z));
366 DATA_TYPE d76 = *((__global DATA_TYPE *)(src_addr + 62 * src_stride_z));
367 DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
371 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
373 COMPUTE_TMP_COL_2x2_7x7(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70);
374 COMPUTE_TMP_COL_2x2_7x7(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71);
375 COMPUTE_TMP_COL_2x2_7x7(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72);
376 COMPUTE_TMP_COL_2x2_7x7(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73);
377 COMPUTE_TMP_COL_2x2_7x7(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74);
378 COMPUTE_TMP_COL_2x2_7x7(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75);
379 COMPUTE_TMP_COL_2x2_7x7(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76);
380 COMPUTE_TMP_COL_2x2_7x7(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77);
384 out_col0 = tmp_col0 + tmp_col1 + tmp_col2 + tmp_col3 + tmp_col4 + tmp_col5 + tmp_col6;
386 out_col1 = -tmp_col1 + tmp_col2 - 2 * tmp_col3 + 2 * tmp_col4 - 3 * tmp_col5 + 3 * tmp_col6 + tmp_col7;
388 #if defined(HAS_BIAS) 392 DATA_TYPE b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, x_out)));
397 #endif // defined(HAS_BIAS) 399 int2 offset_y = min((int2)y_out + (int2)(0, 1), (int2)((
int)DST_WIDTH - 1)) * (int2)dst_stride_y;
400 int2 offset_z = min((int2)z_out + (int2)(0, 1), (int2)((
int)DST_HEIGHT - 1)) * (int2)dst_stride_z;
410 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s1) = out_col1_dt.s1;
411 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s0) = out_col1_dt.s0;
412 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s1) = out_col0_dt.s1;
413 *(__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s0) = out_col0_dt.s0;
415 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 417 #endif // defined(VEC_SIZE) && VEC_SIZE == 2 419 #if defined(VEC_SIZE) && VEC_SIZE == 4 450 __kernel
void winograd_output_transform_4x4_3x3_nchw(
453 #
if defined(HAS_BIAS)
460 #if defined(SRC_DEPTH) 469 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
470 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
471 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
472 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
473 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
474 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
476 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 478 float out00 = d00 + d01 + d02 + d03 + d04;
479 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
480 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
481 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
482 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 484 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
485 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
486 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
487 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
488 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
489 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
491 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
492 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
493 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
494 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
495 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
496 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
498 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
499 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
500 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
501 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
502 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
503 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
505 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
506 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
507 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
508 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
509 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
510 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
512 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
513 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
514 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
515 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
516 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
517 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
520 float out00 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
521 float out01 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
522 float out02 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
523 float out03 = (float)d01 + d21 + (
float)d41 + (float)d11 + (
float)d31;
525 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
526 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
528 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
529 out01 += k1 - d02 - d12 - d22 - d32 - d42;
530 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
531 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
534 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
535 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
536 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
537 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
539 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
540 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
542 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
543 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
544 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
545 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
548 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
549 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
550 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
551 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
553 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
554 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
556 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
557 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
558 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
559 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
562 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
563 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
564 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
565 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
567 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
568 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
570 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
571 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
572 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
573 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
574 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 576 int y_in = get_global_id(1);
577 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
578 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
579 int z_out = get_global_id(0);
580 #if defined(SRC_DEPTH) 581 int batch = get_global_id(2) / SRC_DEPTH;
584 #if defined(HAS_BIAS) 588 float b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, z_out)));
594 #endif // defined(HAS_BIAS) 597 #if defined(SRC_DEPTH) 598 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
600 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
604 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 606 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL,
608 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
609 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
610 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
611 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
612 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 613 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL), 0,
614 (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
615 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 617 #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 618 #if defined(HAS_BIAS) 634 #endif // defined(HAS_BIAS) 635 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out10, out11, out12, out13),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL), 0,
636 (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
637 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out20, out21, out22, out23),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL), 0,
638 (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
639 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out30, out31, out32, out33),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL), 0,
640 (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
641 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 677 __kernel
void winograd_output_transform_4x4_3x3_nhwc(
680 #
if defined(HAS_BIAS)
686 #if defined(SRC_DEPTH) 695 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
696 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
697 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
698 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
699 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
700 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
702 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 704 float out00 = d00 + d01 + d02 + d03 + d04;
705 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
706 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
707 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
708 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 710 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
711 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
712 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
713 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
714 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
715 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
717 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
718 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
719 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
720 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
721 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
722 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
724 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
725 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
726 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
727 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
728 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
729 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
731 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
732 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
733 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
734 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
735 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
736 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
738 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
739 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
740 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
741 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
742 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
743 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
746 float out00 = d01 + d21 + d41 + d11 + d31;
747 float out01 = d01 + d21 + d41 + d11 + d31;
748 float out02 = d01 + d21 + d41 + d11 + d31;
749 float out03 = d01 + d21 + d41 + d11 + d31;
751 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
752 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
754 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
755 out01 += k1 - d02 - d12 - d22 - d32 - d42;
756 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
757 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
760 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
761 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
762 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
763 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
765 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
766 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
768 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
769 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
770 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
771 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
774 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
775 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
776 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
777 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
779 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
780 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
782 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
783 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
784 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
785 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
788 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
789 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
790 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
791 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
793 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
794 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
796 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
797 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
798 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
799 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
800 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 802 int y_in = get_global_id(1);
803 int x_out = get_global_id(0);
804 int y_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
805 int z_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
806 #if defined(SRC_DEPTH) 807 int batch = get_global_id(2) / SRC_DEPTH;
810 #if defined(HAS_BIAS) 820 #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) & !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 835 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) & !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 837 #endif // defined(HAS_BIAS) 839 __global
unsigned char *dst_base_ptr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE);
841 #if defined(SRC_DEPTH) 842 dst_base_ptr += batch * dst_stride_w;
843 #endif // defined(SRC_DEPTH) 845 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 847 dst_base_ptr += y_out * dst_stride_y;
849 int4 offset_z = min((int4)z_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_HEIGHT - 1)) * (int4)dst_stride_z;
853 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL,
858 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s3)) = out0_dt.s3;
859 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s2)) = out0_dt.s2;
860 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s1)) = out0_dt.s1;
861 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s0)) = out0_dt.s0;
863 #elif defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) 865 dst_base_ptr += z_out * dst_stride_z;
867 int4 offset_y = min((int4)y_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_WIDTH - 1)) * (int4)dst_stride_y;
870 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)),
875 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3)) = out0_dt.s3;
876 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2)) = out0_dt.s2;
877 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1)) = out0_dt.s1;
878 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0)) = out0_dt.s0;
880 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) 882 int4 offset_y = min((int4)y_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_WIDTH - 1)) * (int4)dst_stride_y;
883 int4 offset_z = min((int4)z_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_HEIGHT - 1)) * (int4)dst_stride_z;
887 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL);
889 out1_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out10, out11, out12, out13),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL);
891 out2_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out20, out21, out22, out23),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL);
899 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s3)) = out3_dt.s3;
900 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s3)) = out3_dt.s2;
901 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s3)) = out3_dt.s1;
902 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s3)) = out3_dt.s0;
903 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s2)) = out2_dt.s3;
904 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s2)) = out2_dt.s2;
905 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s2)) = out2_dt.s1;
906 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s2)) = out2_dt.s0;
907 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s1)) = out1_dt.s3;
908 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s1)) = out1_dt.s2;
909 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s1)) = out1_dt.s1;
910 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s1)) = out1_dt.s0;
911 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s0)) = out0_dt.s3;
912 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s0)) = out0_dt.s2;
913 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s0)) = out0_dt.s1;
914 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s0)) = out0_dt.s0;
915 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) 918 #define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \ 920 comm_fact.s0 = d1 + d2; \ 921 comm_fact.s1 = d3 + d4; \ 922 comm_fact.s2 = d5 + d6; \ 924 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \ 925 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \ 927 comm_fact.s0 = d1 - d2; \ 928 comm_fact.s1 = d3 - d4; \ 929 comm_fact.s2 = d5 - d6; \ 931 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \ 932 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \ 965 __kernel
void winograd_output_transform_4x4_5x5_nchw(
968 #
if defined(HAS_BIAS)
975 #if defined(SRC_DEPTH) 985 int y_in = get_global_id(1);
986 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
987 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
988 int z_out = get_global_id(0);
989 #if defined(SRC_DEPTH) 990 int batch = get_global_id(2) / SRC_DEPTH;
993 #if defined(SRC_DEPTH) 994 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
997 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
1001 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
1002 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
1003 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
1004 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
1005 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
1006 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
1007 DATA_TYPE d06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
1008 DATA_TYPE d07 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
1010 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1012 float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
1013 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
1014 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
1015 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
1017 #if defined(HAS_BIAS) 1021 float b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, z_out)));
1027 #endif // defined(HAS_BIAS) 1030 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1032 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL,
1034 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
1035 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
1036 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
1037 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
1038 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1039 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL, B_VAL), 0,
1040 (__global DATA_TYPE *)(dst_addr));
1041 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1043 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1045 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
1046 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
1047 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
1048 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
1049 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
1050 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
1051 DATA_TYPE d16 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
1052 DATA_TYPE d17 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
1054 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
1055 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
1056 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
1057 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
1058 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
1059 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
1060 DATA_TYPE d26 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
1061 DATA_TYPE d27 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
1063 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
1064 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
1065 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
1066 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
1067 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
1068 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
1069 DATA_TYPE d36 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
1070 DATA_TYPE d37 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
1072 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
1073 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
1074 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
1075 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
1076 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 36 * src_stride_z));
1077 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 37 * src_stride_z));
1078 DATA_TYPE d46 = *((__global DATA_TYPE *)(src_addr + 38 * src_stride_z));
1079 DATA_TYPE d47 = *((__global DATA_TYPE *)(src_addr + 39 * src_stride_z));
1081 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 40 * src_stride_z));
1082 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 41 * src_stride_z));
1083 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 42 * src_stride_z));
1084 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 43 * src_stride_z));
1085 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 44 * src_stride_z));
1086 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 45 * src_stride_z));
1087 DATA_TYPE d56 = *((__global DATA_TYPE *)(src_addr + 46 * src_stride_z));
1088 DATA_TYPE d57 = *((__global DATA_TYPE *)(src_addr + 47 * src_stride_z));
1090 DATA_TYPE d60 = *((__global DATA_TYPE *)(src_addr + 48 * src_stride_z));
1091 DATA_TYPE d61 = *((__global DATA_TYPE *)(src_addr + 49 * src_stride_z));
1092 DATA_TYPE d62 = *((__global DATA_TYPE *)(src_addr + 50 * src_stride_z));
1093 DATA_TYPE d63 = *((__global DATA_TYPE *)(src_addr + 51 * src_stride_z));
1094 DATA_TYPE d64 = *((__global DATA_TYPE *)(src_addr + 52 * src_stride_z));
1095 DATA_TYPE d65 = *((__global DATA_TYPE *)(src_addr + 53 * src_stride_z));
1096 DATA_TYPE d66 = *((__global DATA_TYPE *)(src_addr + 54 * src_stride_z));
1097 DATA_TYPE d67 = *((__global DATA_TYPE *)(src_addr + 55 * src_stride_z));
1099 DATA_TYPE d70 = *((__global DATA_TYPE *)(src_addr + 56 * src_stride_z));
1100 DATA_TYPE d71 = *((__global DATA_TYPE *)(src_addr + 57 * src_stride_z));
1101 DATA_TYPE d72 = *((__global DATA_TYPE *)(src_addr + 58 * src_stride_z));
1102 DATA_TYPE d73 = *((__global DATA_TYPE *)(src_addr + 59 * src_stride_z));
1103 DATA_TYPE d74 = *((__global DATA_TYPE *)(src_addr + 60 * src_stride_z));
1104 DATA_TYPE d75 = *((__global DATA_TYPE *)(src_addr + 61 * src_stride_z));
1105 DATA_TYPE d76 = *((__global DATA_TYPE *)(src_addr + 62 * src_stride_z));
1106 DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
1110 comm_fact0, comm_fact1, comm_fact2;
1112 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1114 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1115 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1116 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1117 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1118 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1119 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1120 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1121 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1124 comm_fact0 = tmp_col1 + tmp_col2;
1125 comm_fact1 = tmp_col3 + tmp_col4;
1126 comm_fact2 = tmp_col5 + tmp_col6;
1129 out_col0 = comm_fact0 + comm_fact1 + (
float)8.f * comm_fact2 + tmp_col0;
1131 out_col2 = comm_fact0 + (
float)4.f * comm_fact1 + (
float)2.f * comm_fact2;
1133 comm_fact0 = tmp_col1 - tmp_col2;
1134 comm_fact1 = tmp_col3 - tmp_col4;
1135 comm_fact2 = tmp_col5 - tmp_col6;
1138 out_col1 = comm_fact0 + (
float)2.f * comm_fact1 + (
float)4.f * comm_fact2;
1140 out_col3 = comm_fact0 + (
float)8.f * comm_fact1 + comm_fact2 + tmp_col7;
1142 #if defined(HAS_BIAS) 1146 float b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, z_out)));
1152 #endif // defined(HAS_BIAS) 1155 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, (
VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s0, out_col1.s0, out_col2.s0, out_col3.s0), A_VAL, B_VAL), 0,
1156 (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
1157 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, (
VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s1, out_col1.s1, out_col2.s1, out_col3.s1), A_VAL, B_VAL), 0,
1158 (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
1159 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, (
VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s2, out_col1.s2, out_col2.s2, out_col3.s2), A_VAL, B_VAL), 0,
1160 (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
1161 vstore4(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, (
VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s3, out_col1.s3, out_col2.s3, out_col3.s3), A_VAL, B_VAL), 0,
1162 (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
1163 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1198 __kernel
void winograd_output_transform_4x4_5x5_nhwc(
1201 #
if defined(HAS_BIAS)
1207 #if defined(SRC_DEPTH) 1215 int y_in = get_global_id(1);
1216 int x_out = get_global_id(0);
1217 int y_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
1218 int z_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
1219 #if defined(SRC_DEPTH) 1220 int batch = get_global_id(2) / SRC_DEPTH;
1223 __global
unsigned char *dst_base_ptr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(
DATA_TYPE);
1225 #if defined(SRC_DEPTH) 1226 dst_base_ptr += batch * dst_stride_w;
1227 #endif // defined(SRC_DEPTH) 1230 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
1231 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
1232 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
1233 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
1234 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
1235 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
1236 DATA_TYPE d06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
1237 DATA_TYPE d07 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
1239 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1241 float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
1242 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
1243 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
1244 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
1246 #if defined(HAS_BIAS) 1250 float b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, x_out)));
1256 #endif // defined(HAS_BIAS) 1259 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1261 dst_base_ptr += y_out * dst_stride_y;
1263 int4 offset_z = min((int4)z_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_HEIGHT - 1)) * (int4)dst_stride_z;
1266 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL,
1271 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s3)) = out0_dt.s3;
1272 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s2)) = out0_dt.s2;
1273 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s1)) = out0_dt.s1;
1274 *((__global DATA_TYPE *)(dst_base_ptr + offset_z.s0)) = out0_dt.s0;
1275 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1277 dst_base_ptr += z_out * dst_stride_z;
1279 int4 offset_y = min((int4)y_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_WIDTH - 1)) * (int4)dst_stride_y;
1282 out0_dt =
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03),
VEC_DATA_TYPE(DATA_TYPE, 4)), A_VAL,
1287 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3)) = out0_dt.s3;
1288 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2)) = out0_dt.s2;
1289 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1)) = out0_dt.s1;
1290 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0)) = out0_dt.s0;
1291 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1293 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1295 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
1296 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
1297 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
1298 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
1299 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
1300 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
1301 DATA_TYPE d16 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
1302 DATA_TYPE d17 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
1304 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
1305 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
1306 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
1307 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
1308 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
1309 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
1310 DATA_TYPE d26 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
1311 DATA_TYPE d27 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
1313 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
1314 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
1315 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
1316 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
1317 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
1318 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
1319 DATA_TYPE d36 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
1320 DATA_TYPE d37 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
1322 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
1323 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
1324 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
1325 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
1326 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 36 * src_stride_z));
1327 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 37 * src_stride_z));
1328 DATA_TYPE d46 = *((__global DATA_TYPE *)(src_addr + 38 * src_stride_z));
1329 DATA_TYPE d47 = *((__global DATA_TYPE *)(src_addr + 39 * src_stride_z));
1331 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 40 * src_stride_z));
1332 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 41 * src_stride_z));
1333 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 42 * src_stride_z));
1334 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 43 * src_stride_z));
1335 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 44 * src_stride_z));
1336 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 45 * src_stride_z));
1337 DATA_TYPE d56 = *((__global DATA_TYPE *)(src_addr + 46 * src_stride_z));
1338 DATA_TYPE d57 = *((__global DATA_TYPE *)(src_addr + 47 * src_stride_z));
1340 DATA_TYPE d60 = *((__global DATA_TYPE *)(src_addr + 48 * src_stride_z));
1341 DATA_TYPE d61 = *((__global DATA_TYPE *)(src_addr + 49 * src_stride_z));
1342 DATA_TYPE d62 = *((__global DATA_TYPE *)(src_addr + 50 * src_stride_z));
1343 DATA_TYPE d63 = *((__global DATA_TYPE *)(src_addr + 51 * src_stride_z));
1344 DATA_TYPE d64 = *((__global DATA_TYPE *)(src_addr + 52 * src_stride_z));
1345 DATA_TYPE d65 = *((__global DATA_TYPE *)(src_addr + 53 * src_stride_z));
1346 DATA_TYPE d66 = *((__global DATA_TYPE *)(src_addr + 54 * src_stride_z));
1347 DATA_TYPE d67 = *((__global DATA_TYPE *)(src_addr + 55 * src_stride_z));
1349 DATA_TYPE d70 = *((__global DATA_TYPE *)(src_addr + 56 * src_stride_z));
1350 DATA_TYPE d71 = *((__global DATA_TYPE *)(src_addr + 57 * src_stride_z));
1351 DATA_TYPE d72 = *((__global DATA_TYPE *)(src_addr + 58 * src_stride_z));
1352 DATA_TYPE d73 = *((__global DATA_TYPE *)(src_addr + 59 * src_stride_z));
1353 DATA_TYPE d74 = *((__global DATA_TYPE *)(src_addr + 60 * src_stride_z));
1354 DATA_TYPE d75 = *((__global DATA_TYPE *)(src_addr + 61 * src_stride_z));
1355 DATA_TYPE d76 = *((__global DATA_TYPE *)(src_addr + 62 * src_stride_z));
1356 DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
1360 comm_fact0, comm_fact1, comm_fact2;
1362 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1364 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1365 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1366 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1367 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1368 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1369 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1370 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1371 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1374 comm_fact0 = tmp_col1 + tmp_col2;
1375 comm_fact1 = tmp_col3 + tmp_col4;
1376 comm_fact2 = tmp_col5 + tmp_col6;
1379 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
1381 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
1383 comm_fact0 = tmp_col1 - tmp_col2;
1384 comm_fact1 = tmp_col3 - tmp_col4;
1385 comm_fact2 = tmp_col5 - tmp_col6;
1388 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
1390 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
1392 #if defined(HAS_BIAS) 1396 DATA_TYPE b = (float) * ((__global DATA_TYPE *)(
vector_offset(&bias, x_out)));
1402 #endif // defined(HAS_BIAS) 1404 int4 offset_y = min((int4)y_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_WIDTH - 1)) * (int4)dst_stride_y;
1405 int4 offset_z = min((int4)z_out + (int4)(0, 1, 2, 3), (int4)((
int)DST_HEIGHT - 1)) * (int4)dst_stride_z;
1419 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s3)) = out_col3_dt.s3;
1420 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s3)) = out_col2_dt.s3;
1421 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s3)) = out_col1_dt.s3;
1422 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s3)) = out_col0_dt.s3;
1423 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s2)) = out_col3_dt.s2;
1424 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s2)) = out_col2_dt.s2;
1425 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s2)) = out_col1_dt.s2;
1426 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s2)) = out_col0_dt.s2;
1427 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s1)) = out_col3_dt.s1;
1428 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s1)) = out_col2_dt.s1;
1429 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s1)) = out_col1_dt.s1;
1430 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s1)) = out_col0_dt.s1;
1431 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s3 + offset_z.s0)) = out_col3_dt.s0;
1432 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s2 + offset_z.s0)) = out_col2_dt.s0;
1433 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s1 + offset_z.s0)) = out_col1_dt.s0;
1434 *((__global DATA_TYPE *)(dst_base_ptr + offset_y.s0 + offset_z.s0)) = out_col0_dt.s0;
1435 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1437 #endif // defined(VEC_SIZE) && VEC_SIZE == 4 1439 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) 1440 #if defined(VEC_SIZE) && VEC_SIZE == 2 1470 __kernel
void winograd_output_transform_2x1_3x1_nchw(
1473 #
if defined(HAS_BIAS)
1479 winograd_output_transform_2x2_3x3_nchw(src_ptr,
1488 src_offset_first_element_in_bytes,
1498 dst_offset_first_element_in_bytes
1499 #
if defined(HAS_BIAS)
1504 bias_offset_first_element_in_bytes
1540 __kernel
void winograd_output_transform_2x1_7x1_nhwc(
1543 #
if defined(HAS_BIAS)
1548 winograd_output_transform_2x2_7x7_nhwc(src_ptr,
1557 src_offset_first_element_in_bytes,
1567 dst_offset_first_element_in_bytes,
1568 #
if defined(HAS_BIAS)
1572 bias_offset_first_element_in_bytes,
1576 #endif // defined(VEC_SIZE) && VEC_SIZE == 2 1578 #if defined(VEC_SIZE) && VEC_SIZE == 4 1608 __kernel
void winograd_output_transform_4x1_3x1_nchw(
1611 #
if defined(HAS_BIAS)
1617 winograd_output_transform_4x4_3x3_nchw(src_ptr,
1626 src_offset_first_element_in_bytes,
1636 dst_offset_first_element_in_bytes
1637 #
if defined(HAS_BIAS)
1642 bias_offset_first_element_in_bytes
1676 __kernel
void winograd_output_transform_4x1_5x1_nchw(
1679 #
if defined(HAS_BIAS)
1685 winograd_output_transform_4x4_5x5_nchw(src_ptr,
1694 src_offset_first_element_in_bytes,
1704 dst_offset_first_element_in_bytes
1705 #
if defined(HAS_BIAS)
1710 bias_offset_first_element_in_bytes
1746 __kernel
void winograd_output_transform_4x1_3x1_nhwc(
1749 #
if defined(HAS_BIAS)
1754 winograd_output_transform_4x4_3x3_nhwc(src_ptr,
1763 src_offset_first_element_in_bytes,
1773 dst_offset_first_element_in_bytes,
1774 #
if defined(HAS_BIAS)
1778 bias_offset_first_element_in_bytes,
1814 __kernel
void winograd_output_transform_4x1_5x1_nhwc(
1817 #
if defined(HAS_BIAS)
1822 winograd_output_transform_4x4_5x5_nhwc(src_ptr,
1831 src_offset_first_element_in_bytes,
1841 dst_offset_first_element_in_bytes,
1842 #
if defined(HAS_BIAS)
1846 bias_offset_first_element_in_bytes,
1850 #endif // defined(VEC_SIZE) && VEC_SIZE == 4 1851 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) 1853 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 1854 #if defined(VEC_SIZE) && VEC_SIZE == 2 1884 __kernel
void winograd_output_transform_1x2_1x3_nchw(
1887 #
if defined(HAS_BIAS)
1893 winograd_output_transform_2x2_3x3_nchw(src_ptr,
1902 src_offset_first_element_in_bytes,
1912 dst_offset_first_element_in_bytes
1913 #
if defined(HAS_BIAS)
1918 bias_offset_first_element_in_bytes
1954 __kernel
void winograd_output_transform_1x2_1x7_nhwc(
1957 #
if defined(HAS_BIAS)
1962 winograd_output_transform_2x2_7x7_nhwc(src_ptr,
1971 src_offset_first_element_in_bytes,
1981 dst_offset_first_element_in_bytes,
1982 #
if defined(HAS_BIAS)
1986 bias_offset_first_element_in_bytes,
1990 #endif // defined(VEC_SIZE) && VEC_SIZE == 2 1992 #if defined(VEC_SIZE) && VEC_SIZE == 4 2022 __kernel
void winograd_output_transform_1x4_1x3_nchw(
2025 #
if defined(HAS_BIAS)
2031 winograd_output_transform_4x4_3x3_nchw(src_ptr,
2040 src_offset_first_element_in_bytes,
2050 dst_offset_first_element_in_bytes
2051 #
if defined(HAS_BIAS)
2056 bias_offset_first_element_in_bytes
2090 __kernel
void winograd_output_transform_1x4_1x5_nchw(
2093 #
if defined(HAS_BIAS)
2099 winograd_output_transform_4x4_5x5_nchw(src_ptr,
2108 src_offset_first_element_in_bytes,
2118 dst_offset_first_element_in_bytes
2119 #
if defined(HAS_BIAS)
2124 bias_offset_first_element_in_bytes
2160 __kernel
void winograd_output_transform_1x4_1x3_nhwc(
2163 #
if defined(HAS_BIAS)
2168 winograd_output_transform_4x4_3x3_nhwc(src_ptr,
2177 src_offset_first_element_in_bytes,
2187 dst_offset_first_element_in_bytes,
2188 #
if defined(HAS_BIAS)
2192 bias_offset_first_element_in_bytes,
2228 __kernel
void winograd_output_transform_1x4_1x5_nhwc(
2231 #
if defined(HAS_BIAS)
2236 winograd_output_transform_4x4_5x5_nhwc(src_ptr,
2245 src_offset_first_element_in_bytes,
2255 dst_offset_first_element_in_bytes,
2256 #
if defined(HAS_BIAS)
2260 bias_offset_first_element_in_bytes,
2264 #endif // defined(VEC_SIZE) && VEC_SIZE == 4 2265 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) 2266 #endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) Structure to hold Vector information.
Structure to hold 3D tensor information.
SimpleTensor< float > src
Structure to hold 4D tensor information.
#define VECTOR_DECLARATION(name)
__global const uchar * tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
Get the pointer position of a Tensor4D.
#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)
#define CONVERT_TO_TENSOR3D_STRUCT(name)
__global const uchar * vector_offset(const Vector *vec, int x)
Get the pointer position of a Vector.
#define TENSOR4D_DECLARATION(name)
#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name)
#define ACTIVATION(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL)
__global const uchar * tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
Get the pointer position of a Tensor3D.
#define VEC_DATA_TYPE(type, size)