50 bool have_zero_x_internal_padding(ITensorInfo *
src,
const ITensorInfo *weights)
52 return (
src->padding().left == 0 && weights->padding().left == 0 &&
src->padding().right == 0 &&
53 weights->padding().right == 0);
64 using tag_type =
typename vtype::tag_type;
67 const int element_size =
src->info()->element_size();
68 const int input_stride_w =
src->info()->strides_in_bytes().y() / element_size;
69 const int input_stride_h =
src->info()->strides_in_bytes().z() / element_size;
70 const int input_stride_n =
src->info()->strides_in_bytes()[3] / element_size;
71 const int input_dim_w =
src->info()->dimension(1);
72 const int input_dim_h =
src->info()->dimension(2);
74 const int output_stride_c =
dst->info()->strides_in_bytes().x();
81 const int conv_pad_top =
conv_info.pad_top();
82 const int conv_pad_left =
conv_info.pad_left();
83 const int conv_stride_w = std::get<0>(
conv_info.stride());
84 const int conv_stride_h = std::get<1>(
conv_info.stride());
87 Window window_out = window;
99 constexpr
int num_elems_read_per_iteration = 16 /
sizeof(T);
102 if (have_zero_x_internal_padding(
src->info(), weights->
info()))
129 const int in_w_start_t =
static_cast<int>(
id.y()) * conv_stride_w - conv_pad_left;
130 const int in_h_start_t =
static_cast<int>(
id.z()) * conv_stride_h - conv_pad_top;
131 const int in_w_end_t = in_w_start_t + kernel_dim_w;
132 const int in_h_end_t = in_h_start_t + kernel_dim_h;
135 const int in_w_start = std::max(in_w_start_t, 0);
136 const int in_h_start = std::max(in_h_start_t, 0);
137 const int in_w_end = std::min(in_w_end_t, input_dim_w);
138 const int in_h_end = std::min(in_h_end_t, input_dim_h);
141 const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w;
142 const int index_h_start = in_h_start - in_h_start_t;
143 const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w;
144 const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
155 const T *in_ptr_row =
156 reinterpret_cast<const T *
>(
src->buffer() +
src->info()->offset_first_element_in_bytes()) +
157 id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h;
158 const T *weights_ptr_row =
159 reinterpret_cast<const T *
>(wei.
ptr()) + index_h_start * kernel_stride_h;
160 uint8_t *out_ptr = out.
ptr() + id_w[3] * output_stride_c;
162 T out_temp =
static_cast<T
>(0);
163 for (
int index_h = index_h_start; index_h < index_h_end;
164 ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h)
166 const T *in_ptr_mover = in_ptr_row;
167 int index_wc = index_wc_start;
168 vector_type out_temp_vec =
wrapper::vdup_n(
static_cast<T
>(0), tag_type());
169 for (; index_wc <= index_wc_end - num_elems_read_per_iteration;
170 index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
176 out_temp +=
vreduce(out_temp_vec);
177 for (; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover)
179 const auto src_val = *(in_ptr_mover);
180 const auto w_val = *(weights_ptr_row + index_wc);
181 out_temp += src_val * w_val;
184 *(
reinterpret_cast<T *
>(out_ptr)) = out_temp;
197 const int in_w_start_t =
static_cast<int>(
id.y()) * conv_stride_w - conv_pad_left;
198 const int in_h_start_t =
static_cast<int>(
id.z()) * conv_stride_h - conv_pad_top;
199 const int in_w_end_t = in_w_start_t + kernel_dim_w;
200 const int in_h_end_t = in_h_start_t + kernel_dim_h;
203 const int in_w_start = std::max(in_w_start_t, 0);
204 const int in_h_start = std::max(in_h_start_t, 0);
205 const int in_w_end = std::min(in_w_end_t, input_dim_w);
206 const int in_h_end = std::min(in_h_end_t, input_dim_h);
209 const int wei_w_start = in_w_start - in_w_start_t;
210 const int wei_h_start = in_h_start - in_h_start_t;
211 const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
212 const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
215 const T *
const in_ptr_start =
216 reinterpret_cast<const T *
>(
src->buffer() +
src->info()->offset_first_element_in_bytes()) +
217 id[3] * input_stride_n;
223 const T *
const weights_ptr_start =
reinterpret_cast<const T *
>(wei.
ptr());
224 uint8_t *out_ptr = out.
ptr() + id_w[3] * output_stride_c;
226 T out_temp =
static_cast<T
>(0);
227 for (
int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end;
228 ++index_wei_h, ++index_in_h)
230 const T *
const in_ptr_row = in_ptr_start + index_in_h * input_stride_h;
231 const T *
const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h;
232 for (
int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end;
233 ++index_wei_w, ++index_in_w)
235 const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
236 const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
238 vector_type out_temp_vec =
wrapper::vdup_n(
static_cast<T
>(0), tag_type());
239 for (; index_c <= index_c_end - num_elems_read_per_iteration;
240 index_c += num_elems_read_per_iteration,
241 in_ptr_mover += num_elems_read_per_iteration,
242 weights_ptr_mover += num_elems_read_per_iteration)
248 out_temp +=
vreduce(out_temp_vec);
249 for (; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover)
251 const auto src_val = *(in_ptr_mover);
252 const auto w_val = *(weights_ptr_mover);
253 out_temp += src_val * w_val;
257 *(
reinterpret_cast<T *
>(out_ptr)) = out_temp;