40 template <typename T, typename ACC_T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
50 const int pool_size_x =
info.is_global_pooling ?
src.shape().x() :
info.pool_size.width;
51 const int pool_size_y =
info.is_global_pooling ?
src.shape().y() :
info.pool_size.height;
53 int pool_stride_x =
info.pad_stride_info.stride().first;
54 int pool_stride_y =
info.pad_stride_info.stride().second;
55 int pad_left =
info.pad_stride_info.pad_left();
56 int pad_top =
info.pad_stride_info.pad_top();
57 int pad_right =
info.pad_stride_info.pad_right();
58 int pad_bottom =
info.pad_stride_info.pad_bottom();
59 bool exclude_padding =
info.exclude_padding;
61 const auto w_src =
static_cast<int>(
src.shape()[0]);
62 const auto h_src =
static_cast<int>(
src.shape()[1]);
63 const auto z_src =
static_cast<int>(
src.shape()[2]);
64 const auto b_src =
static_cast<int>(
src.shape()[3]);
66 const int upper_dims =
src.shape().total_size() / (w_src * h_src);
68 const auto w_dst =
static_cast<int>(
dst.shape()[0]);
69 const auto h_dst =
static_cast<int>(
dst.shape()[1]);
70 const auto z_dst =
static_cast<int>(
dst.shape()[2]);
76 for(
int b = 0;
b < b_src; ++
b)
78 for(
int r = 0; r < z_src; ++r)
80 for(
int h = 0; h < h_dst; ++h)
82 for(
int w = 0;
w < w_dst; ++
w)
84 int wstart =
w * pool_stride_x - pad_left;
85 int hstart = h * pool_stride_y - pad_top;
88 int kh_start = std::max(0, -hstart);
89 int kw_start = std::max(0, -wstart);
90 int max_ker_index{ 0 };
92 int wend = std::min(wstart + pool_size_x, w_src);
93 int hend = std::min(hstart + pool_size_y, h_src);
94 wstart = std::max(wstart, 0);
95 hstart = std::max(hstart, 0);
99 for(
int y = hstart, kh = kh_start; y < hend; ++y, ++kh)
101 for(
int x = wstart, kw = kw_start; x < wend; ++x, ++kw)
103 const auto val =
static_cast<ACC_T
>(
src[
b * z_src * h_src * w_src + r * h_src * w_src + y * w_src + x]);
107 max_ker_index = pool_size_x * (kh) + (kw);
120 dst[
b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst +
w] =
static_cast<T
>(max_val);
123 (*indices)[
b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst +
w] = (
info.use_kernel_indices) ? max_ker_index : max_index;
132 for(
int r = 0; r < upper_dims; ++r)
134 for(
int h = 0; h < h_dst; ++h)
136 for(
int w = 0;
w < w_dst; ++
w)
139 int wstart =
w * pool_stride_x - pad_left;
140 int hstart = h * pool_stride_y - pad_top;
141 int wend = std::min(wstart + pool_size_x, w_src + pad_right);
142 int hend = std::min(hstart + pool_size_y, h_src + pad_bottom);
143 int pool = (hend - hstart) * (wend - wstart);
144 wstart = std::max(wstart, 0);
145 hstart = std::max(hstart, 0);
146 wend = std::min(wend, w_src);
147 hend = std::min(hend, h_src);
151 pool = (hend - hstart) * (wend - wstart);
156 for(
int y = hstart; y < hend; ++y)
158 for(
int x = wstart; x < wend; ++x)
160 avg_val +=
static_cast<ACC_T
>(
src[r * h_src * w_src + y * w_src + x]);
163 dst[r * h_dst * w_dst + h * w_dst +
w] = avg_val / pool;
167 for(
int y = hstart; y < hend; ++y)
169 for(
int x = wstart; x < wend; ++x)
171 const auto val =
static_cast<ACC_T
>(
src[r * h_src * w_src + y * w_src + x]);
172 avg_val += val * val;
175 dst[r * h_dst * w_dst + h * w_dst +
w] =
static_cast<T
>(std::sqrt(avg_val / pool));
190 template <
typename T>