26 #if defined(DATA_TYPE) && defined(VEC_SIZE) && defined(NUM_GROUPS) && defined(K) && defined(SRC_DIM_Z) 29 #if VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 30 #error "Only vector sizes 4, 8 and 16 are supported" 31 #endif // VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 33 #define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) 35 #define DIV_MOD_UINT(x, y, div_res, mod_res) \ 37 div_res = (uint)((x) * (float)(1.0f / (float)(y))); \ 38 uint r = div_res * (y); \ 74 uint curr_channel = 0;
80 DIV_MOD_UINT(get_global_id(2), SRC_DIM_Z, batch_id, curr_channel);
83 DIV_MOD_UINT(curr_channel,
K, group_id, channel_id);
85 const uint x = get_global_id(0) *
VEC_SIZE;
86 const uint y = get_global_id(1) * 2;
87 const uint z = channel_id * NUM_GROUPS + group_id;
90 const __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x *
sizeof(
DATA_TYPE) + y * src_stride_y + curr_channel * src_stride_z + batch_id * src_stride_w;
95 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + x *
sizeof(
DATA_TYPE) + y * dst_stride_y + z * dst_stride_z + batch_id * dst_stride_w;
97 (u0, 0, (__global
DATA_TYPE *)(output_ptr + 0 * dst_stride_y));
99 (u1, 0, (__global
DATA_TYPE *)(output_ptr + 1 * dst_stride_y));
102 #if VEC_SIZE == 4 && defined(LAST_ACCESSED) 137 const uint curr_channel = min((uint)(get_global_id(0) *
VEC_SIZE), (uint)LAST_ACCESSED);
138 uint channel_id0 = 0;
139 uint channel_id1 = 0;
140 uint channel_id2 = 0;
141 uint channel_id3 = 0;
150 DIV_MOD_UINT(get_global_id(2), (uint)SRC_DIM_Z, batch_id, y);
153 DIV_MOD_UINT(curr_channel + (uint)0,
K, group_id0, channel_id0);
154 DIV_MOD_UINT(curr_channel + (uint)1,
K, group_id1, channel_id1);
155 DIV_MOD_UINT(curr_channel + (uint)2,
K, group_id2, channel_id2);
156 DIV_MOD_UINT(curr_channel + (uint)3,
K, group_id3, channel_id3);
158 const uint x = get_global_id(1) * 2;
159 const uint z0 = channel_id0 * (uint)NUM_GROUPS + group_id0;
160 const uint z1 = channel_id1 * (uint)NUM_GROUPS + group_id1;
161 const uint z2 = channel_id2 * (uint)NUM_GROUPS + group_id2;
162 const uint z3 = channel_id3 * (uint)NUM_GROUPS + group_id3;
165 const __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + curr_channel *
sizeof(
DATA_TYPE) + x * src_stride_y + y * src_stride_z + batch_id * src_stride_w;
170 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_y + y * dst_stride_z + batch_id * dst_stride_w;
171 *((__global
DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z0 *
sizeof(
DATA_TYPE))) = u0.s0;
172 *((__global
DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z1 *
sizeof(
DATA_TYPE))) = u0.s1;
173 *((__global
DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z2 *
sizeof(
DATA_TYPE))) = u0.s2;
174 *((__global
DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z3 *
sizeof(
DATA_TYPE))) = u0.s3;
175 *((__global
DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z0 *
sizeof(
DATA_TYPE))) = u1.s0;
176 *((__global
DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z1 *
sizeof(
DATA_TYPE))) = u1.s1;
177 *((__global
DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z2 *
sizeof(
DATA_TYPE))) = u1.s2;
178 *((__global
DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z3 *
sizeof(
DATA_TYPE))) = u1.s3;
180 #endif // VEC_SIZE == 4 && defined(LAST_ACCESSED) 181 #endif // defined(DATA_TYPE) && defined(VEC_SIZE) && defined(NUM_GROUPS) && defined(K) && defined(SRC_DIM_Z)
SimpleTensor< float > src
#define TENSOR4D_DECLARATION(name)