24 #ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_IMPL_H
25 #define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_IMPL_H
46 template <
typename T,
bool has_pads>
64 const int kernel_size2 = kernel_width * kernel_height;
65 const int x_e = top_left_x + kernel_width * dilation_x;
66 const int y_e = top_left_y + kernel_height * dilation_y;
73 for (; d <= (kernel_depth - 3); d += 3)
75 for (
int y = top_left_y; y < y_e; y += dilation_y)
77 if ((y < 0 || y >= input_h) && has_pads)
80 for (
int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
82 *(out_ptr + 0 * kernel_size2) = pad_value;
83 *(out_ptr + 1 * kernel_size2) = pad_value;
84 *(out_ptr + 2 * kernel_size2) = pad_value;
89 for (
int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
91 if ((x < 0 || x >= input_w) && has_pads)
93 *(out_ptr + 0 * kernel_size2) = pad_value;
94 *(out_ptr + 1 * kernel_size2) = pad_value;
95 *(out_ptr + 2 * kernel_size2) = pad_value;
99 *(out_ptr + 0 * kernel_size2) = *(
reinterpret_cast<const T *
>(
100 in_ptr + ((d + 0) * input_stride_z + y * input_stride_y + x * input_stride_x)));
101 *(out_ptr + 1 * kernel_size2) = *(
reinterpret_cast<const T *
>(
102 in_ptr + ((d + 1) * input_stride_z + y * input_stride_y + x * input_stride_x)));
103 *(out_ptr + 2 * kernel_size2) = *(
reinterpret_cast<const T *
>(
104 in_ptr + ((d + 2) * input_stride_z + y * input_stride_y + x * input_stride_x)));
109 out_ptr += 2 * kernel_size2;
113 for (; d < kernel_depth; d++)
115 for (
int y = top_left_y; y < y_e; y += dilation_y)
117 if ((y < 0 || y >= input_h) && has_pads)
120 memset(
static_cast<void *
>(out_ptr), pad_value, kernel_width *
sizeof(T));
121 out_ptr += kernel_width;
125 for (
int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
127 if ((x < 0 || x >= input_w) && has_pads)
129 *out_ptr = pad_value;
133 *out_ptr = *(
reinterpret_cast<const T *
>(
134 in_ptr + (d * input_stride_z + y * input_stride_y + x * input_stride_x)));
144 *out_ptr =
static_cast<T
>(1);
148 template <
typename T,
bool has_pads>
165 const int end_x = start_x + kernel_width * dilation_x;
166 const int end_y = start_y + kernel_height * dilation_y;
167 const int pad_quant = kernel_width * input_c;
168 const int element_size =
static_cast<int>(
sizeof(T));
169 if ((start_y >= 0) && (end_y < input_h) && (start_x >= 0) && (end_x < input_w) && (dilation_x == 1) &&
170 (input_stride_y == input_c * element_size))
172 for (
int y = start_y; y < end_y; y += dilation_y)
175 memcpy(out_ptr,
reinterpret_cast<const T *
>(in_ptr + (y * input_stride_z + start_x * input_stride_y)),
176 input_c * kernel_width * element_size);
177 out_ptr += input_c * kernel_width;
182 for (
int y = start_y; y < end_y; y += dilation_y)
184 if (y < 0 || y >= input_h)
186 memset(
static_cast<void *
>(out_ptr), pad_value, pad_quant * element_size);
187 out_ptr += pad_quant;
189 else if (dilation_x > 1 || start_x < 0 || end_x >= input_w || input_stride_y != input_c * element_size)
191 for (
int x = start_x; x < end_x; x += dilation_x)
193 if (x < 0 || x >= input_w)
195 memset(
static_cast<void *
>(out_ptr), pad_value, input_c * element_size);
200 memcpy(out_ptr,
reinterpret_cast<const T *
>(in_ptr + (y * input_stride_z + x * input_stride_y)),
201 input_c * element_size);
209 memcpy(out_ptr,
reinterpret_cast<const T *
>(in_ptr + (y * input_stride_z + start_x * input_stride_y)),
210 input_c * kernel_width * element_size);
211 out_ptr += input_c * kernel_width;
218 *out_ptr =
static_cast<T
>(1);
222 template <
typename T,
bool has_pads>
240 const int end_x = start_x + kernel_width * dilation_x;
241 const int end_y = start_y + kernel_height * dilation_y;
242 const int pad_quant = kernel_width * (input_c + pad_right);
243 const int element_size =
static_cast<int>(
sizeof(T));
244 const int channel_chunk_size = input_c * element_size;
246 if ((start_y >= 0) && (end_y < input_h) && (start_x >= 0) && (end_x < input_w) && (dilation_x == 1) &&
247 (input_stride_y == channel_chunk_size))
249 for (
int y = start_y; y < end_y; y += dilation_y)
251 const uint8_t *offset_ptr = in_ptr + (y * input_stride_z + start_x * input_stride_y);
252 for (
int e = 0; e < kernel_width; e++)
254 memcpy(out_ptr,
reinterpret_cast<const T *
>(offset_ptr + e * channel_chunk_size), channel_chunk_size);
255 out_ptr += input_c + pad_right;
261 for (
int y = start_y; y < end_y; y += dilation_y)
263 if (y < 0 || y >= input_h)
265 memset(
static_cast<void *
>(out_ptr), pad_value, pad_quant * element_size);
266 out_ptr += pad_quant;
268 else if (dilation_x > 1 || start_x < 0 || end_x >= input_w || input_stride_y != channel_chunk_size)
270 for (
int x = start_x; x < end_x; x += dilation_x)
272 if (x < 0 || x >= input_w)
274 memset(
static_cast<void *
>(out_ptr), pad_value, (input_c + pad_right) * element_size);
275 out_ptr += input_c + pad_right;
279 memcpy(out_ptr,
reinterpret_cast<const T *
>(in_ptr + (y * input_stride_z + x * input_stride_y)),
281 out_ptr += input_c + pad_right;
287 const uint8_t *offset_ptr = in_ptr + (y * input_stride_z + start_x * input_stride_y);
288 for (
int e = 0; e < kernel_width; e++)
290 memcpy(out_ptr,
reinterpret_cast<const T *
>(offset_ptr + e * channel_chunk_size),
292 out_ptr += input_c + pad_right;
300 *out_ptr =
static_cast<T
>(1);
304 template <
typename T,
bool has_pads,
bool is_nchw>
310 std::pair<unsigned int, unsigned int> convolved_dims,
311 const Size2D &kernel_dims,
313 uint32_t input_pad_right,
323 const int input_stride_x =
src->info()->strides_in_bytes().x();
324 const int input_stride_y =
src->info()->strides_in_bytes().y();
325 const int input_stride_z =
src->info()->strides_in_bytes().z();
326 const int pad_left =
conv_info.pad_left();
328 const int stride_x =
conv_info.stride().first;
329 const int stride_y =
conv_info.stride().second;
330 const int pad_value =
333 const auto kernel_width = kernel_dims.
width;
334 const auto kernel_height = kernel_dims.
height;
336 Window window_in_out(window);
350 const int start_w =
id[
width_idx] * stride_x - pad_left;
351 const int start_h =
id[
height_idx] * stride_y - pad_top;
354 const uint8_t *
const input_ptr = in.
ptr();
357 dst->info()->strides_in_bytes().y());
362 linearize_volume_nchw<T, has_pads>(
363 input_ptr, output_ptr,
has_bias, start_w, start_h, kernel_width, kernel_height, input_c, input_w,
364 input_h, input_stride_x, input_stride_y, input_stride_z, pad_value, dilation.
x(), dilation.
y());
368 if (input_pad_right > 0)
370 linearize_volume_nhwc<T, has_pads>(input_ptr, output_ptr,
has_bias, start_w, start_h, kernel_width,
371 kernel_height, input_w, input_h, input_c, input_stride_y,
372 input_stride_z, pad_value, dilation.
x(), dilation.
y(),
377 linearize_volume_nhwc<T, has_pads>(input_ptr, output_ptr,
has_bias, start_w, start_h, kernel_width,
378 kernel_height, input_w, input_h, input_c, input_stride_y,
379 input_stride_z, pad_value, dilation.
x(), dilation.
y());
389 #endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_IMPL_H