28 #if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
29 #if defined(VEC_SIZE) && VEC_SIZE == 2
63 __kernel
void winograd_output_transform_2x2_3x3_nchw(
73 #if defined(SRC_DEPTH)
82 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
83 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
84 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
85 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
87 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
92 float out00 = d00 + d01 + d02;
93 float out01 = d01 - d02 - d03;
94 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
96 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
97 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
98 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
99 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
101 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
102 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
103 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
104 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
106 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
107 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
108 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
109 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
112 float k0 = d01 + d11 + d21;
113 float k1 = d02 + d12 + d22;
114 float k2 = d11 - d21 - d31;
115 float k3 = d12 - d22 - d32;
127 out00 += d00 + d20 + k0 + k1;
128 out01 += k0 - k1 - (d03 + d23);
129 out10 += -d20 - d30 + k2 + k3;
130 out11 += k2 - k3 + d23 + d33;
131 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
133 int y_in = get_global_id(1);
134 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
135 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
136 int z_out = get_global_id(0);
137 #if defined(SRC_DEPTH)
138 int batch = get_global_id(2) / SRC_DEPTH;
141 #if defined(HAS_BIAS)
149 #endif // defined(HAS_BIAS)
152 #if defined(SRC_DEPTH)
153 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z +
batch * dst_stride_w;
155 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
159 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
162 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
163 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
164 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
165 vstore2(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE,
VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 2))(out00, out01),
VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
166 (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
167 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
169 #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
170 #if defined(HAS_BIAS)
172 out10 += (DATA_TYPE)
b;
173 out11 += (DATA_TYPE)
b;
174 #endif // defined(HAS_BIAS)
175 vstore2(
ACTIVATION(ACTIVATION_TYPE, DATA_TYPE,
VEC_SIZE,
CONVERT((
VEC_DATA_TYPE(
float, 2))(out10, out11),
VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
176 (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
177 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
179 #endif // defined(VEC_SIZE) && VEC_SIZE == 2
181 #if defined(VEC_SIZE) && VEC_SIZE == 4
212 __kernel
void winograd_output_transform_4x4_3x3_nchw(
215 #
if defined(HAS_BIAS)
222 #if defined(SRC_DEPTH)
231 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
232 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
233 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
234 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
235 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
236 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
238 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
240 float out00 = d00 + d01 + d02 + d03 + d04;
241 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
242 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
243 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
244 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
246 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
247 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
248 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
249 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
250 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
251 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
253 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
254 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
255 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
256 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
257 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
258 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
260 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
261 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
262 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
263 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
264 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
265 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
267 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
268 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
269 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
270 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
271 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
272 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
274 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
275 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
276 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
277 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
278 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
279 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
282 float out00 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
283 float out01 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
284 float out02 = (float)d01 + (
float)d21 + (float)d41 + (
float)d11 + (float)d31;
285 float out03 = (float)d01 + d21 + (
float)d41 + (float)d11 + (
float)d31;
287 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
288 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
290 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
291 out01 += k1 - d02 - d12 - d22 - d32 - d42;
292 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
293 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
296 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
297 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
298 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
299 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
301 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
302 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
304 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
305 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
306 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
307 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
310 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
311 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
312 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
313 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
315 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
316 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
318 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
319 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
320 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
321 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
324 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
325 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
326 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
327 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
329 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
330 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
332 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
333 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
334 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
335 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
336 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
338 int y_in = get_global_id(1);
339 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
340 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
341 int z_out = get_global_id(0);
342 #if defined(SRC_DEPTH)
343 int batch = get_global_id(2) / SRC_DEPTH;
346 #if defined(HAS_BIAS)
356 #endif // defined(HAS_BIAS)
359 #if defined(SRC_DEPTH)
360 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z +
batch * dst_stride_w;
362 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
367 out0_dt =
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03), A_VAL, B_VAL),
VEC_DATA_TYPE(DATA_TYPE, 4));
369 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
370 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
371 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
372 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
373 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
374 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
375 vstore4(out0_dt, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
376 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
378 #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
379 #if defined(HAS_BIAS)
395 #endif // defined(HAS_BIAS)
396 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out10, out11, out12, out13), A_VAL, B_VAL),
VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
397 (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
398 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out20, out21, out22, out23), A_VAL, B_VAL),
VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
399 (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
400 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out30, out31, out32, out33), A_VAL, B_VAL),
VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
401 (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
402 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
405 #define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
407 comm_fact.s0 = d1 + d2; \
408 comm_fact.s1 = d3 + d4; \
409 comm_fact.s2 = d5 + d6; \
411 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
412 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
414 comm_fact.s0 = d1 - d2; \
415 comm_fact.s1 = d3 - d4; \
416 comm_fact.s2 = d5 - d6; \
418 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
419 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
452 __kernel
void winograd_output_transform_4x4_5x5_nchw(
455 #
if defined(HAS_BIAS)
462 #if defined(SRC_DEPTH)
472 int y_in = get_global_id(1);
473 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
474 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
475 int z_out = get_global_id(0);
476 #if defined(SRC_DEPTH)
477 int batch = get_global_id(2) / SRC_DEPTH;
480 #if defined(SRC_DEPTH)
481 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z +
batch * dst_stride_w;
484 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out *
sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
488 DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
489 DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
490 DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
491 DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
492 DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
493 DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
494 DATA_TYPE d06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
495 DATA_TYPE d07 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
497 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
499 float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
500 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
501 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
502 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
504 #if defined(HAS_BIAS)
510 out00 += (DATA_TYPE)
b;
511 out01 += (DATA_TYPE)
b;
512 out02 += (DATA_TYPE)
b;
513 out03 += (DATA_TYPE)
b;
514 #endif // defined(HAS_BIAS)
517 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
522 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
523 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
524 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
525 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
526 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
527 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out00, out01, out02, out03), A_VAL, B_VAL),
VEC_DATA_TYPE(DATA_TYPE, 4)),
528 0, (__global DATA_TYPE *)(dst_addr));
529 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
531 #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
533 DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
534 DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
535 DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
536 DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
537 DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
538 DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
539 DATA_TYPE d16 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
540 DATA_TYPE d17 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
542 DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
543 DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
544 DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
545 DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
546 DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
547 DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
548 DATA_TYPE d26 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
549 DATA_TYPE d27 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
551 DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
552 DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
553 DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
554 DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
555 DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
556 DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
557 DATA_TYPE d36 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
558 DATA_TYPE d37 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
560 DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
561 DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
562 DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
563 DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
564 DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 36 * src_stride_z));
565 DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 37 * src_stride_z));
566 DATA_TYPE d46 = *((__global DATA_TYPE *)(src_addr + 38 * src_stride_z));
567 DATA_TYPE d47 = *((__global DATA_TYPE *)(src_addr + 39 * src_stride_z));
569 DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 40 * src_stride_z));
570 DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 41 * src_stride_z));
571 DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 42 * src_stride_z));
572 DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 43 * src_stride_z));
573 DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 44 * src_stride_z));
574 DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 45 * src_stride_z));
575 DATA_TYPE d56 = *((__global DATA_TYPE *)(src_addr + 46 * src_stride_z));
576 DATA_TYPE d57 = *((__global DATA_TYPE *)(src_addr + 47 * src_stride_z));
578 DATA_TYPE d60 = *((__global DATA_TYPE *)(src_addr + 48 * src_stride_z));
579 DATA_TYPE d61 = *((__global DATA_TYPE *)(src_addr + 49 * src_stride_z));
580 DATA_TYPE d62 = *((__global DATA_TYPE *)(src_addr + 50 * src_stride_z));
581 DATA_TYPE d63 = *((__global DATA_TYPE *)(src_addr + 51 * src_stride_z));
582 DATA_TYPE d64 = *((__global DATA_TYPE *)(src_addr + 52 * src_stride_z));
583 DATA_TYPE d65 = *((__global DATA_TYPE *)(src_addr + 53 * src_stride_z));
584 DATA_TYPE d66 = *((__global DATA_TYPE *)(src_addr + 54 * src_stride_z));
585 DATA_TYPE d67 = *((__global DATA_TYPE *)(src_addr + 55 * src_stride_z));
587 DATA_TYPE d70 = *((__global DATA_TYPE *)(src_addr + 56 * src_stride_z));
588 DATA_TYPE d71 = *((__global DATA_TYPE *)(src_addr + 57 * src_stride_z));
589 DATA_TYPE d72 = *((__global DATA_TYPE *)(src_addr + 58 * src_stride_z));
590 DATA_TYPE d73 = *((__global DATA_TYPE *)(src_addr + 59 * src_stride_z));
591 DATA_TYPE d74 = *((__global DATA_TYPE *)(src_addr + 60 * src_stride_z));
592 DATA_TYPE d75 = *((__global DATA_TYPE *)(src_addr + 61 * src_stride_z));
593 DATA_TYPE d76 = *((__global DATA_TYPE *)(src_addr + 62 * src_stride_z));
594 DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
598 comm_fact0, comm_fact1, comm_fact2;
600 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
602 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
603 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
604 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
605 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
606 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
607 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
608 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
609 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
612 comm_fact0 = tmp_col1 + tmp_col2;
613 comm_fact1 = tmp_col3 + tmp_col4;
614 comm_fact2 = tmp_col5 + tmp_col6;
617 out_col0 = comm_fact0 + comm_fact1 + (
float)8.f * comm_fact2 + tmp_col0;
619 out_col2 = comm_fact0 + (
float)4.f * comm_fact1 + (
float)2.f * comm_fact2;
621 comm_fact0 = tmp_col1 - tmp_col2;
622 comm_fact1 = tmp_col3 - tmp_col4;
623 comm_fact2 = tmp_col5 - tmp_col6;
626 out_col1 = comm_fact0 + (
float)2.f * comm_fact1 + (
float)4.f * comm_fact2;
628 out_col3 = comm_fact0 + (
float)8.f * comm_fact1 + comm_fact2 + tmp_col7;
630 #if defined(HAS_BIAS)
640 #endif // defined(HAS_BIAS)
643 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out_col0.s0, out_col1.s0, out_col2.s0, out_col3.s0), A_VAL, B_VAL),
645 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
646 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out_col0.s1, out_col1.s1, out_col2.s1, out_col3.s1), A_VAL, B_VAL),
648 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
649 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out_col0.s2, out_col1.s2, out_col2.s2, out_col3.s2), A_VAL, B_VAL),
651 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
652 vstore4(
CONVERT(
ACTIVATION(ACTIVATION_TYPE,
float,
VEC_SIZE, (
VEC_DATA_TYPE(
float, 4))(out_col0.s3, out_col1.s3, out_col2.s3, out_col3.s3), A_VAL, B_VAL),
654 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
655 #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
657 #endif // defined(VEC_SIZE) && VEC_SIZE == 4
659 #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
660 #if defined(VEC_SIZE) && VEC_SIZE == 2
690 __kernel
void winograd_output_transform_2x1_3x1_nchw(
693 #
if defined(HAS_BIAS)
699 winograd_output_transform_2x2_3x3_nchw(src_ptr,
708 src_offset_first_element_in_bytes,
718 dst_offset_first_element_in_bytes
719 #
if defined(HAS_BIAS)
724 bias_offset_first_element_in_bytes
729 #endif // defined(VEC_SIZE) && VEC_SIZE == 2
731 #if defined(VEC_SIZE) && VEC_SIZE == 4
761 __kernel
void winograd_output_transform_4x1_3x1_nchw(
764 #
if defined(HAS_BIAS)
770 winograd_output_transform_4x4_3x3_nchw(src_ptr,
779 src_offset_first_element_in_bytes,
789 dst_offset_first_element_in_bytes
790 #
if defined(HAS_BIAS)
795 bias_offset_first_element_in_bytes
829 __kernel
void winograd_output_transform_4x1_5x1_nchw(
832 #
if defined(HAS_BIAS)
838 winograd_output_transform_4x4_5x5_nchw(src_ptr,
847 src_offset_first_element_in_bytes,
857 dst_offset_first_element_in_bytes
858 #
if defined(HAS_BIAS)
863 bias_offset_first_element_in_bytes
868 #endif // defined(VEC_SIZE) && VEC_SIZE == 4
869 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
871 #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
872 #if defined(VEC_SIZE) && VEC_SIZE == 2
902 __kernel
void winograd_output_transform_1x2_1x3_nchw(
905 #
if defined(HAS_BIAS)
911 winograd_output_transform_2x2_3x3_nchw(src_ptr,
920 src_offset_first_element_in_bytes,
930 dst_offset_first_element_in_bytes
931 #
if defined(HAS_BIAS)
936 bias_offset_first_element_in_bytes
941 #endif // defined(VEC_SIZE) && VEC_SIZE == 2
943 #if defined(VEC_SIZE) && VEC_SIZE == 4
973 __kernel
void winograd_output_transform_1x4_1x3_nchw(
976 #
if defined(HAS_BIAS)
982 winograd_output_transform_4x4_3x3_nchw(src_ptr,
991 src_offset_first_element_in_bytes,
1001 dst_offset_first_element_in_bytes
1002 #
if defined(HAS_BIAS)
1007 bias_offset_first_element_in_bytes
1041 __kernel
void winograd_output_transform_1x4_1x5_nchw(
1044 #
if defined(HAS_BIAS)
1050 winograd_output_transform_4x4_5x5_nchw(src_ptr,
1059 src_offset_first_element_in_bytes,
1069 dst_offset_first_element_in_bytes
1070 #
if defined(HAS_BIAS)
1075 bias_offset_first_element_in_bytes
1080 #endif // defined(VEC_SIZE) && VEC_SIZE == 4
1081 #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
1082 #endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)