24 #ifndef SRC_CORE_NEON_KERNELS_POOL3D_QUANTIZED_H
25 #define SRC_CORE_NEON_KERNELS_POOL3D_QUANTIZED_H
50 int pool_stride_x =
static_cast<int>(pool_info.
stride.
width);
51 int pool_stride_y =
static_cast<int>(pool_info.
stride.
height);
52 int pool_stride_z =
static_cast<int>(pool_info.
stride.
depth);
58 const int pool_pad_top =
static_cast<int>(pool_info.
padding.
top);
59 const int pool_pad_bottom =
static_cast<int>(pool_info.
padding.
bottom);
60 const int pool_pad_left =
static_cast<int>(pool_info.
padding.
left);
61 const int pool_pad_right =
static_cast<int>(pool_info.
padding.
right);
62 const int pool_pad_front =
static_cast<int>(pool_info.
padding.
front);
63 const int pool_pad_back =
static_cast<int>(pool_info.
padding.
back);
65 const int upper_bound_w =
src->info()->dimension(1) + (pool_info.
exclude_padding ? 0 : pool_pad_right);
66 const int upper_bound_h =
src->info()->dimension(2) + (pool_info.
exclude_padding ? 0 : pool_pad_bottom);
67 const int upper_bound_d =
src->info()->dimension(3) + (pool_info.
exclude_padding ? 0 : pool_pad_back);
69 const int input_dim_c =
src->info()->dimension(0);
70 const int input_dim_w =
src->info()->dimension(1);
71 const int input_dim_h =
src->info()->dimension(2);
72 const int input_dim_d =
src->info()->dimension(3);
74 const int y_stride =
static_cast<int>(
src->info()->strides_in_bytes().y());
75 const int z_stride =
static_cast<int>(
src->info()->strides_in_bytes().z());
76 const int w_stride =
static_cast<int>(
src->info()->strides_in_bytes()[3]);
77 const int n_stride =
static_cast<int>(
src->info()->strides_in_bytes()[4]);
79 const uint8_t *in_ptr_start =
src->buffer() +
src->info()->offset_first_element_in_bytes();
81 const int window_end_x = input_dim_c;
82 const int window_start_x = 0;
86 const float32x4_t half_scale_v = vdupq_n_f32(0.5f);
90 const float quant_rescale = dst_qinfo.
scale / src_qinfo.
scale;
93 const int32_t new_offset =
94 dst_qinfo.
offset -
static_cast<int32_t
>(
static_cast<float>(src_qinfo.
offset) / quant_rescale);
101 const int in_idx_width =
static_cast<int>(
id.y()) * pool_stride_x - pool_pad_left;
102 const int in_idx_height =
static_cast<int>(
id.z()) * pool_stride_y - pool_pad_top;
103 const int in_idx_depth =
static_cast<int>(
id[3]) * pool_stride_z - pool_pad_front;
105 const int pool_start_x = std::max(0, -in_idx_width);
106 const int pool_end_x_t = std::min(input_dim_w + pool_pad_left - in_idx_width, pool_size_x);
107 const int pool_start_y = std::max(0, -in_idx_height);
108 const int pool_end_y_t = std::min(input_dim_h + pool_pad_top - in_idx_height, pool_size_y);
110 const int pool_start_z = std::max(0, -in_idx_depth);
111 const int pool_end_z_t = std::min(input_dim_d + pool_pad_front - in_idx_depth, pool_size_z);
114 const int pool_end_x = std::min(pool_end_x_t, input_dim_w - in_idx_width);
115 const int pool_end_y = std::min(pool_end_y_t, input_dim_h - in_idx_height);
116 const int pool_end_z = std::min(pool_end_z_t, input_dim_d - in_idx_depth);
120 calculate_avg_scale_pool3d(pool_info.
exclude_padding,
id, pool_size_x, pool_size_y, pool_size_z,
121 upper_bound_w, upper_bound_h, upper_bound_d, pool_pad_left, pool_pad_top,
122 pool_pad_front, pool_stride_x, pool_stride_y, pool_stride_z);
124 const uint8_t *in_ptr_n = in_ptr_start +
id[4] * n_stride;
126 int x_off = window_start_x;
128 for (; x_off <= (window_end_x - window_step_x); x_off += window_step_x)
136 for (
int z = pool_start_z; z < pool_end_z; ++z)
138 const uint8_t *in_ptr_z = in_ptr_n + (z + in_idx_depth) * w_stride;
139 for (
int y = pool_start_y; y < pool_end_y; ++y)
141 const uint8_t *in_ptr_y = in_ptr_z + (y + in_idx_height) * z_stride;
142 for (
int x = pool_start_x; x < pool_end_x; ++x)
144 const uint8_t *in_ptr_x = in_ptr_y + (x + in_idx_width) * y_stride;
145 const q8x16_t data =
wrapper::vloadq(
reinterpret_cast<const T *
>(in_ptr_x) + x_off);
157 if (src_qinfo != dst_qinfo)
159 const float32x4x4_t vres = {{
160 vcvtq_f32_q32(vres1),
161 vcvtq_f32_q32(vres2),
162 vcvtq_f32_q32(vres3),
163 vcvtq_f32_q32(vres4),
165 const auto requantized_dst =
166 vrequantize_pooling_with_scale<q8x16_t>(vres, quant_rescale,
scale, new_offset);
173 const float32x4_t scale_v = vdupq_n_f32(
scale);
175 vres1 = vcvtq_q32_f32<q32x4_t>(
wrapper::vmla(half_scale_v, vcvtq_f32_q32(vres1), scale_v));
176 vres2 = vcvtq_q32_f32<q32x4_t>(
wrapper::vmla(half_scale_v, vcvtq_f32_q32(vres2), scale_v));
177 vres3 = vcvtq_q32_f32<q32x4_t>(
wrapper::vmla(half_scale_v, vcvtq_f32_q32(vres3), scale_v));
178 vres4 = vcvtq_q32_f32<q32x4_t>(
wrapper::vmla(half_scale_v, vcvtq_f32_q32(vres4), scale_v));
189 for (; x_off < window_end_x; ++x_off)
191 q32_t res =
static_cast<q32_t
>(0.f);
194 for (
int z = pool_start_z; z < pool_end_z; ++z)
196 const uint8_t *in_ptr_z = in_ptr_n + (z + in_idx_depth) * w_stride;
197 for (
int y = pool_start_y; y < pool_end_y; ++y)
199 const uint8_t *in_ptr_y = in_ptr_z + (y + in_idx_height) * z_stride;
200 for (
int x = pool_start_x; x < pool_end_x; ++x)
202 const uint8_t *in_ptr_x = in_ptr_y + (x + in_idx_width) * y_stride;
203 const T data = *(
reinterpret_cast<const T *
>(in_ptr_x) + x_off);
209 if (src_qinfo != dst_qinfo)
211 const float res_f =
static_cast<float>(res);
212 const float new_scale = quant_rescale /
scale;
216 *(
reinterpret_cast<T *
>(out.
ptr()) + x_off) = requantized_dst;
221 res =
static_cast<T
>(0.5f +
static_cast<float>(res) *
scale);
224 *(
reinterpret_cast<T *
>(out.
ptr()) + x_off) = res;
231 template <
typename T>
239 const int window_half_step_x = window_step_x / 2;
241 int pool_stride_x =
static_cast<int>(pool_info.
stride.
width);
242 int pool_stride_y =
static_cast<int>(pool_info.
stride.
height);
243 int pool_stride_z =
static_cast<int>(pool_info.
stride.
depth);
249 const int pool_pad_top =
static_cast<int>(pool_info.
padding.
top);
250 const int pool_pad_left =
static_cast<int>(pool_info.
padding.
left);
251 const int pool_pad_front =
static_cast<int>(pool_info.
padding.
front);
253 const int input_dim_c =
src->info()->dimension(0);
254 const int input_dim_w =
src->info()->dimension(1);
255 const int input_dim_h =
src->info()->dimension(2);
256 const int input_dim_d =
src->info()->dimension(3);
258 const int y_stride =
static_cast<int>(
src->info()->strides_in_bytes().y());
259 const int z_stride =
static_cast<int>(
src->info()->strides_in_bytes().z());
260 const int w_stride =
static_cast<int>(
src->info()->strides_in_bytes()[3]);
261 const int n_stride =
static_cast<int>(
src->info()->strides_in_bytes()[4]);
263 const uint8_t *in_ptr_start =
src->buffer() +
src->info()->offset_first_element_in_bytes();
265 const int window_end_x = input_dim_c;
266 const int window_start_x = 0;
273 const float requant_scale = dst_qinfo.
scale / src_qinfo.
scale;
274 const int32_t requant_offset =
275 dst_qinfo.
offset -
static_cast<int32_t
>(
static_cast<float>(src_qinfo.
offset) / requant_scale);
283 const int in_idx_width =
static_cast<int>(
id.y()) * pool_stride_x - pool_pad_left;
284 const int in_idx_height =
static_cast<int>(
id.z()) * pool_stride_y - pool_pad_top;
285 const int in_idx_depth =
static_cast<int>(
id[3]) * pool_stride_z - pool_pad_front;
287 const int pool_start_x = std::max(0, -in_idx_width);
288 const int pool_end_x_t = std::min(input_dim_w + pool_pad_left - in_idx_width, pool_size_x);
289 const int pool_start_y = std::max(0, -in_idx_height);
290 const int pool_end_y_t = std::min(input_dim_h + pool_pad_top - in_idx_height, pool_size_y);
292 const int pool_start_z = std::max(0, -in_idx_depth);
293 const int pool_end_z_t = std::min(input_dim_d + pool_pad_front - in_idx_depth, pool_size_z);
296 const int pool_end_x = std::min(pool_end_x_t, input_dim_w - in_idx_width);
297 const int pool_end_y = std::min(pool_end_y_t, input_dim_h - in_idx_height);
298 const int pool_end_z = std::min(pool_end_z_t, input_dim_d - in_idx_depth);
300 const uint8_t *in_ptr_n = in_ptr_start +
id[4] * n_stride;
302 int x_off = window_start_x;
304 for (; x_off <= (window_end_x - window_step_x); x_off += window_step_x)
309 for (
int z = pool_start_z; z < pool_end_z; ++z)
311 const uint8_t *in_ptr_z = in_ptr_n + (z + in_idx_depth) * w_stride;
312 for (
int y = pool_start_y; y < pool_end_y; ++y)
314 const uint8_t *in_ptr_y = in_ptr_z + (y + in_idx_height) * z_stride;
315 for (
int x = pool_start_x; x < pool_end_x; ++x)
317 const uint8_t *in_ptr_x = in_ptr_y + (x + in_idx_width) * y_stride;
318 const q8x16_t data =
wrapper::vloadq(
reinterpret_cast<const T *
>(in_ptr_x) + x_off);
327 (src_qinfo != dst_qinfo)
334 for (; x_off <= (window_end_x - window_half_step_x); x_off += window_half_step_x)
339 for (
int z = pool_start_z; z < pool_end_z; ++z)
341 const uint8_t *in_ptr_z = in_ptr_n + (z + in_idx_depth) * w_stride;
342 for (
int y = pool_start_y; y < pool_end_y; ++y)
344 const uint8_t *in_ptr_y = in_ptr_z + (y + in_idx_height) * z_stride;
345 for (
int x = pool_start_x; x < pool_end_x; ++x)
347 const uint8_t *in_ptr_x = in_ptr_y + (x + in_idx_width) * y_stride;
348 const q8x8_t data =
wrapper::vload(
reinterpret_cast<const T *
>(in_ptr_x) + x_off);
357 (src_qinfo != dst_qinfo) ? vrequantize_pooling<q8x8_t>(vres, requant_qinfo) : vres);
361 for (; x_off < window_end_x; ++x_off)
363 T res = std::numeric_limits<T>::min();
365 for (
int z = pool_start_z; z < pool_end_z; ++z)
367 const uint8_t *in_ptr_z = in_ptr_n + (z + in_idx_depth) * w_stride;
368 for (
int y = pool_start_y; y < pool_end_y; ++y)
370 const uint8_t *in_ptr_y = in_ptr_z + (y + in_idx_height) * z_stride;
371 for (
int x = pool_start_x; x < pool_end_x; ++x)
373 const uint8_t *in_ptr_x = in_ptr_y + (x + in_idx_width) * y_stride;
374 const T data = *(
reinterpret_cast<const T *
>(in_ptr_x) + x_off);
376 res = std::max(res, data);
382 if (src_qinfo != dst_qinfo)
384 const float res_f =
static_cast<float>(res);
385 *(
reinterpret_cast<T *
>(out.
ptr()) + x_off) = quantize<T>(res_f, requant_qinfo);
389 *(
reinterpret_cast<T *
>(out.
ptr()) + x_off) = res;
399 #endif // SRC_CORE_NEON_KERNELS_POOL3D_QUANTIZED_H