34 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 35 void vector_matrix_multiply_f16(
const ITensor *lhs,
const ITensor *rhs, ITensor *
dst,
const Window &window,
const ThreadInfo &
info,
float alpha)
37 const auto width_matrix_b =
static_cast<int>(dst->info()->dimension(0));
38 const auto in_b_stride =
static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
39 const auto num_elems_vec_a =
static_cast<int>(lhs->info()->dimension(0));
42 const int window_start_x = 32 * info.thread_id;
43 const int window_step_x = 32 * info.num_threads;
44 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
45 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x,
" (window_end_x - window_start_x) must be multiple of window_step_x");
47 Window win_out(window);
58 if(rhs->info()->num_dimensions() >= 3)
65 Iterator ina(lhs, win_a);
66 Iterator inb(rhs, win_b);
67 Iterator out(dst, win_out);
71 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
75 int x = window_start_x;
78 for(; x < (window_end_x - window_step_x); x += window_step_x)
80 if(x > width_matrix_b)
85 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
87 float16x8_t acc0 = vdupq_n_f16(0.f);
88 float16x8_t acc1 = vdupq_n_f16(0.f);
89 float16x8_t acc2 = vdupq_n_f16(0.f);
90 float16x8_t acc3 = vdupq_n_f16(0.f);
92 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
93 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
94 for(; vec_a <= (vec_a_end_addr - 4);)
96 const float16x4_t a0l = vld1_f16(vec_a);
98 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
99 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
100 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
101 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
102 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
103 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
104 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
105 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
116 matrix_b += 2 * in_b_stride;
118 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
119 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
120 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
121 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
122 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
123 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
124 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
125 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
137 matrix_b += 2 * in_b_stride;
140 for(; vec_a < vec_a_end_addr; ++vec_a)
142 const float16_t a0 = *vec_a;
143 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
144 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
145 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
146 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
153 matrix_b += in_b_stride;
165 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
167 vst1q_f16(vec_out + 0, acc0);
168 vst1q_f16(vec_out + 8, acc1);
169 vst1q_f16(vec_out + 16, acc2);
170 vst1q_f16(vec_out + 24, acc3);
173 for(; x < window_end_x; ++x)
175 if(x > width_matrix_b)
180 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
182 float16x4_t vacc = vdup_n_f16(0.f);
184 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
185 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
186 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
188 const float16x4_t a0l = vld1_f16(vec_a);
190 const float16x4_t b_col =
192 *(matrix_b + 0 * in_b_stride),
193 *(matrix_b + 1 * in_b_stride),
194 *(matrix_b + 2 * in_b_stride),
195 *(matrix_b + 3 * in_b_stride),
200 matrix_b += 4 * in_b_stride;
203 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
205 for(; vec_a < vec_a_end_addr; ++vec_a)
207 const float16_t a0 = *vec_a;
208 const float16_t b00 = *matrix_b;
212 matrix_b += in_b_stride;
218 acc *=
static_cast<float16_t
>(alpha);
221 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
232 const auto width_matrix_b =
static_cast<int>(dst->
info()->
dimension(0));
234 const auto num_elems_vec_a =
static_cast<int>(lhs->
info()->
dimension(0));
237 const int window_start_x = 16 * info.
thread_id;
240 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
266 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
270 int x = window_start_x;
273 for(; x < (window_end_x - window_step_x); x += window_step_x)
275 if(x > width_matrix_b)
280 float32x4_t acc0 = vdupq_n_f32(0.f);
281 float32x4_t acc1 = vdupq_n_f32(0.f);
282 float32x4_t acc2 = vdupq_n_f32(0.f);
283 float32x4_t acc3 = vdupq_n_f32(0.f);
285 auto vec_a =
reinterpret_cast<const float *
>(ina.
ptr());
286 auto matrix_b =
reinterpret_cast<const float *
>(inb.
ptr()) + x;
289 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
290 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
291 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
294 auto vec_a_end_addr = vec_a + num_elems_vec_a;
295 for(; vec_a <= (vec_a_end_addr - 4);)
297 float32x2_t a0l = vld1_f32(vec_a);
299 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
300 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
301 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
302 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
304 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
305 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
306 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
307 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
310 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
311 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
312 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
313 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
314 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
317 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
318 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
319 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
320 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
322 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
323 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
324 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
325 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
328 matrix_b += 2 * in_b_stride;
330 a0l = vld1_f32(vec_a);
332 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
333 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
334 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
335 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
337 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
338 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
339 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
340 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
342 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
343 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
344 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
345 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
347 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
348 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
349 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
350 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
353 matrix_b += 2 * in_b_stride;
356 for(; vec_a < vec_a_end_addr; ++vec_a)
358 const float a0 = *vec_a;
360 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
361 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
362 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
363 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
365 acc0 = vmlaq_n_f32(acc0, b00, a0);
366 acc1 = vmlaq_n_f32(acc1, b01, a0);
367 acc2 = vmlaq_n_f32(acc2, b02, a0);
368 acc3 = vmlaq_n_f32(acc3, b03, a0);
370 matrix_b += in_b_stride;
376 acc0 = vmulq_f32(acc0, alpha_f32);
377 acc1 = vmulq_f32(acc1, alpha_f32);
378 acc2 = vmulq_f32(acc2, alpha_f32);
379 acc3 = vmulq_f32(acc3, alpha_f32);
382 const auto vec_out =
reinterpret_cast<float *
>(out.
ptr()) + x;
384 vst1q_f32(vec_out + 0, acc0);
385 vst1q_f32(vec_out + 4, acc1);
386 vst1q_f32(vec_out + 8, acc2);
387 vst1q_f32(vec_out + 12, acc3);
391 for(; x < window_end_x; ++x)
393 if(x > width_matrix_b)
398 float32x4_t vacc = vdupq_n_f32(0.f);
400 auto vec_a =
reinterpret_cast<const float *
>(ina.
ptr());
401 auto matrix_b =
reinterpret_cast<const float *
>(inb.
ptr()) + x;
404 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
405 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
406 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
409 auto vec_a_end_addr = vec_a + num_elems_vec_a;
410 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
412 const float32x4_t a0l = vld1q_f32(vec_a);
414 const float32x4_t b_col =
416 *(matrix_b + 0 * in_b_stride),
417 *(matrix_b + 1 * in_b_stride),
418 *(matrix_b + 2 * in_b_stride),
419 *(matrix_b + 3 * in_b_stride),
423 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
424 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
425 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
426 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
427 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
430 vacc = vmlaq_f32(vacc, b_col, a0l);
432 matrix_b += 4 * in_b_stride;
435 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
437 for(; vec_a < vec_a_end_addr; ++vec_a)
439 const float a0 = *vec_a;
441 const float b00 = *matrix_b;
445 matrix_b += in_b_stride;
454 const auto vec_out =
reinterpret_cast<float *
>(out.
ptr()) + x;
465 const int out_width =
static_cast<int>(dst->
info()->
dimension(0));
466 const int out_height =
static_cast<int>(dst->
info()->
dimension(1));
469 const size_t out_stride2 = out_stride1 * 2;
470 const size_t out_stride3 = out_stride1 * 3;
496 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
503 auto mtx_a0 =
reinterpret_cast<const float *
>(ina.
ptr());
504 auto mtx_b0 =
reinterpret_cast<const float *
>(inb.
ptr());
505 auto mtx_b1 = mtx_b0 + in_b_stride;
507 float32x4_t acc00 = vdupq_n_f32(0.f);
508 float32x4_t acc10 = vdupq_n_f32(0.f);
509 float32x4_t acc20 = vdupq_n_f32(0.f);
510 float32x4_t acc30 = vdupq_n_f32(0.f);
512 float32x4_t acc01 = vdupq_n_f32(0.f);
513 float32x4_t acc11 = vdupq_n_f32(0.f);
514 float32x4_t acc21 = vdupq_n_f32(0.f);
515 float32x4_t acc31 = vdupq_n_f32(0.f);
518 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
519 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
520 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
523 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
524 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
526 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
527 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
528 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
529 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
531 float32x4_t b00 = vld1q_f32(mtx_b0);
532 float32x4_t b10 = vld1q_f32(mtx_b1);
533 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
534 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
537 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
538 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
539 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
543 acc00 = vmlaq_f32(acc00, b00, a0);
544 acc10 = vmlaq_f32(acc10, b00, a1);
545 acc20 = vmlaq_f32(acc20, b00, a2);
546 acc30 = vmlaq_f32(acc30, b00, a3);
548 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
549 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
550 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
551 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
554 acc01 = vmlaq_f32(acc01, b10, a0);
555 acc11 = vmlaq_f32(acc11, b10, a1);
556 acc21 = vmlaq_f32(acc21, b10, a2);
557 acc31 = vmlaq_f32(acc31, b10, a3);
560 acc00 = vmlaq_f32(acc00, b01, a4);
561 acc10 = vmlaq_f32(acc10, b01, a5);
562 acc20 = vmlaq_f32(acc20, b01, a6);
563 acc30 = vmlaq_f32(acc30, b01, a7);
566 acc01 = vmlaq_f32(acc01, b11, a4);
567 acc11 = vmlaq_f32(acc11, b11, a5);
568 acc21 = vmlaq_f32(acc21, b11, a6);
569 acc31 = vmlaq_f32(acc31, b11, a7);
575 a0 = vld1q_dup_f32(mtx_a0 + 0);
576 a1 = vld1q_dup_f32(mtx_a0 + 1);
577 a2 = vld1q_dup_f32(mtx_a0 + 2);
578 a3 = vld1q_dup_f32(mtx_a0 + 3);
580 b00 = vld1q_f32(mtx_b0);
581 b10 = vld1q_f32(mtx_b1);
582 b01 = vld1q_f32(mtx_b0 + 4);
583 b11 = vld1q_f32(mtx_b1 + 4);
586 acc00 = vmlaq_f32(acc00, b00, a0);
587 acc10 = vmlaq_f32(acc10, b00, a1);
588 acc20 = vmlaq_f32(acc20, b00, a2);
589 acc30 = vmlaq_f32(acc30, b00, a3);
591 a4 = vld1q_dup_f32(mtx_a0 + 4);
592 a5 = vld1q_dup_f32(mtx_a0 + 5);
593 a6 = vld1q_dup_f32(mtx_a0 + 6);
594 a7 = vld1q_dup_f32(mtx_a0 + 7);
597 acc01 = vmlaq_f32(acc01, b10, a0);
598 acc11 = vmlaq_f32(acc11, b10, a1);
599 acc21 = vmlaq_f32(acc21, b10, a2);
600 acc31 = vmlaq_f32(acc31, b10, a3);
603 acc00 = vmlaq_f32(acc00, b01, a4);
604 acc10 = vmlaq_f32(acc10, b01, a5);
605 acc20 = vmlaq_f32(acc20, b01, a6);
606 acc30 = vmlaq_f32(acc30, b01, a7);
609 acc01 = vmlaq_f32(acc01, b11, a4);
610 acc11 = vmlaq_f32(acc11, b11, a5);
611 acc21 = vmlaq_f32(acc21, b11, a6);
612 acc31 = vmlaq_f32(acc31, b11, a7);
618 a0 = vld1q_dup_f32(mtx_a0 + 0);
619 a1 = vld1q_dup_f32(mtx_a0 + 1);
620 a2 = vld1q_dup_f32(mtx_a0 + 2);
621 a3 = vld1q_dup_f32(mtx_a0 + 3);
622 b00 = vld1q_f32(mtx_b0);
623 b10 = vld1q_f32(mtx_b1);
624 b01 = vld1q_f32(mtx_b0 + 4);
625 b11 = vld1q_f32(mtx_b1 + 4);
628 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
629 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
630 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
634 acc00 = vmlaq_f32(acc00, b00, a0);
635 acc10 = vmlaq_f32(acc10, b00, a1);
636 acc20 = vmlaq_f32(acc20, b00, a2);
637 acc30 = vmlaq_f32(acc30, b00, a3);
639 a4 = vld1q_dup_f32(mtx_a0 + 4);
640 a5 = vld1q_dup_f32(mtx_a0 + 5);
641 a6 = vld1q_dup_f32(mtx_a0 + 6);
642 a7 = vld1q_dup_f32(mtx_a0 + 7);
645 acc01 = vmlaq_f32(acc01, b10, a0);
646 acc11 = vmlaq_f32(acc11, b10, a1);
647 acc21 = vmlaq_f32(acc21, b10, a2);
648 acc31 = vmlaq_f32(acc31, b10, a3);
651 acc00 = vmlaq_f32(acc00, b01, a4);
652 acc10 = vmlaq_f32(acc10, b01, a5);
653 acc20 = vmlaq_f32(acc20, b01, a6);
654 acc30 = vmlaq_f32(acc30, b01, a7);
657 acc01 = vmlaq_f32(acc01, b11, a4);
658 acc11 = vmlaq_f32(acc11, b11, a5);
659 acc21 = vmlaq_f32(acc21, b11, a6);
660 acc31 = vmlaq_f32(acc31, b11, a7);
666 a0 = vld1q_dup_f32(mtx_a0 + 0);
667 a1 = vld1q_dup_f32(mtx_a0 + 1);
668 a2 = vld1q_dup_f32(mtx_a0 + 2);
669 a3 = vld1q_dup_f32(mtx_a0 + 3);
670 b00 = vld1q_f32(mtx_b0);
671 b10 = vld1q_f32(mtx_b1);
672 b01 = vld1q_f32(mtx_b0 + 4);
673 b11 = vld1q_f32(mtx_b1 + 4);
676 acc00 = vmlaq_f32(acc00, b00, a0);
677 acc10 = vmlaq_f32(acc10, b00, a1);
678 acc20 = vmlaq_f32(acc20, b00, a2);
679 acc30 = vmlaq_f32(acc30, b00, a3);
681 a4 = vld1q_dup_f32(mtx_a0 + 4);
682 a5 = vld1q_dup_f32(mtx_a0 + 5);
683 a6 = vld1q_dup_f32(mtx_a0 + 6);
684 a7 = vld1q_dup_f32(mtx_a0 + 7);
687 acc01 = vmlaq_f32(acc01, b10, a0);
688 acc11 = vmlaq_f32(acc11, b10, a1);
689 acc21 = vmlaq_f32(acc21, b10, a2);
690 acc31 = vmlaq_f32(acc31, b10, a3);
693 acc00 = vmlaq_f32(acc00, b01, a4);
694 acc10 = vmlaq_f32(acc10, b01, a5);
695 acc20 = vmlaq_f32(acc20, b01, a6);
696 acc30 = vmlaq_f32(acc30, b01, a7);
699 acc01 = vmlaq_f32(acc01, b11, a4);
700 acc11 = vmlaq_f32(acc11, b11, a5);
701 acc21 = vmlaq_f32(acc21, b11, a6);
702 acc31 = vmlaq_f32(acc31, b11, a7);
709 for(; mtx_b0 < mtx_b0_end_addr;)
711 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
712 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
713 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
714 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
715 float32x4_t b00 = vld1q_f32(mtx_b0);
716 float32x4_t b10 = vld1q_f32(mtx_b1);
719 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
720 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
721 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
724 acc00 = vmlaq_f32(acc00, b00, a0);
725 acc10 = vmlaq_f32(acc10, b00, a1);
726 acc20 = vmlaq_f32(acc20, b00, a2);
727 acc30 = vmlaq_f32(acc30, b00, a3);
730 acc01 = vmlaq_f32(acc01, b10, a0);
731 acc11 = vmlaq_f32(acc11, b10, a1);
732 acc21 = vmlaq_f32(acc21, b10, a2);
733 acc31 = vmlaq_f32(acc31, b10, a3);
743 acc00 = vmulq_f32(acc00, alpha_f32);
744 acc10 = vmulq_f32(acc10, alpha_f32);
745 acc20 = vmulq_f32(acc20, alpha_f32);
746 acc30 = vmulq_f32(acc30, alpha_f32);
747 acc01 = vmulq_f32(acc01, alpha_f32);
748 acc11 = vmulq_f32(acc11, alpha_f32);
749 acc21 = vmulq_f32(acc21, alpha_f32);
750 acc31 = vmulq_f32(acc31, alpha_f32);
753 const auto mtx_out0 =
reinterpret_cast<float *
>(out.
ptr());
754 const auto mtx_out1 = mtx_out0 + 4;
756 if(
id.x() < (out_width - 8))
758 vst1q_f32(mtx_out0, acc00);
759 vst1q_f32(mtx_out1, acc01);
760 if(
id.y() + 1 < out_height)
762 vst1q_f32(mtx_out0 + out_stride1, acc10);
763 vst1q_f32(mtx_out1 + out_stride1, acc11);
764 if(
id.y() + 2 < out_height)
766 vst1q_f32(mtx_out0 + out_stride2, acc20);
767 vst1q_f32(mtx_out1 + out_stride2, acc21);
768 if(
id.y() + 3 < out_height)
770 vst1q_f32(mtx_out0 + out_stride3, acc30);
771 vst1q_f32(mtx_out1 + out_stride3, acc31);
776 else if(
id.x() < (out_width - 4))
778 vst1q_f32(mtx_out0, acc00);
779 if(
id.y() + 1 < out_height)
781 vst1q_f32(mtx_out0 + out_stride1, acc10);
782 if(
id.y() + 2 < out_height)
784 vst1q_f32(mtx_out0 + out_stride2, acc20);
785 if(
id.y() + 3 < out_height)
787 vst1q_f32(mtx_out0 + out_stride3, acc30);
792 const int columns_left = out_width -
id.x() - 4;
793 for(
auto x = 0; x < columns_left; ++x)
795 *(mtx_out1 + x) = acc01[x];
796 if(
id.y() + 1 < out_height)
798 *(mtx_out1 + x + out_stride1) = acc11[x];
799 if(
id.y() + 2 < out_height)
801 *(mtx_out1 + x + out_stride2) = acc21[x];
802 if(
id.y() + 3 < out_height)
804 *(mtx_out1 + x + out_stride3) = acc31[x];
813 const int columns_left = out_width -
id.x();
814 for(
int x = 0; x < columns_left; ++x)
816 *(mtx_out0 + x) = acc00[x];
817 if(
id.y() + 1 < out_height)
819 *(mtx_out0 + x + out_stride1) = acc10[x];
820 if(
id.y() + 2 < out_height)
822 *(mtx_out0 + x + out_stride2) = acc20[x];
823 if(
id.y() + 3 < out_height)
825 *(mtx_out0 + x + out_stride3) = acc30[x];
835 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 839 const int out_width =
static_cast<int>(dst->
info()->
dimension(0));
840 const int out_height =
static_cast<int>(dst->
info()->
dimension(1));
867 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
871 const auto *mtx_a0 =
reinterpret_cast<const float16_t *
>(ina.
ptr());
872 const auto *mtx_b0 =
reinterpret_cast<const float16_t *
>(inb.
ptr());
873 auto *mtx_out =
reinterpret_cast<float16_t *
>(out.
ptr());
912 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
914 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
917 const float16x8_t p00 = vld1q_f16(mtx_a0);
918 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
920 const float16x8_t q00 = vld1q_f16(mtx_b0);
921 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
922 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
923 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
949 for(; mtx_b0 < mtx_b0_end_addr;)
952 const float16x4_t p00 = vld1_f16(mtx_a0);
953 const float16x8_t q00 = vld1q_f16(mtx_b0);
966 c.val[0] =
vmulq_f16(c.val[0], alpha_f16);
967 c.val[1] =
vmulq_f16(c.val[1], alpha_f16);
968 c.val[2] =
vmulq_f16(c.val[2], alpha_f16);
969 c.val[3] =
vmulq_f16(c.val[3], alpha_f16);
972 if(
id.x() < (out_width - 8))
974 vst1q_f16(mtx_out, c.val[0]);
975 if(
id.y() + 1 < out_height)
977 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
978 if(
id.y() + 2 < out_height)
980 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
981 if(
id.y() + 3 < out_height)
983 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
991 const int columns_left = out_width -
id.x();
992 for(
int x = 0; x < columns_left; ++x)
994 *(mtx_out + x) = c.val[0][x];
995 if(
id.y() + 1 < out_height)
997 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
998 if(
id.y() + 2 < out_height)
1000 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1001 if(
id.y() + 3 < out_height)
1003 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
Describe one of the image's dimensions with a start, end and step.
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Interface for CPU tensor.
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
Copyright (c) 2017-2022 Arm Limited.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
float16x4_t vmul_f16(float16x4_t, float16x4_t)
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
constexpr int end() const
Return the end of the dimension.
Iterator updated by execute_window_loop for each window element.
constexpr int start() const
Return the start of the dimension.
Describe a multidimensional execution window.
constexpr const Dimension & x() const
Alias to access the first dimension of the window.