24 #ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_NCHW_IMPL_H
25 #define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_NCHW_IMPL_H
55 using tag_type =
typename vtype::tag_type;
58 const int element_size =
src->info()->element_size();
59 const int input_stride_w =
src->info()->strides_in_bytes()[0] / element_size;
60 const int input_stride_h =
src->info()->strides_in_bytes()[1] / element_size;
61 const int input_stride_c =
src->info()->strides_in_bytes()[2] / element_size;
62 const int input_stride_n =
src->info()->strides_in_bytes()[3] / element_size;
64 const int input_dim_w =
src->info()->dimension(0);
65 const int input_dim_h =
src->info()->dimension(1);
67 const int output_stride_c =
dst->info()->strides_in_bytes()[2];
76 const int conv_pad_top =
conv_info.pad_top();
77 const int conv_pad_left =
conv_info.pad_left();
78 const int conv_stride_w = std::get<0>(
conv_info.stride());
79 const int conv_stride_h = std::get<1>(
conv_info.stride());
82 Window window_out = window;
94 constexpr
int num_elems_read_per_iteration = 16 /
sizeof(T);
101 const int in_w_start_t =
static_cast<int>(
id.x()) * conv_stride_w - conv_pad_left;
102 const int in_h_start_t =
static_cast<int>(
id.y()) * conv_stride_h - conv_pad_top;
103 const int in_w_end_t = in_w_start_t + kernel_dim_w;
104 const int in_h_end_t = in_h_start_t + kernel_dim_h;
107 const int in_w_start = std::max(in_w_start_t, 0);
108 const int in_h_start = std::max(in_h_start_t, 0);
109 const int in_w_end = std::min(in_w_end_t, input_dim_w);
110 const int in_h_end = std::min(in_h_end_t, input_dim_h);
113 const int wei_w_start = in_w_start - in_w_start_t;
114 const int wei_h_start = in_h_start - in_h_start_t;
115 const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
118 const T *
const in_ptr_start =
119 reinterpret_cast<const T *
>(
src->buffer() +
src->info()->offset_first_element_in_bytes()) +
120 id[3] * input_stride_n;
125 const T *
const weights_ptr_start =
reinterpret_cast<const T *
>(wei.
ptr());
126 uint8_t *out_ptr = out.
ptr() + id_w[3] * output_stride_c;
127 T out_temp =
static_cast<T
>(0);
129 for (
int index_wei_c = 0, index_in_c = 0; index_wei_c < index_c_end; ++index_wei_c, ++index_in_c)
131 const T *
const in_ptr_row_0 = in_ptr_start + index_in_c * input_stride_c;
132 const T *
const weights_ptr_row_0 = weights_ptr_start + index_wei_c * kernel_stride_c;
133 for (
int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end;
134 ++index_wei_h, ++index_in_h)
136 const T *in_ptr_row = in_ptr_row_0 + index_in_h * input_stride_h;
137 const T *weights_ptr_row = weights_ptr_row_0 + index_wei_h * kernel_stride_h;
138 int index_w = in_w_start;
139 int index_wei_w = wei_w_start;
140 vector_type out_temp_vec =
wrapper::vdup_n(
static_cast<T
>(0), tag_type());
141 for (; index_w <= ((in_w_end - num_elems_read_per_iteration));
142 index_w += num_elems_read_per_iteration, index_wei_w += num_elems_read_per_iteration)
144 const auto src_vec =
wrapper::vloadq(in_ptr_row + index_w * input_stride_w);
145 const auto w_vec =
wrapper::vloadq(weights_ptr_row + index_wei_w * kernel_stride_w);
148 out_temp +=
vreduce(out_temp_vec);
149 for (; index_w < in_w_end; ++index_w, ++index_wei_w)
151 const auto src_val = *(in_ptr_row + index_w * input_stride_w);
152 const auto w_val = *(weights_ptr_row + index_wei_w * kernel_stride_w);
153 out_temp += src_val * w_val;
157 *(
reinterpret_cast<T *
>(out_ptr)) = out_temp;
166 #endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_NCHW_IMPL_H