38 const uintptr_t src_shape[4],
39 const uintptr_t src_strides[4],
40 const uintptr_t dst_strides[4],
41 uintptr_t element_size,
47 const auto dst_channels = src_shape[2] / (block_size * block_size);
48 const auto src_block_col_stride = dst_channels * src_strides[2];
49 const auto src_block_row_stride = block_size * dst_channels * src_strides[2];
51 auto *src_batch_ptr =
src;
52 auto *dst_batch_ptr =
dst;
54 for (uintptr_t batch = 0; batch < src_shape[3]; ++batch)
56 auto *src_channel_ptr = src_batch_ptr;
57 auto *dst_channel_ptr = dst_batch_ptr;
59 for (uintptr_t channel = 0; channel < dst_channels; ++channel)
61 auto *src_height_block_ptr = src_channel_ptr;
62 auto *dst_row_ptr = dst_channel_ptr;
64 for (uintptr_t height_block = 0; height_block < src_shape[1]; ++height_block)
66 auto *src_block_row_ptr = src_height_block_ptr;
68 for (uintptr_t block_row = 0; block_row < block_size; ++block_row)
70 auto *src_width_block_ptr = src_block_row_ptr;
71 auto *dst_col_ptr = dst_row_ptr;
73 for (uintptr_t width_block = 0; width_block < src_shape[0]; ++width_block)
75 auto *src_block_col_ptr = src_width_block_ptr;
77 for (uintptr_t block_col = 0; block_col < block_size; ++block_col)
97 std::memcpy(dst_col_ptr, src_block_col_ptr, element_size);
99 src_block_col_ptr += src_block_col_stride;
100 dst_col_ptr += element_size;
103 src_width_block_ptr += element_size;
106 src_block_row_ptr += src_block_row_stride;
107 dst_row_ptr += dst_strides[1];
110 src_height_block_ptr += src_strides[1];
113 src_channel_ptr += src_strides[2];
114 dst_channel_ptr += dst_strides[2];
117 src_batch_ptr += src_strides[3];
118 dst_batch_ptr += dst_strides[3];