67 using TInput =
typename strategy::input_type;
68 using TOutput =
typename strategy::return_type;
72 const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
73 const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
76 const TInput *
const inptr =
static_cast<const TInput *
>(_input);
77 TOutput *
const outptr =
static_cast<TOutput *
>(_output);
83 TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*(16 /
sizeof(TInput))];
84 const TInput *inptrs[strategy::input_rows];
87 TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
88 TOutput **
const outptr_array = _outptr_array;
91 uint8_t *
const working_space =
static_cast<uint8_t *
>(_working_space);
92 TOutput *
const output_buffer =
reinterpret_cast<TOutput *
>(working_space);
96 for (
unsigned int batch = 0; batch <
batches; batch++)
99 const auto inptr_batch = inptr + batch * ld_input_batch;
100 const auto outptr_batch = outptr + batch * ld_output_batch;
102 for (
int start_out_i = start_out_height;
103 start_out_i < end_out_height;
104 start_out_i +=
static_cast<int>(strategy::output_rows))
106 const int end_out_i = start_out_i + strategy::output_rows;
107 const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
108 const int end_in_i = start_in_i + strategy::input_rows;
111 const auto pad_top =
static_cast<unsigned int>(-std::min(start_in_i, 0));
112 const auto pad_bottom =
static_cast<unsigned int>(-std::min(static_cast<int>(
input_height) - end_in_i, 0));
113 const unsigned int valid_output_rows = std::min(
114 end_out_i - start_out_i,
115 static_cast<int>(output_height) - start_out_i
118 for (
int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
120 const int start_in_j = start_out_j * strategy::stride_cols -
args.padding.left;
121 const int pad_left = -std::min(0, start_in_j);
123 const int end_out_j = start_out_j + strategy::output_cols;
124 const int end_in_j = start_in_j + strategy::input_cols;
126 const auto pad_right =
static_cast<unsigned int>(-std::min(static_cast<int>(
input_width) - end_in_j, 0));
127 const unsigned int valid_output_cols = std::min(
128 end_out_j - start_out_j,
129 static_cast<int>(output_width) - start_out_j
133 TOutput **outptr_pos = outptr_array;
134 for (
auto i = 0u; i < valid_output_rows; i++)
137 TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
138 for (; j < valid_output_cols; j++)
140 *(outptr_pos++) = colptr;
141 colptr += ld_output_col;
143 for (; j < strategy::output_cols; j++)
145 *(outptr_pos++) = output_buffer;
148 for (
auto i = valid_output_rows; i < strategy::output_rows; i++)
150 for (
auto j = 0u; j < strategy::output_cols; j++)
152 *(outptr_pos++) = output_buffer;
156 start_out_j += strategy::output_cols;
158 const uint8_t *params =
static_cast<const uint8_t *
>(
parameters);
161 for (
unsigned int in_c = 0; in_c < input_channels; in_c++)
165 for (
unsigned int i = 0; i < strategy::input_rows; i++)
167 for (
unsigned int j = 0;
168 j < (16 /
sizeof(TInput)) * strategy::input_col_quads; j++)
170 rearranged_input[i][j] = pad_value;
172 inptrs[i] = rearranged_input[i];
175 auto inptr_row = inptr_batch + in_c +
176 (start_in_i + pad_top) * ld_input_row +
177 (start_in_j + pad_left) * ld_input_col;
178 if (ld_input_col == 1 && !pad_left &&
179 start_in_j + (16 /
sizeof(TInput)) * strategy::input_col_quads <
input_width)
184 for (
unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
186 inptrs[i] = inptr_row;
187 inptr_row += ld_input_row;
195 for (
unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
197 auto inptr_col = inptr_row;
198 for (
unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
200 rearranged_input[i][j] = *inptr_col;
201 inptr_col += ld_input_col;
203 inptr_row += ld_input_row;
207 execute_tile(inptrs, outptr_array, params);
210 TOutput **outptr_pos = outptr_array;
211 for (
auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
213 outptr_pos[i] +=
args.channel_multiplier;
217 params += param_stride;
T iceildiv(const T a, const T b)
std::unique_ptr< ParametersLibrary > parameters
const size_t input_height