30 #if !defined(_WIN64) && !defined(__OpenBSD__)
38 template <
typename TInput,
typename TOutput>
41 unsigned int input_rows, input_cols, output_rows, output_cols;
45 unsigned int stride_rows,
unsigned int stride_cols,
46 unsigned int output_rows,
unsigned int output_cols)
47 : input_rows(output_rows + (window_rows - 1) * stride_rows),
48 input_cols(output_cols + (window_cols - 1) * stride_cols),
49 output_rows(output_rows), output_cols(output_cols)
59 unsigned int n_channels,
60 const TInput *
const *,
63 unsigned int pad_left,
65 unsigned int pad_right,
66 unsigned int pad_bottom
79 template <
typename TInput,
typename TOutput=TInput,
class OutputStage=Nothing>
82 size_t sizeof_input_buffer(
void)
const
84 return sizeof(TInput) * this->m_args.n_channels;
87 size_t sizeof_output_buffer(
void)
const
89 return sizeof(TOutput) * this->m_args.n_channels;
94 size_t get_working_size_per_thread()
const override
96 return sizeof(
WorkingSpace) + this->m_args.n_channels * (
sizeof(TInput) +
sizeof(TOutput));
100 void initialise_working_space(
void *raw_ws)
const override
104 ws->output_buffer =
reinterpret_cast<char *
>(ws + 1) +
sizeof(TInput) * this->m_args.n_channels;
110 using limits = std::numeric_limits<TInput>;
111 if (limits::has_infinity)
113 fill_val = -limits::infinity();
117 fill_val = limits::min();
121 auto ptr =
reinterpret_cast<TInput *
>(ws->input_buffer);
122 auto n_channels = this->m_args.n_channels;
123 for (; n_channels; n_channels--)
130 void compute_tile_padded(
131 unsigned int output_i,
unsigned int output_j,
132 unsigned int channel_start,
unsigned int channel_end,
139 this->m_strat.get())->get_kernel();
142 auto ws =
reinterpret_cast<WorkingSpace *
>(working_space);
143 auto inptr_array =
reinterpret_cast<const TInput **
>(alloca(
144 sizeof(TInput *) * this->m_strat->get_input_rows() * this->m_strat->get_input_cols()));
145 auto outptr_array =
reinterpret_cast<TOutput **
>(alloca(
146 sizeof(TOutput *) * this->m_strat->get_output_rows() * this->m_strat->get_output_cols()));
149 const int ii =
static_cast<int>(output_i * this->m_args.pool_stride.rows) - this->m_args.padding.top;
150 const auto input_pad_top =
static_cast<unsigned int>(ii < 0 ? -ii : 0);
151 const auto input_i =
static_cast<unsigned int>(ii < 0 ? 0 : ii);
153 const unsigned int end_ii = ii + this->m_strat->get_input_rows();
154 const auto input_pad_bottom = end_ii < this->m_args.input_rows ? 0 : end_ii - this->m_args.input_rows;
156 const int ij =
static_cast<int>(output_j * this->m_args.pool_stride.cols) - this->m_args.padding.left;
157 const auto input_pad_left =
static_cast<unsigned int>(ij < 0 ? -ij : 0);
158 const auto input_j =
static_cast<unsigned int>(ij < 0 ? 0 : ij);
160 const unsigned int end_ij = ij + this->m_strat->get_input_cols();
161 const auto input_pad_right = end_ij < this->m_args.input_cols ? 0 : end_ij - this->m_args.input_cols;
163 fill_pointer_array<const TInput>(
164 inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
165 input.base + input_i*
input.ld_row + input_j*
input.ld_col + channel_start,
167 reinterpret_cast<const TInput *
>(ws->input_buffer),
168 input_pad_top, this->m_args.input_rows - input_i,
169 input_pad_left, this->m_args.input_cols - input_j
174 outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
175 output.
base + output_i*output.
ld_row + output_j*output.
ld_col + channel_start,
177 reinterpret_cast<TOutput *
>(ws->output_buffer),
178 0, this->m_args.output_rows - output_i,
179 0, this->m_args.output_cols - output_j
185 this->m_args.exclude_padding,
186 input_pad_left, input_pad_top,
187 input_pad_right, input_pad_bottom
192 void compute_row_padded_tile_row(
193 const unsigned int output_i,
unsigned int output_j,
unsigned int n_tile_cols,
194 const unsigned int channel_start,
const unsigned int channel_end,
201 this->m_strat.get())->get_kernel();
204 auto ws =
reinterpret_cast<WorkingSpace *
>(working_space);
205 auto inptr_array =
reinterpret_cast<const TInput **
>(alloca(
206 sizeof(TInput *) * this->m_strat->get_input_rows() * this->m_strat->get_input_cols()));
207 auto outptr_array =
reinterpret_cast<TOutput **
>(alloca(
208 sizeof(TOutput *) * this->m_strat->get_output_rows() * this->m_strat->get_output_cols()));
211 const int ii =
static_cast<int>(output_i * this->m_args.pool_stride.rows) - this->m_args.padding.top;
212 const auto input_pad_top =
static_cast<unsigned int>(ii < 0 ? -ii : 0);
213 const auto input_i =
static_cast<unsigned int>(ii < 0 ? 0 : ii);
215 const unsigned int end_ii = ii + this->m_strat->get_input_rows();
216 const auto input_pad_bottom = end_ii < this->m_args.input_rows ? 0 : end_ii - this->m_args.input_rows;
218 const int ij =
static_cast<int>(output_j * this->m_args.pool_stride.cols) - this->m_args.padding.left;
219 const auto input_j =
static_cast<unsigned int>(ij < 0 ? 0 : ij);
221 const auto end_oi = output_i + this->m_strat->get_output_cols();
222 const auto output_pad_bottom = end_oi < this->m_args.output_rows ? 0 : end_oi - this->m_args.output_rows;
224 fill_pointer_array<const TInput>(
225 inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
226 input.base + input_i*
input.ld_row + input_j*
input.ld_col + channel_start,
228 reinterpret_cast<const TInput *
>(ws->input_buffer),
229 input_pad_top, this->m_args.input_rows - input_i,
230 0, this->m_args.input_cols - input_j
235 outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
236 output.
base + output_i*output.
ld_row + output_j*output.
ld_col + channel_start,
238 reinterpret_cast<TOutput *
>(ws->output_buffer),
239 0, this->m_args.output_rows - output_i,
240 0, this->m_args.output_cols - output_j
244 for (; n_tile_cols; n_tile_cols--)
248 this->m_args.exclude_padding,
254 const auto input_col_stride =
input.ld_col * this->m_strat->get_output_cols() * this->m_args.pool_stride.cols;
256 auto n = input_pad_top * this->m_strat->get_input_cols();
257 n < (this->m_strat->get_input_rows() - input_pad_bottom) * this->m_strat->get_input_cols();
261 inptr_array[n] += input_col_stride;
264 const auto output_col_stride = output.
ld_col * this->m_strat->get_output_cols();
267 n < (this->m_strat->get_output_rows() - output_pad_bottom) * this->m_strat->get_output_cols();
278 const PoolingArgs &
args,
const OutputStage &os = {})