39 #include "src/core/NEON/kernels/convolution/common/shims.hpp"
52 static const std::array<PermutationVector, 2> permutations2 = {{
56 static const std::array<PermutationVector, 6> permutations3 = {{
64 static const std::array<PermutationVector, 24> permutations4 = {
65 {
PermutationVector(0U, 1U, 2U, 3U),
PermutationVector(1U, 0U, 2U, 3U),
PermutationVector(2U, 0U, 1U, 3U),
66 PermutationVector(0U, 2U, 1U, 3U),
PermutationVector(1U, 2U, 0U, 3U),
PermutationVector(2U, 1U, 0U, 3U),
67 PermutationVector(2U, 1U, 3U, 0U),
PermutationVector(1U, 2U, 3U, 0U),
PermutationVector(3U, 2U, 1U, 0U),
68 PermutationVector(2U, 3U, 1U, 0U),
PermutationVector(1U, 3U, 2U, 0U),
PermutationVector(3U, 1U, 2U, 0U),
69 PermutationVector(3U, 0U, 2U, 1U),
PermutationVector(0U, 3U, 2U, 1U),
PermutationVector(2U, 3U, 0U, 1U),
70 PermutationVector(3U, 2U, 0U, 1U),
PermutationVector(0U, 2U, 3U, 1U),
PermutationVector(2U, 0U, 3U, 1U),
71 PermutationVector(1U, 0U, 3U, 2U),
PermutationVector(0U, 1U, 3U, 2U),
PermutationVector(3U, 1U, 0U, 2U),
72 PermutationVector(1U, 3U, 0U, 2U),
PermutationVector(0U, 3U, 1U, 2U),
PermutationVector(3U, 0U, 1U, 2U)}};
74 return (permutations2.end() != std::find(permutations2.begin(), permutations2.end(), v)) ||
75 (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v)) ||
76 (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v));
87 if (
dst->total_size() != 0)
103 Window window_src = window;
111 Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
113 Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start()));
115 Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start()));
116 window_src.set(3, Window::Dimension(window[3].start(), window[3].
end(), window[3].
end() - window[3].start()));
120 Window window_dst(window);
121 const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
122 for (
size_t d = 0; d <=
dst->info()->num_dimensions(); ++d)
124 window_dst.set(d, zero_window);
128 Iterator src_it(
src, window_src);
129 Iterator dst_it(
dst, window_dst);
131 int in_row_stride = 0;
132 int in_col_stride = 0;
133 int in_channel_stride = 0;
134 int in_batch_stride = 0;
144 in_row_stride =
src->info()->strides_in_bytes().y() /
sizeof(T);
145 in_channel_stride =
src->info()->strides_in_bytes().z() /
sizeof(T);
146 in_batch_stride =
src->info()->strides_in_bytes()[3] /
sizeof(T);
147 n_cols =
src->info()->tensor_shape().x();
148 n_rows = window_src.y().step();
149 n_channels =
src->info()->tensor_shape().z();
150 n_batches =
src->info()->tensor_shape()[3];
155 in_col_stride =
src->info()->strides_in_bytes().y() /
sizeof(T);
156 in_row_stride =
src->info()->strides_in_bytes().z() /
sizeof(T);
157 in_batch_stride =
src->info()->strides_in_bytes()[3] /
sizeof(T);
158 n_channels =
src->info()->tensor_shape().x();
159 n_cols = window_src.y().step();
160 n_rows =
src->info()->tensor_shape().z();
161 n_batches =
src->info()->tensor_shape()[3];
174 const int out_channel_stride =
dst->info()->strides_in_bytes().x() /
sizeof(T);
175 const int out_col_stride =
dst->info()->strides_in_bytes().y() /
sizeof(T);
176 const int out_row_stride =
dst->info()->strides_in_bytes().z() /
sizeof(T);
177 const int out_batch_stride =
dst->info()->strides_in_bytes()[3] /
sizeof(T);
180 [&](
const Coordinates &
id)
182 const int idx =
id[0] * out_col_stride +
id[1] * out_row_stride +
id[2] * out_channel_stride;
183 reorder::nchw_to_nhwc(
reinterpret_cast<const T *
>(src_it.ptr()),
184 reinterpret_cast<T *
>(dst_it.ptr()) + idx, n_batches, n_channels, n_rows, n_cols,
185 in_batch_stride, in_channel_stride, in_row_stride, out_batch_stride,
186 out_row_stride, out_col_stride);
193 const int out_col_stride =
dst->info()->strides_in_bytes().x() /
sizeof(T);
194 const int out_row_stride =
dst->info()->strides_in_bytes().y() /
sizeof(T);
195 const int out_channel_stride =
dst->info()->strides_in_bytes().z() /
sizeof(T);
196 const int out_batch_stride =
dst->info()->strides_in_bytes()[3] /
sizeof(T);
199 [&](
const Coordinates &
id)
201 const int idx =
id[0] * out_channel_stride +
id[1] * out_col_stride +
id[2] * out_row_stride;
202 reorder::nhwc_to_nchw(
reinterpret_cast<const T *
>(src_it.ptr()),
203 reinterpret_cast<T *
>(dst_it.ptr()) + idx, n_batches, n_rows, n_cols, n_channels,
204 in_batch_stride, in_row_stride, in_col_stride, out_batch_stride,
205 out_channel_stride, out_row_stride);
213 Strides strides =
dst->info()->strides_in_bytes();
214 Strides perm_strides = strides;
216 const int perm_stride_3 =
src->info()->num_dimensions() >= 4 ? perm_strides[3] : 0;
219 [&](
const Coordinates &
id)
222 id[0] * perm_strides[0] +
id[1] * perm_strides[1] +
id[2] * perm_strides[2] +
id[3] * perm_stride_3;
223 *(
reinterpret_cast<T *
>(dst_it.ptr() + idx)) = *(
reinterpret_cast<const T *
>(src_it.ptr()));
247 ICpuKernel::configure(win);
265 switch (
src->info()->element_size())
284 return "CpuPermuteKernel";