47 unsigned int num_elems_processed(
size_t element_size)
68 void transpose_8bit_elements(
const ITensor *in, ITensor *out,
const Window &window)
70 const int window_step_x = 8;
71 const int window_step_y = 8;
72 const int window_start_x = window.x().start();
73 const int window_end_x = window.x().end();
74 const int window_start_y = window.y().start();
75 const int window_end_y = std::min(window.y().end(),
static_cast<int>(in->info()->dimension(1)));
76 const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
77 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
78 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
81 bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
83 Window window_in(window);
88 if (window_end_y_multiple_of > window_start_y)
90 window_in.set(
Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
98 Window window_out(window);
99 window_out.set(
Window::DimX, Window::Dimension(0, 0, 0));
100 window_out.set(
Window::DimY, Window::Dimension(0, 0, 0));
102 Iterator output(out, window_out);
105 if (in->info()->dimension(1) != 1)
107 Iterator
input(in, window_in);
110 [&](
const Coordinates &
id)
113 int x = window_start_x;
114 for (; x <= (window_end_x - window_step_x); x += window_step_x)
116 const uint8x8_t row0 =
117 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 0 * input_stride_in_bytes));
118 const uint8x8_t row1 =
119 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 1 * input_stride_in_bytes));
120 const uint8x8_t row2 =
121 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 2 * input_stride_in_bytes));
122 const uint8x8_t row3 =
123 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 3 * input_stride_in_bytes));
124 const uint8x8_t row4 =
125 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 4 * input_stride_in_bytes));
126 const uint8x8_t row5 =
127 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 5 * input_stride_in_bytes));
128 const uint8x8_t row6 =
129 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 6 * input_stride_in_bytes));
130 const uint8x8_t row7 =
131 vld1_u8(
reinterpret_cast<const uint8_t *
>(
input.ptr() + x + 7 * input_stride_in_bytes));
134 const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
135 const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
136 const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
137 const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
140 const uint16x4x2_t k0_u16 =
141 vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
142 const uint16x4x2_t k1_u16 =
143 vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
144 const uint16x4x2_t k2_u16 =
145 vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
146 const uint16x4x2_t k3_u16 =
147 vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
150 const uint32x2x2_t k0_u32 =
151 vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
152 const uint32x2x2_t k1_u32 =
153 vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
154 const uint32x2x2_t k2_u32 =
155 vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
156 const uint32x2x2_t k3_u32 =
157 vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
160 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint8_t) + x * output_stride_in_bytes;
163 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
164 vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
166 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
167 vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
169 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
170 vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
172 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
173 vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
175 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes),
176 vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
178 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes),
179 vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
181 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes),
182 vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
184 reinterpret_cast<uint8_t *
>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes),
185 vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
189 for (; x < window_end_x; ++x)
191 const uint8_t val0 = *(
input.ptr() + x + 0 * input_stride_in_bytes);
192 const uint8_t val1 = *(
input.ptr() + x + 1 * input_stride_in_bytes);
193 const uint8_t val2 = *(
input.ptr() + x + 2 * input_stride_in_bytes);
194 const uint8_t val3 = *(
input.ptr() + x + 3 * input_stride_in_bytes);
195 const uint8_t val4 = *(
input.ptr() + x + 4 * input_stride_in_bytes);
196 const uint8_t val5 = *(
input.ptr() + x + 5 * input_stride_in_bytes);
197 const uint8_t val6 = *(
input.ptr() + x + 6 * input_stride_in_bytes);
198 const uint8_t val7 = *(
input.ptr() + x + 7 * input_stride_in_bytes);
200 uint8x8_t result = vdup_n_u8(0);
201 result = vset_lane_u8(val0, result, 0);
202 result = vset_lane_u8(val1, result, 1);
203 result = vset_lane_u8(val2, result, 2);
204 result = vset_lane_u8(val3, result, 3);
205 result = vset_lane_u8(val4, result, 4);
206 result = vset_lane_u8(val5, result, 5);
207 result = vset_lane_u8(val6, result, 6);
208 result = vset_lane_u8(val7, result, 7);
211 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint8_t) + x * output_stride_in_bytes;
213 vst1_u8(output.ptr() + dst_offset_in_bytes, result);
219 if (left_over_loop_y)
221 window_in.set(
Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
222 window_in.set(
Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
224 Iterator
input(in, window_in);
225 Iterator output(out, window_out);
230 [&](
const Coordinates &
id)
232 const uint8_t val0 = *
input.ptr();
235 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint8_t) +
id.x() * output_stride_in_bytes;
237 *(output.ptr() + dst_offset_in_bytes) = val0;
243 void transpose_16bit_elements(
const ITensor *in, ITensor *out,
const Window &window)
245 const int window_step_x = 4;
246 const int window_step_y = 4;
247 const int window_start_x = window.x().start();
248 const int window_end_x = window.x().end();
249 const int window_start_y = window.y().start();
250 const int window_end_y = std::min(window.y().end(),
static_cast<int>(in->info()->dimension(1)));
251 const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
252 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
253 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
256 bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
258 Window window_in(window);
259 window_in.set(
Window::DimX, Window::Dimension(0, 1, 1));
260 if (left_over_loop_y)
263 if (window_end_y_multiple_of > window_start_y)
265 window_in.set(
Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
269 window_in.set(
Window::DimY, Window::Dimension(0, 0, 1));
273 Window window_out(window);
274 window_out.set(
Window::DimX, Window::Dimension(0, 0, 0));
275 window_out.set(
Window::DimY, Window::Dimension(0, 0, 0));
277 Iterator output(out, window_out);
280 if (in->info()->dimension(1) != 1)
282 Iterator
input(in, window_in);
285 [&](
const Coordinates &
id)
288 int x = window_start_x;
289 for (; x <= (window_end_x - window_step_x); x += window_step_x)
291 const uint16x4_t row0 =
292 vld1_u16(
reinterpret_cast<const uint16_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
293 const uint16x4_t row1 =
294 vld1_u16(
reinterpret_cast<const uint16_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
295 const uint16x4_t row2 =
296 vld1_u16(
reinterpret_cast<const uint16_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
297 const uint16x4_t row3 =
298 vld1_u16(
reinterpret_cast<const uint16_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
301 const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
302 const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
305 const uint32x2x2_t k0_u32 =
306 vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
307 const uint32x2x2_t k1_u32 =
308 vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
311 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint16_t) + x * output_stride_in_bytes;
314 reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
315 vreinterpret_u16_u32(k0_u32.val[0]));
317 reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
318 vreinterpret_u16_u32(k1_u32.val[0]));
320 reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
321 vreinterpret_u16_u32(k0_u32.val[1]));
323 reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
324 vreinterpret_u16_u32(k1_u32.val[1]));
328 for (; x < window_end_x; ++x)
330 const uint16_t val0 = *(
reinterpret_cast<uint16_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
331 const uint16_t val1 = *(
reinterpret_cast<uint16_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
332 const uint16_t val2 = *(
reinterpret_cast<uint16_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
333 const uint16_t val3 = *(
reinterpret_cast<uint16_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
335 uint16x4_t result = vdup_n_u16(0);
336 result = vset_lane_u16(val0, result, 0);
337 result = vset_lane_u16(val1, result, 1);
338 result = vset_lane_u16(val2, result, 2);
339 result = vset_lane_u16(val3, result, 3);
342 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint16_t) + x * output_stride_in_bytes;
344 vst1_u16(
reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes), result);
350 if (left_over_loop_y)
352 window_in.set(
Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
353 window_in.set(
Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
355 Iterator
input(in, window_in);
356 Iterator output(out, window_out);
361 [&](
const Coordinates &
id)
363 const uint16_t val0 = *(
reinterpret_cast<uint16_t *
>(
input.ptr()));
366 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint16_t) +
id.x() * output_stride_in_bytes;
368 *(
reinterpret_cast<uint16_t *
>(output.ptr() + dst_offset_in_bytes)) = val0;
375 inline uint32x4x2_t vld1q_u32_x2_(
const uint32_t *ptr)
378 return {vld1q_u32(ptr), vld1q_u32(ptr + 4)};
381 inline void vst1q_u32_x2_(
const uint32_t *ptr,
const uint32x4x2_t &val)
384 vst1q_u32(
const_cast<uint32_t *
>(ptr), val.val[0]);
385 vst1q_u32(
const_cast<uint32_t *
>(ptr + 4), val.val[1]);
388 void transpose_32bit_elements(
const ITensor *in, ITensor *out,
const Window &window)
390 constexpr
int window_step_x = 8;
391 constexpr
int window_step_y = 8;
392 const int window_start_x = window.x().start();
393 const int window_end_x = window.x().end();
394 const int window_start_y = window.y().start();
395 const int window_end_y = std::min(window.y().end(),
static_cast<int>(in->info()->dimension(1)));
396 const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
397 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
398 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
401 bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
403 Window window_in(window);
404 window_in.set(
Window::DimX, Window::Dimension(0, 1, 1));
405 if (left_over_loop_y)
408 if (window_end_y_multiple_of > window_start_y)
410 window_in.set(
Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
414 window_in.set(
Window::DimY, Window::Dimension(0, 0, 1));
418 Window window_out(window);
419 window_out.set(
Window::DimX, Window::Dimension(0, 0, 0));
420 window_out.set(
Window::DimY, Window::Dimension(0, 0, 0));
422 Iterator output(out, window_out);
425 if (in->info()->dimension(1) != 1)
427 Iterator
input(in, window_in);
430 [&](
const Coordinates &
id)
433 int x = window_start_x;
434 for (; x <= (window_end_x - window_step_x); x += window_step_x)
437 const uint32x4x2_t row0 =
438 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
439 const uint32x4x2_t row1 =
440 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
441 const uint32x4x2_t row2 =
442 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
443 const uint32x4x2_t row3 =
444 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
445 const uint32x4x2_t row4 =
446 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 4 * input_stride_in_bytes) + x);
447 const uint32x4x2_t row5 =
448 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 5 * input_stride_in_bytes) + x);
449 const uint32x4x2_t row6 =
450 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 6 * input_stride_in_bytes) + x);
451 const uint32x4x2_t row7 =
452 vld1q_u32_x2_(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 7 * input_stride_in_bytes) + x);
455 const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]),
456 vtrn2q_u32(row0.val[0], row1.val[0])};
457 const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]),
458 vtrn2q_u32(row0.val[1], row1.val[1])};
459 const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]),
460 vtrn2q_u32(row2.val[0], row3.val[0])};
461 const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]),
462 vtrn2q_u32(row2.val[1], row3.val[1])};
463 const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]),
464 vtrn2q_u32(row4.val[0], row5.val[0])};
465 const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]),
466 vtrn2q_u32(row4.val[1], row5.val[1])};
467 const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]),
468 vtrn2q_u32(row6.val[0], row7.val[0])};
469 const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]),
470 vtrn2q_u32(row6.val[1], row7.val[1])};
473 const uint64x2x2_t k0_u64 = {
474 vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])),
475 vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))};
476 const uint64x2x2_t k1_u64 = {
477 vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])),
478 vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))};
479 const uint64x2x2_t k2_u64 = {
480 vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])),
481 vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))};
482 const uint64x2x2_t k3_u64 = {
483 vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])),
484 vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))};
485 const uint64x2x2_t k4_u64 = {
486 vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])),
487 vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))};
488 const uint64x2x2_t k5_u64 = {
489 vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])),
490 vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))};
491 const uint64x2x2_t k6_u64 = {
492 vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])),
493 vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))};
494 const uint64x2x2_t k7_u64 = {
495 vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])),
496 vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))};
499 const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]),
500 vreinterpretq_u32_u64(k4_u64.val[0])};
501 const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]),
502 vreinterpretq_u32_u64(k5_u64.val[0])};
503 const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]),
504 vreinterpretq_u32_u64(k4_u64.val[1])};
505 const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]),
506 vreinterpretq_u32_u64(k5_u64.val[1])};
507 const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]),
508 vreinterpretq_u32_u64(k6_u64.val[0])};
509 const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]),
510 vreinterpretq_u32_u64(k7_u64.val[0])};
511 const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]),
512 vreinterpretq_u32_u64(k6_u64.val[1])};
513 const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]),
514 vreinterpretq_u32_u64(k7_u64.val[1])};
517 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) + x * output_stride_in_bytes;
521 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
524 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
527 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
530 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
533 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes),
536 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes),
539 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes),
542 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes),
547 for (; x < window_end_x; ++x)
549 const uint32_t val0 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
550 const uint32_t val1 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
551 const uint32_t val2 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
552 const uint32_t val3 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
553 const uint32_t val4 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 4 * input_stride_in_bytes) + x);
554 const uint32_t val5 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 5 * input_stride_in_bytes) + x);
555 const uint32_t val6 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 6 * input_stride_in_bytes) + x);
556 const uint32_t val7 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 7 * input_stride_in_bytes) + x);
558 uint32x4_t result0 = vdupq_n_u32(0);
559 uint32x4_t result1 = vdupq_n_u32(0);
560 result0 = vsetq_lane_u32(val0, result0, 0);
561 result0 = vsetq_lane_u32(val1, result0, 1);
562 result0 = vsetq_lane_u32(val2, result0, 2);
563 result0 = vsetq_lane_u32(val3, result0, 3);
564 result1 = vsetq_lane_u32(val4, result1, 0);
565 result1 = vsetq_lane_u32(val5, result1, 1);
566 result1 = vsetq_lane_u32(val6, result1, 2);
567 result1 = vsetq_lane_u32(val7, result1, 3);
570 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) + x * output_stride_in_bytes;
572 vst1q_u32_x2_(
reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes), {result0, result1});
578 if (left_over_loop_y)
580 window_in.set(
Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
581 window_in.set(
Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
583 Iterator
input(in, window_in);
584 Iterator output(out, window_out);
589 [&](
const Coordinates &
id)
591 const uint32_t val0 = *(
reinterpret_cast<uint32_t *
>(
input.ptr()));
594 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) +
id.x() * output_stride_in_bytes;
596 *(
reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes)) = val0;
602 void transpose_32bit_elements(
const ITensor *in, ITensor *out,
const Window &window)
604 const int window_step_x = 4;
605 const int window_step_y = 4;
606 const int window_start_x = window.x().start();
607 const int window_end_x = window.x().end();
608 const int window_start_y = window.y().start();
609 const int window_end_y = std::min(window.y().end(),
static_cast<int>(in->info()->dimension(1)));
610 const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
611 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
612 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
615 bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
617 Window window_in(window);
618 window_in.set(
Window::DimX, Window::Dimension(0, 1, 1));
619 if (left_over_loop_y)
622 if (window_end_y_multiple_of > window_start_y)
624 window_in.set(
Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
628 window_in.set(
Window::DimY, Window::Dimension(0, 0, 1));
632 Window window_out(window);
633 window_out.set(
Window::DimX, Window::Dimension(0, 0, 0));
634 window_out.set(
Window::DimY, Window::Dimension(0, 0, 0));
636 Iterator output(out, window_out);
639 if (in->info()->dimension(1) != 1)
641 Iterator
input(in, window_in);
644 [&](
const Coordinates &
id)
647 int x = window_start_x;
648 for (; x <= (window_end_x - window_step_x); x += window_step_x)
650 const uint32x4_t row0 =
651 vld1q_u32(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
652 const uint32x4_t row1 =
653 vld1q_u32(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
654 const uint32x4_t row2 =
655 vld1q_u32(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
656 const uint32x4_t row3 =
657 vld1q_u32(
reinterpret_cast<const uint32_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
660 const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
661 const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
662 const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
663 const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
666 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) + x * output_stride_in_bytes;
670 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
671 vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
673 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
674 vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
676 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
677 vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
679 reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
680 vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
684 for (; x < window_end_x; ++x)
686 const uint32_t val0 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 0 * input_stride_in_bytes) + x);
687 const uint32_t val1 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 1 * input_stride_in_bytes) + x);
688 const uint32_t val2 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 2 * input_stride_in_bytes) + x);
689 const uint32_t val3 = *(
reinterpret_cast<uint32_t *
>(
input.ptr() + 3 * input_stride_in_bytes) + x);
691 uint32x4_t result = vdupq_n_u32(0);
692 result = vsetq_lane_u32(val0, result, 0);
693 result = vsetq_lane_u32(val1, result, 1);
694 result = vsetq_lane_u32(val2, result, 2);
695 result = vsetq_lane_u32(val3, result, 3);
698 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) + x * output_stride_in_bytes;
700 vst1q_u32(
reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes), result);
706 if (left_over_loop_y)
708 window_in.set(
Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
709 window_in.set(
Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
711 Iterator
input(in, window_in);
712 Iterator output(out, window_out);
717 [&](
const Coordinates &
id)
719 const uint32_t val0 = *(
reinterpret_cast<uint32_t *
>(
input.ptr()));
722 const size_t dst_offset_in_bytes =
id.y() *
sizeof(uint32_t) +
id.x() * output_stride_in_bytes;
724 *(
reinterpret_cast<uint32_t *
>(output.ptr() + dst_offset_in_bytes)) = val0;
729 #endif // __aarch64__
749 const unsigned int num_elems_processed_per_iteration_x = 1;
750 const unsigned int num_elems_processed_per_iteration_y = num_elems_processed(
src->element_size());
761 ICpuKernel::configure(win);
772 "Element size not supported");
775 if (
dst->total_size() != 0)
796 switch (
src->info()->element_size())
815 return "CpuTransposeKernel";