24 #ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
25 #define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
53 using tag_type =
typename vtype::tag_type;
54 constexpr
int num_elems_read_per_iteration = 16 /
sizeof(T);
57 const int element_size =
src->info()->element_size();
58 const int input_stride_w =
src->info()->strides_in_bytes().y() / element_size;
59 const int input_stride_h =
src->info()->strides_in_bytes().z() / element_size;
60 const int input_stride_d =
src->info()->strides_in_bytes()[3] / element_size;
61 const int input_stride_n =
src->info()->strides_in_bytes()[4] / element_size;
62 const int input_dim_w =
src->info()->dimension(1);
63 const int input_dim_h =
src->info()->dimension(2);
64 const int input_dim_d =
src->info()->dimension(3);
75 const int conv_pad_top =
conv_info.padding.top;
76 const int conv_pad_left =
conv_info.padding.left;
77 const int conv_pad_front =
conv_info.padding.front;
78 const int conv_stride_w =
conv_info.stride.width;
79 const int conv_stride_h =
conv_info.stride.height;
80 const int conv_stride_d =
conv_info.stride.depth;
83 Window window_out = window;
96 const T *biases_ptr =
nullptr;
97 if (biases !=
nullptr)
106 const int in_w_start_t =
static_cast<int>(
id.y()) * conv_stride_w - conv_pad_left;
107 const int in_h_start_t =
static_cast<int>(
id.z()) * conv_stride_h - conv_pad_top;
108 const int in_d_start_t =
static_cast<int>(
id[3]) * conv_stride_d - conv_pad_front;
109 const int in_w_end_t = in_w_start_t + kernel_dim_w;
110 const int in_h_end_t = in_h_start_t + kernel_dim_h;
111 const int in_d_end_t = in_d_start_t + kernel_dim_d;
114 const int in_w_start = std::max(in_w_start_t, 0);
115 const int in_h_start = std::max(in_h_start_t, 0);
116 const int in_d_start = std::max(in_d_start_t, 0);
117 const int in_w_end = std::min(in_w_end_t, input_dim_w);
118 const int in_h_end = std::min(in_h_end_t, input_dim_h);
119 const int in_d_end = std::min(in_d_end_t, input_dim_d);
122 const int wei_w_start = in_w_start - in_w_start_t;
123 const int wei_h_start = in_h_start - in_h_start_t;
124 const int wei_d_start = in_d_start - in_d_start_t;
125 const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
126 const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
127 const int wei_d_end = kernel_dim_d - (in_d_end_t - in_d_end);
131 const T *
const in_ptr_start =
132 reinterpret_cast<const T *
>(
src->buffer() +
src->info()->offset_first_element_in_bytes()) +
133 id[4] * input_stride_n;
142 const auto weights_ptr_start =
reinterpret_cast<const T *
>(wei.
ptr());
143 T out_temp =
static_cast<T
>(0);
144 T *out_ptr =
reinterpret_cast<T *
>(out.
ptr());
145 for (
int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end;
146 ++index_wei_d, ++index_in_d)
148 const auto in_ptr_d = in_ptr_start + index_in_d * input_stride_d;
149 const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d;
150 for (
int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end;
151 ++index_wei_h, ++index_in_h)
153 const T *
const in_ptr_row = in_ptr_d + index_in_h * input_stride_h;
154 const T *
const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h;
155 for (
int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end;
156 ++index_wei_w, ++index_in_w)
158 const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
159 const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
161 vector_type out_temp_vec =
wrapper::vdup_n(
static_cast<T
>(0), tag_type());
163 for (; index_c_in <= index_c_in_end - num_elems_read_per_iteration;
164 index_c_in += num_elems_read_per_iteration,
165 in_ptr_mover += num_elems_read_per_iteration)
169 for (
int k = 0; k < num_elems_read_per_iteration;
170 ++k, weights_ptr_mover += index_c_out_end)
176 out_temp +=
vreduce(out_temp_vec);
177 for (; index_c_in < index_c_in_end;
178 ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end)
180 const auto src_val = *(in_ptr_mover);
181 const auto w_val = *(weights_ptr_mover);
182 out_temp += src_val * w_val;
187 *(
reinterpret_cast<T *
>(out_ptr + id_w[0])) =
188 (biases_ptr !=
nullptr) ? out_temp + biases_ptr[id_w[0]] : out_temp;
197 #endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H