31 namespace addressing {
35 void **dest_raw,
const unsigned int array_rows,
const unsigned int array_cols,
36 void *base_ptr_raw,
size_t ld_row,
size_t ld_col,
38 const unsigned int pad_top,
const unsigned int valid_rows,
39 const unsigned int pad_left,
const unsigned int valid_cols
42 auto dest =
reinterpret_cast<char **
>(dest_raw);
43 auto base_ptr =
reinterpret_cast<char *
>(base_ptr_raw);
44 auto pad_buffer =
reinterpret_cast<char *
>(pad_buffer_raw);
45 ld_row *= element_size;
46 ld_col *= element_size;
48 const auto last_valid_row = std::min(pad_top + valid_rows, array_rows);
49 const auto last_valid_col = std::min(pad_left + valid_cols, array_cols);
52 for (; i < pad_top; i++)
54 for (
unsigned int j = 0; j < array_cols; j++)
56 *(
dest++) = pad_buffer;
59 for (; i < last_valid_row; i++)
62 auto colptr = base_ptr;
65 for (; j < pad_left; j++)
67 *(
dest++) = pad_buffer;
69 for (; j < last_valid_col; j++)
74 for (; j < array_cols; j++)
76 *(
dest++) = pad_buffer;
79 for (; i < array_rows; i++)
81 for (
unsigned int j = 0; j < array_cols; j++)
83 *(
dest++) = pad_buffer;
90 const size_t element_size,
92 const unsigned int output_rows,
const unsigned int output_cols,
93 const unsigned int kernel_rows,
const unsigned int kernel_cols,
94 const unsigned int stride_rows,
const unsigned int stride_cols,
95 void *base_ptr_raw,
size_t ld_row,
size_t ld_col,
97 const unsigned int pad_top,
const unsigned int valid_rows,
98 const unsigned int pad_left,
const unsigned int valid_cols
101 auto dest =
reinterpret_cast<char **
>(dest_raw);
102 auto base_ptr =
reinterpret_cast<char *
>(base_ptr_raw);
103 auto pad_buffer =
reinterpret_cast<char *
>(pad_buffer_raw);
104 ld_row *= element_size;
105 ld_col *= element_size;
107 const auto last_valid_row = pad_top + valid_rows;
108 const auto last_valid_col = pad_left + valid_cols;
109 const auto point_stride = output_rows * output_cols;
113 for (
unsigned int oi = 0; oi < output_rows; oi++)
115 for (
unsigned int oj = 0; oj < output_cols; oj++)
117 auto point_dest =
dest;
121 unsigned int ki = 0, ii = oi*stride_rows;
122 for (; ii < pad_top && ki < kernel_rows; ii++, ki++)
125 for (
unsigned int j = 0; j < kernel_cols; j++)
127 *point_dest = pad_buffer;
128 point_dest += point_stride;
131 for (; ii < last_valid_row && ki < kernel_rows; ii++, ki++)
133 unsigned int kj = 0, ij = oj*stride_cols;
134 for (; ij < pad_left && kj < kernel_cols; ij++, kj++)
137 *point_dest = pad_buffer;
138 point_dest += point_stride;
140 for (; ij < last_valid_col && kj < kernel_cols; ij++, kj++)
142 *point_dest = base_ptr + (ii - pad_top)*ld_row + (ij - pad_left)*ld_col;
143 point_dest += point_stride;
145 for (; kj < kernel_cols; kj++)
148 *point_dest = pad_buffer;
149 point_dest += point_stride;
152 for (; ki < kernel_rows; ki++)
155 for (
unsigned int j = 0; j < kernel_cols; j++)
157 *point_dest = pad_buffer;
158 point_dest += point_stride;
173 const void **dest_row_pointers_raw,
174 void *dest_patch_raw,
175 const unsigned int patch_rows,
unsigned int patch_cols,
176 const void *src_ptr_raw,
size_t ld_row,
size_t ld_col,
178 const unsigned int pad_top,
const unsigned int valid_rows,
179 const unsigned int pad_left,
const unsigned int valid_cols
183 auto row_pointers =
reinterpret_cast<const char **
>(dest_row_pointers_raw);
184 auto dest_patch =
reinterpret_cast<char *
>(dest_patch_raw);
185 auto src =
reinterpret_cast<const char *
>(src_ptr_raw);
186 ld_row *= element_size;
187 ld_col *= element_size;
190 patch_cols = arm_gemm::roundup<unsigned int>(patch_cols, 16 / element_size);
192 const auto last_valid_row = std::min(pad_top + valid_rows, patch_rows);
193 const auto last_valid_col = std::min(pad_left + valid_cols, patch_cols);
197 for (; i < pad_top; i++)
200 *(row_pointers++) =
reinterpret_cast<const char *
>(pad_row);
202 for (; i < last_valid_row; i++)
211 if (ld_col == element_size && pad_left == 0 && last_valid_col == patch_cols)
213 *(row_pointers++) = colptr;
217 auto patch_col = dest_patch;
218 *(row_pointers++) = dest_patch;
219 dest_patch += element_size * patch_cols;
223 memcpy(patch_col, pad_row, element_size * patch_cols);
224 patch_col += pad_left * element_size;
226 if (ld_col == element_size)
229 memcpy(patch_col, colptr, (last_valid_col - pad_left) * element_size);
234 for (
auto j = pad_left; j < last_valid_col; j++)
236 memcpy(patch_col, colptr, element_size);
237 patch_col += element_size;
243 for (; i < patch_rows; i++)
246 *(row_pointers++) =
reinterpret_cast<const char *
>(pad_row);
261 const void **dest_pointers_raw,
263 const unsigned int output_rows,
const unsigned int output_cols,
264 const unsigned int kernel_rows,
const unsigned int kernel_cols,
265 const unsigned int stride_rows,
const unsigned int stride_cols,
266 const void *src_ptr_raw,
size_t ld_row,
size_t ld_col,
268 const unsigned int pad_top,
const unsigned int valid_rows,
269 const unsigned int pad_left,
const unsigned int valid_cols
272 auto dest =
reinterpret_cast<const char **
>(dest_pointers_raw);
273 auto patch =
reinterpret_cast<char *
>(patch_raw);
274 auto src_ptr =
reinterpret_cast<const char *
>(src_ptr_raw);
275 ld_row *= element_size;
276 ld_col *= element_size;
279 const auto patch_cols = arm_gemm::roundup<unsigned int>(output_cols, 16 / element_size);
281 const auto input_rows = kernel_rows + (output_rows - 1) * stride_rows;
282 const auto last_valid_row = std::min(pad_top + valid_rows, input_rows);
284 const auto input_cols = kernel_cols + (output_cols - 1) * stride_cols;
285 const auto last_valid_col = std::min(pad_left + valid_cols, input_cols);
287 for (
auto ki = 0u; ki < kernel_rows; ki++)
289 for (
auto kj = 0u; kj < kernel_cols; kj++)
291 auto oi = 0u, ii = ki;
292 for (; oi < output_rows && ii < pad_top; oi++, ii += stride_rows)
295 *(
dest++) =
reinterpret_cast<const char *
>(pad_row);
297 for (; oi < output_rows && ii < last_valid_row; oi++, ii += stride_rows)
299 auto rowptr = src_ptr + (ii - pad_top) * ld_row;
302 auto patch_pos = patch;
304 patch += patch_cols * element_size;
307 memcpy(patch_pos, pad_row, patch_cols * element_size);
310 auto oj = 0u, ij = kj;
311 for (; oj < patch_cols && ij < pad_left; oj++, ij += stride_cols)
314 patch_pos += element_size;
316 for (; oj < patch_cols && ij < last_valid_col; oj++, ij += stride_cols)
319 memcpy(patch_pos, rowptr + (ij - pad_left)*ld_col, element_size);
320 patch_pos += element_size;
324 for (; oi < output_rows; oi++)
326 *(
dest++) =
reinterpret_cast<const char *
>(pad_row);