Compute Library
 22.08
winograd_filter_transform.cl
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "helpers.h"
25 
26 #if defined(SRC_DIM_Z)
27 /** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
28  *
29  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
30  * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
31  * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
32  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
33  *
34  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
35  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
36  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
37  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
38  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
39  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
40  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
41  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
42  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
43  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
44  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
45  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
50  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
51  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52  */
53 __kernel void winograd_filter_transform_2x2_3x3_nchw(
56 {
57  Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
58 
59  const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
60 
61  // Load the values from the input tensor
62 #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
63  VEC_DATA_TYPE(DATA_TYPE, 3)
64  w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
65 #elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
66  VEC_DATA_TYPE(DATA_TYPE, 3)
67  w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
68  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
69  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
70 #else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
71  VEC_DATA_TYPE(DATA_TYPE, 3)
72  w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
73  VEC_DATA_TYPE(DATA_TYPE, 3)
74  w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
75  VEC_DATA_TYPE(DATA_TYPE, 3)
76  w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
77 #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
78 
79  // Row 0
80  VEC_DATA_TYPE(DATA_TYPE, 4)
81  out0 = 0.0f;
82  out0.s0 = (w0.s0);
83  out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
84  out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
85  out0.s3 = (w0.s2);
86 
87 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
88  // Row 1
89  VEC_DATA_TYPE(DATA_TYPE, 4)
90  out1 = 0.0f;
91  out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
92  out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
93  out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
94  out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
95 
96  // Row 2
97  VEC_DATA_TYPE(DATA_TYPE, 4)
98  out2 = 0.0f;
99  out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
100  out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
101  out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
102  out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
103 
104  // Row 3
105  VEC_DATA_TYPE(DATA_TYPE, 4)
106  out3 = 0.0f;
107  out3.s0 = (w2.s0);
108  out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
109  out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
110  out3.s3 = (w2.s2);
111 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
112 
113  int z = get_global_id(2);
114  int x0 = z / SRC_DIM_Z; // idx filter
115  int y0 = z % SRC_DIM_Z; // idx channel
116 
117  // Get output address
118  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
119 
120  // Store the values across the channels
121  // 16 channels for 3x3 kernels
122  // 4 channels for 3x1 or 1x3 kernels
123  *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
124  *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
125  *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
126  *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
127 
128 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
129  *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out1.s0;
130  *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out1.s1;
131  *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s2;
132  *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s3;
133  *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out2.s0;
134  *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out2.s1;
135  *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out2.s2;
136  *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out2.s3;
137  *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out3.s0;
138  *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out3.s1;
139  *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out3.s2;
140  *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out3.s3;
141 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
142 }
143 
144 /** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
145  *
146  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
147  * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
148  * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
149  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
150  *
151  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
152  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
153  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
154  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
155  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
156  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
157  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
158  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
159  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
160  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
161  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
162  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
163  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
164  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
165  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
166  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
167  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
168  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
169  */
170 __kernel void winograd_filter_transform_4x4_3x3_nchw(
173 {
174  Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
175 
176  const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
177 
178  // Load the values from the input tensor
179 #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
180  VEC_DATA_TYPE(DATA_TYPE, 3)
181  w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
182 #elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
183  VEC_DATA_TYPE(DATA_TYPE, 3)
184  w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
185  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
186  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
187 #else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
188  VEC_DATA_TYPE(DATA_TYPE, 3)
189  w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
190  VEC_DATA_TYPE(DATA_TYPE, 3)
191  w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
192  VEC_DATA_TYPE(DATA_TYPE, 3)
193  w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
194 #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
195 
196  // Row 0
197  VEC_DATA_TYPE(DATA_TYPE, 8)
198  out0 = 0.0f;
199  out0.s0 = (w0.s0) / 16.f;
200  out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
201  out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
202  out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
203  out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
204  out0.s5 = (w0.s2) / 4.f;
205 
206 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
207  // Row 1
208  VEC_DATA_TYPE(DATA_TYPE, 8)
209  out1 = 0.0f;
210  out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
211  out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
212  out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
213  out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
214  out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
215  out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
216 
217  // Row 2
218  VEC_DATA_TYPE(DATA_TYPE, 8)
219  out2 = 0.0f;
220  out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
221  out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
222  out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
223  out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
224  out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
225  out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
226 
227  // Row 3
228  VEC_DATA_TYPE(DATA_TYPE, 8)
229  out3 = 0.0f;
230  out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
231  out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
232  out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
233  out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
234  out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
235  out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
236 
237  // Row 4
238  VEC_DATA_TYPE(DATA_TYPE, 8)
239  out4 = 0.0f;
240  out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
241  out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
242  out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
243  out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
244  out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
245  out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
246 
247  // Row 5
248  VEC_DATA_TYPE(DATA_TYPE, 8)
249  out5 = 0.0f;
250  out5.s0 = (w2.s0) / 4.f;
251  out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
252  out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
253  out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
254  out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
255  out5.s5 = (w2.s2);
256 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
257 
258  int z = get_global_id(2);
259  int x0 = z / SRC_DIM_Z; // idx filter
260  int y0 = z % SRC_DIM_Z; // idx channel
261 
262  // Get output address
263  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
264 
265  // Store the values across the channels
266  // 36 channels for 3x3 kernels
267  // 6 channels for 3x1 or 1x3 kernels
268  *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
269  *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
270  *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
271  *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
272  *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
273  *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
274 
275 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
276  *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s0;
277  *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s1;
278  *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s2;
279  *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s3;
280  *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s4;
281  *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s5;
282  *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out2.s0;
283  *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out2.s1;
284  *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out2.s2;
285  *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out2.s3;
286  *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s4;
287  *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s5;
288  *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out3.s0;
289  *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out3.s1;
290  *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out3.s2;
291  *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out3.s3;
292  *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out3.s4;
293  *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out3.s5;
294  *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out4.s0;
295  *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out4.s1;
296  *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out4.s2;
297  *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out4.s3;
298  *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out4.s4;
299  *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out4.s5;
300  *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out5.s0;
301  *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out5.s1;
302  *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out5.s2;
303  *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out5.s3;
304  *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out5.s4;
305  *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out5.s5;
306 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
307 }
308 
309 /** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
310  *
311  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
312  *
313  * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
314  * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
315  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
316  *
317  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
318  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
319  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
320  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
321  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
322  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
323  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
324  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
325  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
326  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
327  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
328  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
329  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
330  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
331  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
332  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
333  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
334  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
335  */
336 __kernel void winograd_filter_transform_4x4_5x5_nchw(
339 {
340  Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
341 
342  const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
343 
344  // Load the values from the input tensor
345 #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
346  VEC_DATA_TYPE(DATA_TYPE, 4)
347  w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
348  DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
349 #elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
350  VEC_DATA_TYPE(DATA_TYPE, 4)
351  w00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
352  *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
353  *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
354  *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
355  DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
356 #else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
357  VEC_DATA_TYPE(DATA_TYPE, 4)
358  w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
359  DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
360  VEC_DATA_TYPE(DATA_TYPE, 4)
361  w10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
362  DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y) + 4);
363  VEC_DATA_TYPE(DATA_TYPE, 4)
364  w20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
365  DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y) + 4);
366  VEC_DATA_TYPE(DATA_TYPE, 4)
367  w30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
368  DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y) + 4);
369  VEC_DATA_TYPE(DATA_TYPE, 4)
370  w40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
371  DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y) + 4);
372 #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
373 
374  // Transform the input tile
375 
376  // Row 0
377  VEC_DATA_TYPE(DATA_TYPE, 8)
378  out0 = 0.0f;
379  out0.s0 = w00.s0;
380  out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
381  out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
382  out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
383  out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
384  out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
385  out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
386  out0.s7 = w01;
387 
388 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
389  // Row 1
390  VEC_DATA_TYPE(DATA_TYPE, 8)
391  out1 = 0.0f;
392  out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
393  out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
394  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
395  out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
396  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
397  out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
398  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
399  out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
400  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
401  out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
402  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
403  out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
404  (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
405  out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
406 
407  // Row 2
408  VEC_DATA_TYPE(DATA_TYPE, 8)
409  out2 = 0.0f;
410  out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
411  out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
412  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
413  out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
414  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
415  out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
416  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
417  out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
418  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
419  out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
420  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
421  out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
422  (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
423  out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
424 
425  // Row 3
426  VEC_DATA_TYPE(DATA_TYPE, 8)
427  out3 = 0.0f;
428  out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
429  out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
430  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
431  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
432  out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
433  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
434  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
435  out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
436  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
437  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
438  out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
439  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
440  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
441  out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
442  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
443  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
444  out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
445  (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
446  (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
447  out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
448 
449  // Row 4
450  VEC_DATA_TYPE(DATA_TYPE, 8)
451  out4 = 0.0f;
452  out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
453  out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
454  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
455  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
456  out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
457  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
458  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
459  out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
460  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
461  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
462  out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
463  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
464  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
465  out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
466  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
467  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
468  out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
469  (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
470  (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
471  out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
472 
473  // Row 5
474  VEC_DATA_TYPE(DATA_TYPE, 8)
475  out5 = 0.0f;
476  out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
477  out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
478  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
479  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
480  out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
481  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
482  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
483  out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
484  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
485  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
486  out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
487  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
488  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
489  out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
490  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
491  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
492  out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
493  (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
494  (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
495  out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
496 
497  // Row 6
498  VEC_DATA_TYPE(DATA_TYPE, 8)
499  out6 = 0.0f;
500  out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
501  out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
502  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
503  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
504  out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
505  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
506  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
507  out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
508  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
509  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
510  out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
511  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
512  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
513  out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
514  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
515  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
516  out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
517  (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
518  (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
519  out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
520 
521  // Row 7
522  VEC_DATA_TYPE(DATA_TYPE, 8)
523  out7 = 0.0f;
524  out7.s0 = w40.s0;
525  out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
526  out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
527  out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
528  out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
529  out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
530  out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
531  out7.s7 = w41;
532 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
533 
534  int z = get_global_id(2);
535  int x0 = z / SRC_DIM_Z; // idx filter
536  int y0 = z % SRC_DIM_Z; // idx channel
537 
538  // Get output address
539  __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
540 
541  // Store the values across the channels
542  *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
543  *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
544  *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
545  *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
546  *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
547  *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
548  *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
549  *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
550 
551 #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
552  *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
553  *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
554  *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
555  *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
556  *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
557  *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
558  *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
559  *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
560  *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
561  *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
562  *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
563  *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
564  *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
565  *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
566  *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
567  *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
568  *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
569  *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
570  *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
571  *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
572  *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
573  *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
574  *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
575  *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
576  *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
577  *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
578  *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
579  *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
580  *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
581  *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
582  *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
583  *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
584  *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
585  *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
586  *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
587  *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
588  *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
589  *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
590  *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
591  *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
592  *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
593  *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
594  *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
595  *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
596  *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
597  *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
598  *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
599  *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
600  *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
601  *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
602  *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
603  *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
604  *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
605  *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
606  *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
607  *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
608 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
609 }
610 
611 #endif // defined(SRC_DIM_Z)
612 
613 #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
614 /** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
615  *
616  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
617  * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
618  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
619  *
620  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
621  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
622  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
623  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
624  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
625  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
626  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
627  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
628  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
629  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
630  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
631  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
632  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
633  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
634  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
635  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
636  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
637  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
638  */
639 __kernel void winograd_filter_transform_2x1_3x1_nchw(
642 {
643  winograd_filter_transform_2x2_3x3_nchw(src_ptr,
644  src_stride_x,
645  src_step_x,
646  src_stride_y,
647  src_step_y,
648  src_stride_z,
649  src_step_z,
650  src_stride_w,
651  src_step_w,
652  src_offset_first_element_in_bytes,
653  dst_ptr,
654  dst_stride_x,
655  dst_step_x,
656  dst_stride_y,
657  dst_step_y,
658  dst_stride_z,
659  dst_step_z,
660  dst_offset_first_element_in_bytes);
661 }
662 
663 /** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
664  *
665  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
666  * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
667  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
668  *
669  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
670  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
671  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
672  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
673  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
674  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
675  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
676  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
677  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
678  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
679  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
680  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
681  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
682  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
683  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
684  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
685  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
686  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
687  */
688 __kernel void winograd_filter_transform_4x1_3x1_nchw(
691 {
692  winograd_filter_transform_4x4_3x3_nchw(src_ptr,
693  src_stride_x,
694  src_step_x,
695  src_stride_y,
696  src_step_y,
697  src_stride_z,
698  src_step_z,
699  src_stride_w,
700  src_step_w,
701  src_offset_first_element_in_bytes,
702  dst_ptr,
703  dst_stride_x,
704  dst_step_x,
705  dst_stride_y,
706  dst_step_y,
707  dst_stride_z,
708  dst_step_z,
709  dst_offset_first_element_in_bytes);
710 }
711 
712 /** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
713  *
714  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
715  * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
716  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
717  *
718  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
719  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
720  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
721  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
722  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
723  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
724  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
725  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
726  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
727  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
728  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
729  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
730  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
731  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
732  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
733  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
734  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
735  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
736  */
737 __kernel void winograd_filter_transform_4x1_5x1_nchw(
740 {
741  winograd_filter_transform_4x4_5x5_nchw(src_ptr,
742  src_stride_x,
743  src_step_x,
744  src_stride_y,
745  src_step_y,
746  src_stride_z,
747  src_step_z,
748  src_stride_w,
749  src_step_w,
750  src_offset_first_element_in_bytes,
751  dst_ptr,
752  dst_stride_x,
753  dst_step_x,
754  dst_stride_y,
755  dst_step_y,
756  dst_stride_z,
757  dst_step_z,
758  dst_offset_first_element_in_bytes);
759 }
760 
761 #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
762 
763 #if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
764 /** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
765  *
766  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
767  * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
768  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
769  *
770  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
771  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
772  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
773  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
774  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
775  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
776  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
777  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
778  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
779  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
780  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
781  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
782  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
783  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
784  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
785  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
786  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
787  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
788  */
789 __kernel void winograd_filter_transform_1x2_1x3_nchw(
792 {
793  winograd_filter_transform_2x2_3x3_nchw(src_ptr,
794  src_stride_x,
795  src_step_x,
796  src_stride_y,
797  src_step_y,
798  src_stride_z,
799  src_step_z,
800  src_stride_w,
801  src_step_w,
802  src_offset_first_element_in_bytes,
803  dst_ptr,
804  dst_stride_x,
805  dst_step_x,
806  dst_stride_y,
807  dst_step_y,
808  dst_stride_z,
809  dst_step_z,
810  dst_offset_first_element_in_bytes);
811 }
812 
813 /** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
814  *
815  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
816  * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
817  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
818  *
819  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
820  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
821  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
822  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
823  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
824  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
825  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
826  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
827  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
828  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
829  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
830  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
831  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
832  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
833  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
834  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
835  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
836  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
837  */
838 __kernel void winograd_filter_transform_1x4_1x3_nchw(
841 {
842  winograd_filter_transform_4x4_3x3_nchw(src_ptr,
843  src_stride_x,
844  src_step_x,
845  src_stride_y,
846  src_step_y,
847  src_stride_z,
848  src_step_z,
849  src_stride_w,
850  src_step_w,
851  src_offset_first_element_in_bytes,
852  dst_ptr,
853  dst_stride_x,
854  dst_step_x,
855  dst_stride_y,
856  dst_step_y,
857  dst_stride_z,
858  dst_step_z,
859  dst_offset_first_element_in_bytes);
860 }
861 
862 /** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
863  *
864  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
865  * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
866  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
867  *
868  * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
869  * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
870  * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
871  * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
872  * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
873  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
874  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
875  * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
876  * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
877  * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
878  * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
879  * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
880  * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
881  * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
882  * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
883  * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
884  * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
885  * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
886  */
887 __kernel void winograd_filter_transform_1x4_1x5_nchw(
890 {
891  winograd_filter_transform_4x4_5x5_nchw(src_ptr,
892  src_stride_x,
893  src_step_x,
894  src_stride_y,
895  src_step_y,
896  src_stride_z,
897  src_step_z,
898  src_stride_w,
899  src_step_w,
900  src_offset_first_element_in_bytes,
901  dst_ptr,
902  dst_stride_x,
903  dst_step_x,
904  dst_stride_y,
905  dst_step_y,
906  dst_stride_z,
907  dst_step_z,
908  dst_offset_first_element_in_bytes);
909 }
910 
911 #endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
SimpleTensor< float > src
Definition: DFT.cpp:155
Structure to hold 4D tensor information.
Definition: helpers.h:916
__global const uchar * tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
Get the pointer position of a Tensor4D.
Definition: helpers.h:1109
#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)
Definition: helpers.h:877
#define TENSOR4D_DECLARATION(name)
Definition: helpers.h:823
#define TENSOR3D_DECLARATION(name)
Definition: helpers.h:813
#define VEC_DATA_TYPE(type, size)
Definition: helpers.h:728