27 #include "depthwise.hpp"
65 template <
typename TInput,
typename TWeight,
typename TOutput>
69 using Parent = DepthwiseCommon<TInput, TWeight, TOutput>;
72 std::unique_ptr<const IDepthfirstStrategy> m_strat;
75 virtual size_t get_working_size_per_thread()
const = 0;
78 virtual void initialise_working_space(
void *)
const = 0;
81 virtual void compute_tile_padded(
82 const DepthwiseArgs &
args,
83 unsigned int output_i,
unsigned int output_j,
84 unsigned int output_channel_start,
unsigned int output_channel_end,
96 virtual void compute_row_padded_tile_row(
97 const DepthwiseArgs &
args,
98 const unsigned int output_i,
unsigned int output_j,
unsigned int n_tile_cols,
99 const unsigned int output_channel_start,
const unsigned int output_channel_end,
106 for (; n_tile_cols; n_tile_cols--, output_j += m_strat->get_output_cols())
108 this->compute_tile_padded(
110 output_i, output_j, output_channel_start, output_channel_end,
121 virtual void compute_tiles_unpadded(
122 const DepthwiseArgs &
args,
123 unsigned int start_output_i,
unsigned int start_output_j,
124 unsigned int n_tile_rows,
unsigned int n_tile_cols,
125 unsigned int output_channel_start,
unsigned int output_channel_end,
132 for (
unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
134 unsigned int row_start_output_j = start_output_j;
135 for (
unsigned int tile_j = 0; tile_j < n_tile_cols; tile_j++)
137 this->compute_tile_padded(
139 start_output_i, row_start_output_j,
140 output_channel_start, output_channel_end,
143 row_start_output_j += m_strat->get_output_cols();
145 start_output_i += m_strat->get_output_rows();
149 void execute_internal(
150 const DepthwiseArgs &
args,
154 size_t ld_input_batch,
157 size_t ld_output_col,
158 size_t ld_output_row,
159 size_t ld_output_batch,
161 unsigned int thread_id,
162 unsigned int n_threads
166 void *thread_working_space =
167 static_cast<uint8_t *
>(working_space) + thread_id * this->get_working_size_per_thread();
168 this->initialise_working_space(thread_working_space);
172 TensorSpec<TOutput *> output_tensor(
reinterpret_cast<TOutput *
>(output), ld_output_row, ld_output_col);
174 const auto n_output_channels =
args.input_channels *
args.channel_multiplier;
178 auto thread_id_for_rows = thread_id;
179 auto n_threads_for_rows = n_threads;
180 auto thread_id_for_batches = 0;
181 auto n_threads_for_batches = 1;
182 if (
args.output_rows == 1) {
183 thread_id_for_rows = 0;
184 n_threads_for_rows = 1;
185 thread_id_for_batches = thread_id;
186 n_threads_for_batches = n_threads;
190 input_tensor.
base += ld_input_batch*thread_id_for_batches;
191 output_tensor.
base += ld_output_batch*thread_id_for_batches;
192 for (
unsigned int batch = thread_id_for_batches;
193 batch <
args.n_batches;
194 batch += n_threads_for_batches)
197 for (
unsigned int start_output_i = thread_id_for_rows * m_strat->get_output_rows();
198 start_output_i <
args.output_rows;
199 start_output_i += n_threads_for_rows * m_strat->get_output_rows())
203 const auto end_output_i = start_output_i + m_strat->get_output_rows();
204 const bool pad_output_bottom =
args.output_rows < end_output_i;
206 const int start_input_i = start_output_i *
args.stride_rows -
args.padding.top;
207 const bool pad_input_top = start_input_i < 0;
208 const int end_input_i = start_input_i + m_strat->get_input_rows();
209 const bool pad_input_bottom =
static_cast<int>(
args.input_rows) < end_input_i;
212 || pad_output_bottom;
217 unsigned int start_output_j = 0;
218 while (start_output_j <
args.output_cols)
220 const int start_in_j = start_output_j *
args.stride_cols -
args.padding.left;
221 const bool pad_input_left = start_in_j < 0;
224 int n_unpadded_tiles = 0;
228 n_unpadded_tiles = (
args.output_cols - start_output_j) / m_strat->get_output_cols();
231 const int tile_stride = m_strat->get_output_cols() *
args.stride_cols;
232 int end_output_j = start_output_j + n_unpadded_tiles * m_strat->get_output_cols();
233 int end_input_j = start_in_j + m_strat->get_input_cols() + (n_unpadded_tiles - 1)*tile_stride;
235 while (n_unpadded_tiles > 0 &&
236 (
static_cast<int>(
args.output_cols) < end_output_j ||
237 static_cast<int>(
args.input_cols) < end_input_j))
240 end_output_j -= m_strat->get_output_cols();
241 end_input_j -= tile_stride;
246 if (n_unpadded_tiles)
251 this->compute_tiles_unpadded(
253 start_output_i, start_output_j,
255 0, n_output_channels,
256 input_tensor, output_tensor,
parameters, thread_working_space
262 this->compute_row_padded_tile_row(
264 start_output_i, start_output_j, n_unpadded_tiles,
265 0, n_output_channels,
266 input_tensor, output_tensor,
parameters, thread_working_space
269 start_output_j += n_unpadded_tiles * m_strat->get_output_cols();
273 this->compute_tile_padded(
275 start_output_i, start_output_j,
276 0, n_output_channels,
277 input_tensor, output_tensor,
parameters, thread_working_space
279 start_output_j += m_strat->get_output_cols();
285 input_tensor.
base += ld_input_batch*n_threads_for_batches;
286 output_tensor.
base += ld_output_batch*n_threads_for_batches;
298 return n_threads * this->get_working_size_per_thread();