Compute Library
 22.08
winograd_input_transform.cl
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "helpers.h"
25 #include "tile_helpers.h"
26 
27 #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
28  ({ \
29  comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
30  comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
31  comm_fact.s2 = 2.5f * tmp.s3; \
32  comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
33  comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
34  comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
35  comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
36  \
37  out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
38  out.s1 = comm_fact.s0 + comm_fact.s1; \
39  out.s2 = comm_fact.s0 - comm_fact.s1; \
40  out.s3 = comm_fact.s3 + comm_fact.s4; \
41  out.s4 = comm_fact.s4 - comm_fact.s3; \
42  out.s5 = comm_fact.s5 + comm_fact.s6; \
43  out.s6 = comm_fact.s5 - comm_fact.s6; \
44  out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
45  })
46 
47 #define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
48  ({ \
49  comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6; \
50  comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5; \
51  comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6; \
52  comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5; \
53  comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6; \
54  comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5; \
55  out.s0 = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6; \
56  out.s1 = comm_fact.s0 - comm_fact.s1; \
57  out.s2 = comm_fact.s0 + comm_fact.s1; \
58  out.s3 = comm_fact.s2 - comm_fact.s3; \
59  out.s4 = comm_fact.s2 + comm_fact.s3; \
60  out.s5 = comm_fact.s4 - comm_fact.s5; \
61  out.s6 = comm_fact.s4 + comm_fact.s5; \
62  out.s7 = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
63  })
64 
65 #if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
66 /** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
67  *
68  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
69  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
70  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
71  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
72  * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
73  * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
74  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
75  *
76  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
77  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
78  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
79  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
80  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
81  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
82  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
83  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
84  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
85  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
86  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
87  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
88  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
89  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
90  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
91  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
92  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
93  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
94  */
95 __kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
98  uint src_stride_w,
99  uint dst_stride_w)
100 {
101  const int x = get_global_id(0);
102  const int y = get_global_id(1);
103 #if defined(SRC_DEPTH)
104  const int z = get_global_id(2) % SRC_DEPTH;
105  const int b = get_global_id(2) / SRC_DEPTH;
106 #else /* defined(SRC_DEPTH) */
107  const int z = get_global_id(2);
108 #endif /* defined(SRC_DEPTH) */
109 
110  // Compute input address
111 #if defined(SRC_DEPTH)
112  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
113 #else /* defined(SRC_DEPTH) */
114  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
115 #endif /* defined(SRC_DEPTH) */
116 
117  src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
118 
119 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
120  VEC_DATA_TYPE(DATA_TYPE, 4)
121  in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
122 #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
123  VEC_DATA_TYPE(DATA_TYPE, 4)
124  in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
125  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
126  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
127  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
128 #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
129  VEC_DATA_TYPE(DATA_TYPE, 4)
130  in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
131  VEC_DATA_TYPE(DATA_TYPE, 4)
132  in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
133  VEC_DATA_TYPE(DATA_TYPE, 4)
134  in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
135  VEC_DATA_TYPE(DATA_TYPE, 4)
136  in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
137 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
138 
139  VEC_DATA_TYPE(DATA_TYPE, 4)
140  tmp0 = in_row0;
141 
142 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
143  tmp0 -= in_row2;
144 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
145 
146  DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
147  DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
148  DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
149  DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
150 
151 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
152  VEC_DATA_TYPE(DATA_TYPE, 4)
153  tmp1 = in_row1 + in_row2;
154  VEC_DATA_TYPE(DATA_TYPE, 4)
155  tmp2 = in_row2 - in_row1;
156  VEC_DATA_TYPE(DATA_TYPE, 4)
157  tmp3 = in_row1 - in_row3;
158 
159  DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
160  DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
161  DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
162  DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
163 
164  DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
165  DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
166  DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
167  DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
168 
169  DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
170  DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
171  DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
172  DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
173 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
174 
175 #if defined(SRC_DEPTH)
176  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
177 #else /* defined(SRC_DEPTH) */
178  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
179 #endif /* defined(SRC_DEPTH) */
180 
181  *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
182  *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
183  *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
184  *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
185 
186 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
187  *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
188  *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
189  *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
190  *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
191  *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
192  *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
193  *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
194  *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
195  *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
196  *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
197  *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
198  *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
199 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
200 }
201 
202 /** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
203  *
204  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
205  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
206  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
207  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
208  * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
209  * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
210  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
211  *
212  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
213  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
214  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
215  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
216  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
217  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
218  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
219  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
220  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
221  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
222  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
223  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
224  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
225  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
226  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
227  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
228  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
229  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
230  */
231 __kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
234  uint src_stride_w,
235  uint dst_stride_w)
236 {
237  const int x = get_global_id(0);
238  const int y = get_global_id(1);
239 #if defined(SRC_DEPTH)
240  const int z = (get_global_id(2) * 2) % SRC_DEPTH;
241  const int b = (get_global_id(2) * 2) / SRC_DEPTH;
242 #else /* defined(SRC_DEPTH) */
243  const int z = get_global_id(2) * 2;
244 #endif /* defined(SRC_DEPTH) */
245 
246  // Compute input address
247 #if defined(SRC_DEPTH)
248  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
249 #else /* defined(SRC_DEPTH) */
250  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
251 #endif /* defined(SRC_DEPTH) */
252  src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
253 
254 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
255  VEC_DATA_TYPE(DATA_TYPE, 4)
256  in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
257 #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
258  VEC_DATA_TYPE(DATA_TYPE, 4)
259  in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
260  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
261  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
262  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
263 #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
264  VEC_DATA_TYPE(DATA_TYPE, 4)
265  in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
266  VEC_DATA_TYPE(DATA_TYPE, 4)
267  in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
268  VEC_DATA_TYPE(DATA_TYPE, 4)
269  in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
270  VEC_DATA_TYPE(DATA_TYPE, 4)
271  in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
272 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
273 
274  src_addr += src_stride_z;
275 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
276  VEC_DATA_TYPE(DATA_TYPE, 4)
277  in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
278 #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
279  VEC_DATA_TYPE(DATA_TYPE, 4)
280  in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
281  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
282  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
283  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
284 #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
285  VEC_DATA_TYPE(DATA_TYPE, 4)
286  in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
287  VEC_DATA_TYPE(DATA_TYPE, 4)
288  in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
289  VEC_DATA_TYPE(DATA_TYPE, 4)
290  in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
291  VEC_DATA_TYPE(DATA_TYPE, 4)
292  in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
293 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
294 
295  VEC_DATA_TYPE(DATA_TYPE, 4)
296  tmp0 = in_row0;
297  VEC_DATA_TYPE(DATA_TYPE, 4)
298  tmp4 = in_row4;
299 
300 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
301  tmp0 -= in_row2;
302  tmp4 -= in_row6;
303 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
304 
305  VEC_DATA_TYPE(DATA_TYPE, 2)
306  out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
307  VEC_DATA_TYPE(DATA_TYPE, 2)
308  out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
309  VEC_DATA_TYPE(DATA_TYPE, 2)
310  out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
311  VEC_DATA_TYPE(DATA_TYPE, 2)
312  out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
313 
314 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
315  VEC_DATA_TYPE(DATA_TYPE, 4)
316  tmp1 = in_row1 + in_row2;
317  VEC_DATA_TYPE(DATA_TYPE, 4)
318  tmp2 = in_row2 - in_row1;
319  VEC_DATA_TYPE(DATA_TYPE, 4)
320  tmp3 = in_row1 - in_row3;
321 
322  VEC_DATA_TYPE(DATA_TYPE, 4)
323  tmp5 = in_row5 + in_row6;
324  VEC_DATA_TYPE(DATA_TYPE, 4)
325  tmp6 = in_row6 - in_row5;
326  VEC_DATA_TYPE(DATA_TYPE, 4)
327  tmp7 = in_row5 - in_row7;
328 
329  VEC_DATA_TYPE(DATA_TYPE, 2)
330  out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
331  VEC_DATA_TYPE(DATA_TYPE, 2)
332  out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
333  VEC_DATA_TYPE(DATA_TYPE, 2)
334  out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
335  VEC_DATA_TYPE(DATA_TYPE, 2)
336  out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
337 
338  VEC_DATA_TYPE(DATA_TYPE, 2)
339  out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
340  VEC_DATA_TYPE(DATA_TYPE, 2)
341  out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
342  VEC_DATA_TYPE(DATA_TYPE, 2)
343  out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
344  VEC_DATA_TYPE(DATA_TYPE, 2)
345  out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
346 
347  VEC_DATA_TYPE(DATA_TYPE, 2)
348  out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
349  VEC_DATA_TYPE(DATA_TYPE, 2)
350  out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
351  VEC_DATA_TYPE(DATA_TYPE, 2)
352  out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
353  VEC_DATA_TYPE(DATA_TYPE, 2)
354  out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
355 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
356 
357 #if defined(SRC_DEPTH)
358  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
359 #else /* defined(SRC_DEPTH) */
360  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
361 #endif /* defined(SRC_DEPTH) */
362 
363  vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
364  vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
365  vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
366  vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
367 
368 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
369  vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
370  vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
371  vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
372  vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
373  vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
374  vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
375  vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
376  vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
377  vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
378  vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
379  vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
380  vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
381 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
382 }
383 
384 /** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
385  *
386  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
387  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
388  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
389  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
390  * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
391  * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
392  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
393  *
394  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
395  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
396  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
397  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
398  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
399  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
400  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
401  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
402  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
403  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
404  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
405  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
406  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
407  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
408  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
409  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
410  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
411  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
412  */
413 __kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
416  uint src_stride_w,
417  uint dst_stride_w)
418 {
419  const int x = get_global_id(0);
420  const int y = get_global_id(1);
421 #if defined(SRC_DEPTH)
422  const int z = get_global_id(2) % SRC_DEPTH;
423  const int b = get_global_id(2) / SRC_DEPTH;
424 #else /* defined(SRC_DEPTH) */
425  const int z = get_global_id(2);
426 #endif /* defined(SRC_DEPTH) */
427 
428  // Compute input address
429 #if defined(SRC_DEPTH)
430  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
431 #else /* defined(SRC_DEPTH) */
432  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
433 #endif /* defined(SRC_DEPTH) */
434 
435  src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
436 
437 #if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
438  // Row0
439  VEC_DATA_TYPE(DATA_TYPE, 4)
440  d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
441  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
442  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
443  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
444  VEC_DATA_TYPE(DATA_TYPE, 2)
445  d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
446  *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
447 #else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
448  // Row0
449  VEC_DATA_TYPE(DATA_TYPE, 4)
450  d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
451  VEC_DATA_TYPE(DATA_TYPE, 2)
452  d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
453 #endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
454 
455  DATA_TYPE out0 = 0.0f;
456  DATA_TYPE out1 = 0.0f;
457  DATA_TYPE out2 = 0.0f;
458  DATA_TYPE out3 = 0.0f;
459  DATA_TYPE out4 = 0.0f;
460  DATA_TYPE out5 = 0.0f;
461 
462  // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
463  out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
464  out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
465  out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
466  out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
467  out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
468  out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
469 
470 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
471  // Row4
472  VEC_DATA_TYPE(DATA_TYPE, 4)
473  d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
474  VEC_DATA_TYPE(DATA_TYPE, 2)
475  d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
476 
477  // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
478  DATA_TYPE k0 = d41.s0;
479  DATA_TYPE k1 = d41.s0;
480  DATA_TYPE k2 = d41.s0;
481  DATA_TYPE k3 = d41.s0;
482  DATA_TYPE k4 = d41.s0;
483  DATA_TYPE k5 = 0.0f;
484 
485  k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
486  k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
487  k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
488  k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
489  k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
490  k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
491 
492  out0 += k0;
493  out1 += k1;
494  out2 += k2;
495  out3 += k3;
496  out4 += k4;
497  out5 += k5;
498 
499  // Row2
500  VEC_DATA_TYPE(DATA_TYPE, 4)
501  d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
502  VEC_DATA_TYPE(DATA_TYPE, 2)
503  d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
504 
505  out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
506  out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
507  out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
508  out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
509  out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
510  out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
511 #endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
512 
513  // Compute destination address
514 #if defined(SRC_DEPTH)
515  __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
516 #else /* defined(SRC_DEPTH) */
517  __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
518 #endif /* defined(SRC_DEPTH) */
519 
520  uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
521 
522  *(dst_addr) = out0;
523  dst_addr += dst_plane_stride;
524  *(dst_addr) = out1;
525  dst_addr += dst_plane_stride;
526  *(dst_addr) = out2;
527  dst_addr += dst_plane_stride;
528  *(dst_addr) = out3;
529  dst_addr += dst_plane_stride;
530  *(dst_addr) = out4;
531  dst_addr += dst_plane_stride;
532  *(dst_addr) = out5;
533  dst_addr += dst_plane_stride;
534 
535 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
536  DATA_TYPE out6 = k0;
537  DATA_TYPE out7 = k1;
538  DATA_TYPE out8 = k2;
539  DATA_TYPE out9 = k3;
540  DATA_TYPE out10 = k4;
541  DATA_TYPE out11 = k5;
542  DATA_TYPE out12 = k0;
543  DATA_TYPE out13 = k1;
544  DATA_TYPE out14 = k2;
545  DATA_TYPE out15 = k3;
546  DATA_TYPE out16 = k4;
547  DATA_TYPE out17 = k5;
548  DATA_TYPE out18 = k0;
549  DATA_TYPE out19 = k1;
550  DATA_TYPE out20 = k2;
551  DATA_TYPE out21 = k3;
552  DATA_TYPE out22 = k4;
553  DATA_TYPE out23 = k5;
554  DATA_TYPE out24 = k0;
555  DATA_TYPE out25 = k1;
556  DATA_TYPE out26 = k2;
557  DATA_TYPE out27 = k3;
558  DATA_TYPE out28 = k4;
559  DATA_TYPE out29 = k5;
560 
561  // Row1
562  VEC_DATA_TYPE(DATA_TYPE, 4)
563  d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
564  VEC_DATA_TYPE(DATA_TYPE, 2)
565  d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
566 
567  // Row3
568  VEC_DATA_TYPE(DATA_TYPE, 4)
569  d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
570  VEC_DATA_TYPE(DATA_TYPE, 2)
571  d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
572 
573  // Compute common parts for the channels between [6, 29]
574  // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
575  // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
576  DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
577  DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
578  DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
579  DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
580  DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
581  DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
582  DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
583  DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
584  DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
585  DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
586  DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
587  DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
588 
589  // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
590  // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
591  DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
592  DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
593  DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
594  DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
595  DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
596  DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
597  DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
598  DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
599  DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
600  DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
601  DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
602  DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
603 
604  out6 += part0 - part1;
605  out12 += part0 + part1;
606  out7 += part2 + part3 + part4 + part5;
607  out8 += part2 - part3 + part4 - part5;
608  out13 += part2 + part3 - part4 - part5;
609  out14 += part2 - part3 - part4 + part5;
610  out9 += part6 + part7 + part8 + part9;
611  out10 += part6 - part7 + part8 - part9;
612  out15 += part6 - part7 - part8 + part9;
613  out16 += part6 + part7 - part8 - part9;
614  out11 += part10 + part11;
615  out17 += part10 - part11;
616 
617  out18 += part13 - part12;
618  out24 += part13 + part12;
619  out19 += part14 + part15 + part16 + part17;
620  out20 += part14 - part15 + part16 - part17;
621  out25 += part14 - part15 - part16 + part17;
622  out26 += part14 + part15 - part16 - part17;
623  out21 += part18 + part19 + part20 + part21;
624  out22 += part18 - part19 + part20 - part21;
625  out27 += part18 - part19 - part20 + part21;
626  out28 += part18 + part19 - part20 - part21;
627  out23 += part22 + part23;
628  out29 += part22 - part23;
629 
630  *(dst_addr) = out6;
631  dst_addr += dst_plane_stride;
632  *(dst_addr) = out7;
633  dst_addr += dst_plane_stride;
634  *(dst_addr) = out8;
635  dst_addr += dst_plane_stride;
636  *(dst_addr) = out9;
637  dst_addr += dst_plane_stride;
638  *(dst_addr) = out10;
639  dst_addr += dst_plane_stride;
640  *(dst_addr) = out11;
641  dst_addr += dst_plane_stride;
642  *(dst_addr) = out12;
643  dst_addr += dst_plane_stride;
644  *(dst_addr) = out13;
645  dst_addr += dst_plane_stride;
646  *(dst_addr) = out14;
647  dst_addr += dst_plane_stride;
648  *(dst_addr) = out15;
649  dst_addr += dst_plane_stride;
650  *(dst_addr) = out16;
651  dst_addr += dst_plane_stride;
652  *(dst_addr) = out17;
653  dst_addr += dst_plane_stride;
654 
655  *(dst_addr) = out18;
656  dst_addr += dst_plane_stride;
657  *(dst_addr) = out19;
658  dst_addr += dst_plane_stride;
659  *(dst_addr) = out20;
660  dst_addr += dst_plane_stride;
661  *(dst_addr) = out21;
662  dst_addr += dst_plane_stride;
663  *(dst_addr) = out22;
664  dst_addr += dst_plane_stride;
665  *(dst_addr) = out23;
666  dst_addr += dst_plane_stride;
667  *(dst_addr) = out24;
668  dst_addr += dst_plane_stride;
669  *(dst_addr) = out25;
670  dst_addr += dst_plane_stride;
671  *(dst_addr) = out26;
672  dst_addr += dst_plane_stride;
673  *(dst_addr) = out27;
674  dst_addr += dst_plane_stride;
675  *(dst_addr) = out28;
676  dst_addr += dst_plane_stride;
677  *(dst_addr) = out29;
678  dst_addr += dst_plane_stride;
679 
680  // Row5
681  VEC_DATA_TYPE(DATA_TYPE, 4)
682  d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
683  VEC_DATA_TYPE(DATA_TYPE, 2)
684  d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
685 
686  // Channels [30, 35]
687  out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
688  out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
689  out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
690  out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
691  out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
692  out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
693 
694  *(dst_addr) = out0;
695  dst_addr += dst_plane_stride;
696  *(dst_addr) = out1;
697  dst_addr += dst_plane_stride;
698  *(dst_addr) = out2;
699  dst_addr += dst_plane_stride;
700  *(dst_addr) = out3;
701  dst_addr += dst_plane_stride;
702  *(dst_addr) = out4;
703  dst_addr += dst_plane_stride;
704  *(dst_addr) = out5;
705  dst_addr += dst_plane_stride;
706 #endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
707 }
708 
709 /** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
710  *
711  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
712  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
713  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
714  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
715  * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
716  * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
717  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
718  *
719  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
720  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
721  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
722  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
723  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
724  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
725  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
726  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
727  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
728  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
729  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
730  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
731  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
732  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
733  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
734  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
735  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
736  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
737  */
738 __kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
741  uint src_stride_w,
742  uint dst_stride_w)
743 {
744  const int x = get_global_id(0);
745  const int y = get_global_id(1);
746 #if defined(SRC_DEPTH)
747  const int z = get_global_id(2) % SRC_DEPTH;
748  const int b = get_global_id(2) / SRC_DEPTH;
749 #else /* defined(SRC_DEPTH) */
750  const int z = get_global_id(2);
751 #endif /* defined(SRC_DEPTH) */
752 
753  // Compute input address
754 #if defined(SRC_DEPTH)
755  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
756 #else /* defined(SRC_DEPTH) */
757  __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
758 #endif /* defined(SRC_DEPTH) */
759  src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
760 
761  // Load input tile
762 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
763  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
764 #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
765  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
766  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
767  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
768  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
769  *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
770  *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
771  *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
772  *((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
773 #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
774  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
775  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
776  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
777  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
778  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
779  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
780  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
781  const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
782 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
783 
784  // Calculate common factors for intermediate tensor
785  VEC_DATA_TYPE(DATA_TYPE, 8)
786  tmp0 = in_row0;
787  VEC_DATA_TYPE(DATA_TYPE, 8)
788  comm_fact0 = 0.0f;
789 
790 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
791  comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
792  tmp0 += -in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
793 
794  VEC_DATA_TYPE(DATA_TYPE, 8)
795  comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
796  VEC_DATA_TYPE(DATA_TYPE, 8)
797  comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
798 
799  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
800  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
801 
802  comm_fact0 = (DATA_TYPE)2.5f * in_row3;
803  comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.0f * in_row5;
804 
805  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
806  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
807 
808  comm_fact1 = (DATA_TYPE)2.0f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
809  comm_fact2 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
810 
811  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
812  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
813  const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
814 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
815 
816  // Calculate output rows (reuse comm_fact0 vector)
817  VEC_DATA_TYPE(DATA_TYPE, 8)
818  out0;
819 
820  OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
821 
822 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
823  VEC_DATA_TYPE(DATA_TYPE, 8)
824  out1, out2, out3, out4, out5, out6, out7;
825 
826  OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
827  OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
828  OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
829  OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
830  OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
831  OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
832  OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
833 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
834 
835  // Store values across the channels
836 #if defined(SRC_DEPTH)
837  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
838 #else /* defined(SRC_DEPTH) */
839  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
840 #endif /* defined(SRC_DEPTH) */
841 
842  *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
843  *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
844  *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
845  *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
846  *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
847  *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
848  *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
849  *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
850 
851 #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
852  *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
853  *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
854  *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
855  *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
856  *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
857  *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
858  *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
859  *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
860  *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
861  *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
862  *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
863  *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
864  *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
865  *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
866  *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
867  *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
868  *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
869  *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
870  *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
871  *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
872  *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
873  *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
874  *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
875  *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
876  *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
877  *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
878  *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
879  *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
880  *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
881  *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
882  *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
883  *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
884  *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
885  *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
886  *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
887  *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
888  *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
889  *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
890  *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
891  *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
892  *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
893  *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
894  *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
895  *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
896  *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
897  *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
898  *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
899  *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
900  *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
901  *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
902  *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
903  *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
904  *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
905  *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
906  *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
907  *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
908 #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
909 }
910 
911 #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
912 /** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
913  *
914  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
915  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
916  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
917  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
918  * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
919  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
920  *
921  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
922  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
923  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
924  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
925  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
926  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
927  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
928  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
929  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
930  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
931  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
932  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
933  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
934  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
935  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
936  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
937  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
938  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
939  */
940 __kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
943  uint src_stride_w,
944  uint dst_stride_w)
945 {
946  winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
947  src_stride_x,
948  src_step_x,
949  src_stride_y,
950  src_step_y,
951  src_stride_z,
952  src_step_z,
953  src_offset_first_element_in_bytes,
954  dst_ptr,
955  dst_stride_x,
956  dst_step_x,
957  dst_stride_y,
958  dst_step_y,
959  dst_stride_z,
960  dst_step_z,
961  dst_offset_first_element_in_bytes,
962  src_stride_w,
963  dst_stride_w);
964 }
965 
966 /** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
967  *
968  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
969  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
970  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
971  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
972  * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
973  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
974  *
975  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
976  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
977  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
978  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
979  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
980  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
981  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
982  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
983  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
984  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
985  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
986  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
987  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
988  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
989  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
990  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
991  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
992  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
993  */
994 __kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
997  uint src_stride_w,
998  uint dst_stride_w)
999 {
1000  winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1001  src_stride_x,
1002  src_step_x,
1003  src_stride_y,
1004  src_step_y,
1005  src_stride_z,
1006  src_step_z,
1007  src_offset_first_element_in_bytes,
1008  dst_ptr,
1009  dst_stride_x,
1010  dst_step_x,
1011  dst_stride_y,
1012  dst_step_y,
1013  dst_stride_z,
1014  dst_step_z,
1015  dst_offset_first_element_in_bytes,
1016  src_stride_w,
1017  dst_stride_w);
1018 }
1019 
1020 /** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
1021  *
1022  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1023  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1024  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1025  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1026  * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1027  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1028  *
1029  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1030  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1031  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1032  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1033  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1034  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1035  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1036  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1037  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1038  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1039  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1040  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1041  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1042  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1043  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1044  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1045  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1046  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1047  */
1048 __kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
1051  uint src_stride_w,
1052  uint dst_stride_w)
1053 {
1054  winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1055  src_stride_x,
1056  src_step_x,
1057  src_stride_y,
1058  src_step_y,
1059  src_stride_z,
1060  src_step_z,
1061  src_offset_first_element_in_bytes,
1062  dst_ptr,
1063  dst_stride_x,
1064  dst_step_x,
1065  dst_stride_y,
1066  dst_step_y,
1067  dst_stride_z,
1068  dst_step_z,
1069  dst_offset_first_element_in_bytes,
1070  src_stride_w,
1071  dst_stride_w);
1072 }
1073 
1074 /** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
1075  *
1076  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1077  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1078  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1079  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1080  * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1081  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1082  *
1083  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1084  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1085  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1086  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1087  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1088  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1089  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1090  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1091  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1092  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1093  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1094  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1095  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1096  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1097  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1098  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1099  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1100  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1101  */
1102 __kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
1105  uint src_stride_w,
1106  uint dst_stride_w)
1107 {
1108  winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1109  src_stride_x,
1110  src_step_x,
1111  src_stride_y,
1112  src_step_y,
1113  src_stride_z,
1114  src_step_z,
1115  src_offset_first_element_in_bytes,
1116  dst_ptr,
1117  dst_stride_x,
1118  dst_step_x,
1119  dst_stride_y,
1120  dst_step_y,
1121  dst_stride_z,
1122  dst_step_z,
1123  dst_offset_first_element_in_bytes,
1124  src_stride_w,
1125  dst_stride_w);
1126 }
1127 #endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1128 
1129 #if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1130 /** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
1131  *
1132  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1133  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1134  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1135  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1136  * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1137  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1138  *
1139  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1140  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1141  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1142  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1143  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1144  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1145  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1146  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1147  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1148  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1149  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1150  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1151  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1152  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1153  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1154  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1155  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1156  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1157  */
1158 __kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
1161  uint src_stride_w,
1162  uint dst_stride_w)
1163 {
1164  winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1165  src_stride_x,
1166  src_step_x,
1167  src_stride_y,
1168  src_step_y,
1169  src_stride_z,
1170  src_step_z,
1171  src_offset_first_element_in_bytes,
1172  dst_ptr,
1173  dst_stride_x,
1174  dst_step_x,
1175  dst_stride_y,
1176  dst_step_y,
1177  dst_stride_z,
1178  dst_step_z,
1179  dst_offset_first_element_in_bytes,
1180  src_stride_w,
1181  dst_stride_w);
1182 }
1183 
1184 /** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
1185  *
1186  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1187  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1188  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1189  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1190  * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1191  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1192  *
1193  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1194  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1195  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1196  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1197  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1198  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1199  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1200  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1201  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1202  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1203  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1204  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1205  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1206  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1207  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1208  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1209  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1210  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1211  */
1212 __kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
1215  uint src_stride_w,
1216  uint dst_stride_w)
1217 {
1218  winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1219  src_stride_x,
1220  src_step_x,
1221  src_stride_y,
1222  src_step_y,
1223  src_stride_z,
1224  src_step_z,
1225  src_offset_first_element_in_bytes,
1226  dst_ptr,
1227  dst_stride_x,
1228  dst_step_x,
1229  dst_stride_y,
1230  dst_step_y,
1231  dst_stride_z,
1232  dst_step_z,
1233  dst_offset_first_element_in_bytes,
1234  src_stride_w,
1235  dst_stride_w);
1236 }
1237 
1238 /** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
1239  *
1240  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1241  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1242  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1243  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1244  * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1245  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1246  *
1247  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1248  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1249  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1250  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1251  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1252  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1253  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1254  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1255  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1256  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1257  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1258  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1259  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1260  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1261  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1262  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1263  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1264  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1265  */
1266 __kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
1269  uint src_stride_w,
1270  uint dst_stride_w)
1271 {
1272  winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1273  src_stride_x,
1274  src_step_x,
1275  src_stride_y,
1276  src_step_y,
1277  src_stride_z,
1278  src_step_z,
1279  src_offset_first_element_in_bytes,
1280  dst_ptr,
1281  dst_stride_x,
1282  dst_step_x,
1283  dst_stride_y,
1284  dst_step_y,
1285  dst_stride_z,
1286  dst_step_z,
1287  dst_offset_first_element_in_bytes,
1288  src_stride_w,
1289  dst_stride_w);
1290 }
1291 
1292 /** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
1293  *
1294  * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1295  * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1296  * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1297  * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1298  * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1299  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1300  *
1301  * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
1302  * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1303  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1304  * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1305  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1306  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1307  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1308  * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1309  * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1310  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1311  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1312  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1313  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1314  * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1315  * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1316  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1317  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1318  * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1319  */
1320 __kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
1323  uint src_stride_w,
1324  uint dst_stride_w)
1325 {
1326  winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1327  src_stride_x,
1328  src_step_x,
1329  src_stride_y,
1330  src_step_y,
1331  src_stride_z,
1332  src_step_z,
1333  src_offset_first_element_in_bytes,
1334  dst_ptr,
1335  dst_stride_x,
1336  dst_step_x,
1337  dst_stride_y,
1338  dst_step_y,
1339  dst_stride_z,
1340  dst_step_z,
1341  dst_offset_first_element_in_bytes,
1342  src_stride_w,
1343  dst_stride_w);
1344 }
1345 #endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1346 #endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
SimpleTensor< float > b
Definition: DFT.cpp:157
SimpleTensor< float > src
Definition: DFT.cpp:155
#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact)
#define TENSOR3D_DECLARATION(name)
Definition: helpers.h:813
#define VEC_DATA_TYPE(type, size)
Definition: helpers.h:728