27 #include "pooling.hpp"
57 template <
typename TInput,
typename TOutput>
61 using Parent = PoolingCommon<TInput, TOutput>;
64 std::unique_ptr<const IDepthfirstStrategy> m_strat;
67 virtual size_t get_working_size_per_thread()
const = 0;
70 virtual void initialise_working_space(
void *)
const = 0;
73 virtual void compute_tile_padded(
74 unsigned int output_i,
unsigned int output_j,
75 unsigned int output_channel_start,
unsigned int output_channel_end,
86 virtual void compute_row_padded_tile_row(
87 const unsigned int output_i,
unsigned int output_j,
unsigned int n_tile_cols,
88 const unsigned int output_channel_start,
const unsigned int output_channel_end,
94 for (; n_tile_cols; n_tile_cols--, output_j += m_strat->get_output_cols())
96 this->compute_tile_padded(
97 output_i, output_j, output_channel_start, output_channel_end,
98 input, output, working_space
108 virtual void compute_tiles_unpadded(
109 unsigned int start_output_i,
unsigned int start_output_j,
110 unsigned int n_tile_rows,
unsigned int n_tile_cols,
111 unsigned int output_channel_start,
unsigned int output_channel_end,
117 for (
unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
119 this->compute_row_padded_tile_row(
120 start_output_i, start_output_j, n_tile_cols,
121 output_channel_start, output_channel_end,
122 input, output, working_space
124 start_output_i += m_strat->get_output_rows();
128 void execute_internal(
129 unsigned int n_batches,
130 unsigned int input_height,
131 unsigned int input_width,
132 unsigned int n_channels,
133 const PaddingValues &padding,
137 size_t ld_input_batch,
138 unsigned int output_height,
139 unsigned int output_width,
141 size_t ld_output_col,
142 size_t ld_output_row,
143 size_t ld_output_batch,
145 unsigned int thread_id,
146 unsigned int n_threads
150 void *thread_working_space =
151 static_cast<uint8_t *
>(working_space) + thread_id * this->get_working_size_per_thread();
152 this->initialise_working_space(thread_working_space);
156 TensorSpec<TOutput *> output_tensor(
reinterpret_cast<TOutput *
>(output), ld_output_row, ld_output_col);
161 if (n_threads > 1 && output_height == 1 && output_width == 1)
167 const auto start_channel = thread_id * channels_per_thread;
168 const auto end_channel = std::min(start_channel + channels_per_thread, n_channels);
170 if (start_channel >= end_channel)
176 for (; n_batches; n_batches--)
180 this->compute_tile_padded(
182 start_channel, end_channel,
183 input_tensor, output_tensor, thread_working_space
187 input_tensor.
base += ld_input_batch;
188 output_tensor.
base += ld_output_batch;
195 for (
unsigned int batch = 0; batch < n_batches; batch++)
198 for (
unsigned int start_output_i = thread_id * m_strat->get_output_rows();
199 start_output_i < output_height;
200 start_output_i += n_threads * m_strat->get_output_rows())
204 const auto end_output_i = start_output_i + m_strat->get_output_rows();
205 const bool pad_output_bottom = output_height < end_output_i;
207 const int start_input_i = start_output_i * this->m_args.pool_stride.rows - padding.top;
208 const bool pad_input_top = start_input_i < 0;
209 const int end_input_i = start_input_i + m_strat->get_input_rows();
210 const bool pad_input_bottom =
static_cast<int>(input_height) < end_input_i;
211 const bool pad_row = pad_input_top || pad_input_bottom || pad_output_bottom;
216 unsigned int start_output_j = 0;
217 while (start_output_j < output_width)
219 const int start_in_j = start_output_j * this->m_args.pool_stride.cols - padding.left;
220 const bool pad_input_left = start_in_j < 0;
223 int n_unpadded_tiles = 0;
227 n_unpadded_tiles = (output_width - start_output_j) / m_strat->get_output_cols();
230 const int tile_stride = m_strat->get_output_cols() * this->m_args.pool_stride.cols;
231 int end_output_j = start_output_j + n_unpadded_tiles * m_strat->get_output_cols();
232 int end_input_j = start_in_j + m_strat->get_input_cols() + (n_unpadded_tiles - 1)*tile_stride;
234 while (n_unpadded_tiles > 0 &&
235 (
static_cast<int>(output_width) < end_output_j ||
236 static_cast<int>(input_width) < end_input_j))
239 end_output_j -= m_strat->get_output_cols();
240 end_input_j -= tile_stride;
245 if (n_unpadded_tiles)
250 this->compute_tiles_unpadded(
251 start_output_i, start_output_j,
254 input_tensor, output_tensor, thread_working_space
260 this->compute_row_padded_tile_row(
261 start_output_i, start_output_j, n_unpadded_tiles,
263 input_tensor, output_tensor, thread_working_space
266 start_output_j += n_unpadded_tiles * m_strat->get_output_cols();
270 this->compute_tile_padded(
271 start_output_i, start_output_j,
273 input_tensor, output_tensor, thread_working_space
275 start_output_j += m_strat->get_output_cols();
281 input_tensor.
base += ld_input_batch;
282 output_tensor.
base += ld_output_batch;
294 return n_threads * this->get_working_size_per_thread();