46 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 47 void vector_matrix_multiply_f16(
const ITensor *lhs,
const ITensor *rhs, ITensor *
dst,
const Window &window,
const ThreadInfo &
info,
float alpha)
49 const auto width_matrix_b =
static_cast<int>(dst->info()->dimension(0));
50 const auto in_b_stride =
static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
51 const auto num_elems_vec_a =
static_cast<int>(lhs->info()->dimension(0));
54 const int window_start_x = 32 * info.thread_id;
55 const int window_step_x = 32 * info.num_threads;
56 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
57 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x,
" (window_end_x - window_start_x) must be multiple of window_step_x");
59 Window win_out(window);
70 if(rhs->info()->num_dimensions() >= 3)
77 Iterator ina(lhs, win_a);
78 Iterator inb(rhs, win_b);
79 Iterator out(dst, win_out);
83 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
87 int x = window_start_x;
90 for(; x < (window_end_x - window_step_x); x += window_step_x)
92 if(x > width_matrix_b)
97 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
99 float16x8_t acc0 = vdupq_n_f16(0.f);
100 float16x8_t acc1 = vdupq_n_f16(0.f);
101 float16x8_t acc2 = vdupq_n_f16(0.f);
102 float16x8_t acc3 = vdupq_n_f16(0.f);
104 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
105 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
106 for(; vec_a <= (vec_a_end_addr - 4);)
108 const float16x4_t a0l = vld1_f16(vec_a);
110 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
111 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
112 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
113 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
114 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
115 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
116 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
117 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
128 matrix_b += 2 * in_b_stride;
130 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
131 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
132 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
133 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
134 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
135 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
136 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
137 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
149 matrix_b += 2 * in_b_stride;
152 for(; vec_a < vec_a_end_addr; ++vec_a)
154 const float16_t a0 = *vec_a;
155 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
156 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
157 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
158 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
165 matrix_b += in_b_stride;
177 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
179 vst1q_f16(vec_out + 0, acc0);
180 vst1q_f16(vec_out + 8, acc1);
181 vst1q_f16(vec_out + 16, acc2);
182 vst1q_f16(vec_out + 24, acc3);
185 for(; x < window_end_x; ++x)
187 if(x > width_matrix_b)
192 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
194 float16x4_t vacc = vdup_n_f16(0.f);
196 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
197 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
198 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
200 const float16x4_t a0l = vld1_f16(vec_a);
202 const float16x4_t b_col =
204 *(matrix_b + 0 * in_b_stride),
205 *(matrix_b + 1 * in_b_stride),
206 *(matrix_b + 2 * in_b_stride),
207 *(matrix_b + 3 * in_b_stride),
212 matrix_b += 4 * in_b_stride;
215 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
217 for(; vec_a < vec_a_end_addr; ++vec_a)
219 const float16_t a0 = *vec_a;
220 const float16_t b00 = *matrix_b;
224 matrix_b += in_b_stride;
230 acc *=
static_cast<float16_t
>(alpha);
233 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
242 void vector_matrix_multiply_f32(
const ITensor *lhs,
const ITensor *rhs, ITensor *dst,
const Window &window,
const ThreadInfo &info,
float alpha)
244 const auto width_matrix_b =
static_cast<int>(dst->info()->dimension(0));
245 const auto in_b_stride =
static_cast<int>(rhs->info()->strides_in_bytes()[1] /
data_size_from_type(rhs->info()->data_type()));
246 const auto num_elems_vec_a =
static_cast<int>(lhs->info()->dimension(0));
249 const int window_start_x = 16 * info.thread_id;
250 const int window_step_x = 16 * info.num_threads;
252 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
254 Window win_out(window);
258 Window win_a(window);
265 if(rhs->info()->num_dimensions() >= 3)
272 Iterator ina(lhs, win_a);
273 Iterator inb(rhs, win_b);
274 Iterator out(dst, win_out);
278 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
282 int x = window_start_x;
285 for(; x < (window_end_x - window_step_x); x += window_step_x)
287 if(x > width_matrix_b)
292 float32x4_t acc0 = vdupq_n_f32(0.f);
293 float32x4_t acc1 = vdupq_n_f32(0.f);
294 float32x4_t acc2 = vdupq_n_f32(0.f);
295 float32x4_t acc3 = vdupq_n_f32(0.f);
297 auto vec_a =
reinterpret_cast<const float *
>(ina.ptr());
298 auto matrix_b =
reinterpret_cast<const float *
>(inb.ptr()) + x;
301 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
302 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
303 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
306 auto vec_a_end_addr = vec_a + num_elems_vec_a;
307 for(; vec_a <= (vec_a_end_addr - 4);)
309 float32x2_t a0l = vld1_f32(vec_a);
311 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
312 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
313 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
314 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
316 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
317 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
318 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
319 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
322 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
323 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
324 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
325 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
326 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
329 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
330 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
331 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
332 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
334 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
335 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
336 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
337 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
340 matrix_b += 2 * in_b_stride;
342 a0l = vld1_f32(vec_a);
344 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
345 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
346 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
347 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
349 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
350 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
351 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
352 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
354 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
355 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
356 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
357 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
359 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
360 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
361 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
362 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
365 matrix_b += 2 * in_b_stride;
368 for(; vec_a < vec_a_end_addr; ++vec_a)
370 const float a0 = *vec_a;
372 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
373 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
374 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
375 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
377 acc0 = vmlaq_n_f32(acc0, b00, a0);
378 acc1 = vmlaq_n_f32(acc1, b01, a0);
379 acc2 = vmlaq_n_f32(acc2, b02, a0);
380 acc3 = vmlaq_n_f32(acc3, b03, a0);
382 matrix_b += in_b_stride;
388 acc0 = vmulq_f32(acc0, alpha_f32);
389 acc1 = vmulq_f32(acc1, alpha_f32);
390 acc2 = vmulq_f32(acc2, alpha_f32);
391 acc3 = vmulq_f32(acc3, alpha_f32);
394 const auto vec_out =
reinterpret_cast<float *
>(out.ptr()) + x;
396 vst1q_f32(vec_out + 0, acc0);
397 vst1q_f32(vec_out + 4, acc1);
398 vst1q_f32(vec_out + 8, acc2);
399 vst1q_f32(vec_out + 12, acc3);
403 for(; x < window_end_x; ++x)
405 if(x > width_matrix_b)
410 float32x4_t vacc = vdupq_n_f32(0.f);
412 auto vec_a =
reinterpret_cast<const float *
>(ina.ptr());
413 auto matrix_b =
reinterpret_cast<const float *
>(inb.ptr()) + x;
416 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
417 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
418 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
421 auto vec_a_end_addr = vec_a + num_elems_vec_a;
422 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
424 const float32x4_t a0l = vld1q_f32(vec_a);
426 const float32x4_t b_col =
428 *(matrix_b + 0 * in_b_stride),
429 *(matrix_b + 1 * in_b_stride),
430 *(matrix_b + 2 * in_b_stride),
431 *(matrix_b + 3 * in_b_stride),
435 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
436 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
437 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
438 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
439 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
442 vacc = vmlaq_f32(vacc, b_col, a0l);
444 matrix_b += 4 * in_b_stride;
447 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
449 for(; vec_a < vec_a_end_addr; ++vec_a)
451 const float a0 = *vec_a;
453 const float b00 = *matrix_b;
457 matrix_b += in_b_stride;
466 const auto vec_out =
reinterpret_cast<float *
>(out.ptr()) + x;
474 void matrix_matrix_multiply_f32(
const ITensor *lhs,
const ITensor *rhs, ITensor *dst,
const Window &window,
const ThreadInfo &info,
float alpha)
477 const int out_width =
static_cast<int>(dst->info()->dimension(0));
478 const int out_height =
static_cast<int>(dst->info()->dimension(1));
479 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] /
data_size_from_type(rhs->info()->data_type());
480 const size_t out_stride1 = dst->info()->strides_in_bytes()[1] /
data_size_from_type(dst->info()->data_type());
481 const size_t out_stride2 = out_stride1 * 2;
482 const size_t out_stride3 = out_stride1 * 3;
483 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
486 Window win_a(window);
488 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
493 if(rhs->info()->num_dimensions() >= 3)
499 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
502 Iterator ina(lhs, win_a);
503 Iterator inb(rhs, win_b);
504 Iterator out(dst, window);
508 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
515 auto mtx_a0 =
reinterpret_cast<const float *
>(ina.ptr());
516 auto mtx_b0 =
reinterpret_cast<const float *
>(inb.ptr());
517 auto mtx_b1 = mtx_b0 + in_b_stride;
519 float32x4_t acc00 = vdupq_n_f32(0.f);
520 float32x4_t acc10 = vdupq_n_f32(0.f);
521 float32x4_t acc20 = vdupq_n_f32(0.f);
522 float32x4_t acc30 = vdupq_n_f32(0.f);
524 float32x4_t acc01 = vdupq_n_f32(0.f);
525 float32x4_t acc11 = vdupq_n_f32(0.f);
526 float32x4_t acc21 = vdupq_n_f32(0.f);
527 float32x4_t acc31 = vdupq_n_f32(0.f);
530 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
531 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
532 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
535 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
536 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
538 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
539 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
540 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
541 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
543 float32x4_t b00 = vld1q_f32(mtx_b0);
544 float32x4_t b10 = vld1q_f32(mtx_b1);
545 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
546 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
549 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
550 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
551 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
555 acc00 = vmlaq_f32(acc00, b00, a0);
556 acc10 = vmlaq_f32(acc10, b00, a1);
557 acc20 = vmlaq_f32(acc20, b00, a2);
558 acc30 = vmlaq_f32(acc30, b00, a3);
560 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
561 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
562 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
563 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
566 acc01 = vmlaq_f32(acc01, b10, a0);
567 acc11 = vmlaq_f32(acc11, b10, a1);
568 acc21 = vmlaq_f32(acc21, b10, a2);
569 acc31 = vmlaq_f32(acc31, b10, a3);
572 acc00 = vmlaq_f32(acc00, b01, a4);
573 acc10 = vmlaq_f32(acc10, b01, a5);
574 acc20 = vmlaq_f32(acc20, b01, a6);
575 acc30 = vmlaq_f32(acc30, b01, a7);
578 acc01 = vmlaq_f32(acc01, b11, a4);
579 acc11 = vmlaq_f32(acc11, b11, a5);
580 acc21 = vmlaq_f32(acc21, b11, a6);
581 acc31 = vmlaq_f32(acc31, b11, a7);
587 a0 = vld1q_dup_f32(mtx_a0 + 0);
588 a1 = vld1q_dup_f32(mtx_a0 + 1);
589 a2 = vld1q_dup_f32(mtx_a0 + 2);
590 a3 = vld1q_dup_f32(mtx_a0 + 3);
592 b00 = vld1q_f32(mtx_b0);
593 b10 = vld1q_f32(mtx_b1);
594 b01 = vld1q_f32(mtx_b0 + 4);
595 b11 = vld1q_f32(mtx_b1 + 4);
598 acc00 = vmlaq_f32(acc00, b00, a0);
599 acc10 = vmlaq_f32(acc10, b00, a1);
600 acc20 = vmlaq_f32(acc20, b00, a2);
601 acc30 = vmlaq_f32(acc30, b00, a3);
603 a4 = vld1q_dup_f32(mtx_a0 + 4);
604 a5 = vld1q_dup_f32(mtx_a0 + 5);
605 a6 = vld1q_dup_f32(mtx_a0 + 6);
606 a7 = vld1q_dup_f32(mtx_a0 + 7);
609 acc01 = vmlaq_f32(acc01, b10, a0);
610 acc11 = vmlaq_f32(acc11, b10, a1);
611 acc21 = vmlaq_f32(acc21, b10, a2);
612 acc31 = vmlaq_f32(acc31, b10, a3);
615 acc00 = vmlaq_f32(acc00, b01, a4);
616 acc10 = vmlaq_f32(acc10, b01, a5);
617 acc20 = vmlaq_f32(acc20, b01, a6);
618 acc30 = vmlaq_f32(acc30, b01, a7);
621 acc01 = vmlaq_f32(acc01, b11, a4);
622 acc11 = vmlaq_f32(acc11, b11, a5);
623 acc21 = vmlaq_f32(acc21, b11, a6);
624 acc31 = vmlaq_f32(acc31, b11, a7);
630 a0 = vld1q_dup_f32(mtx_a0 + 0);
631 a1 = vld1q_dup_f32(mtx_a0 + 1);
632 a2 = vld1q_dup_f32(mtx_a0 + 2);
633 a3 = vld1q_dup_f32(mtx_a0 + 3);
634 b00 = vld1q_f32(mtx_b0);
635 b10 = vld1q_f32(mtx_b1);
636 b01 = vld1q_f32(mtx_b0 + 4);
637 b11 = vld1q_f32(mtx_b1 + 4);
640 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
641 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
642 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
646 acc00 = vmlaq_f32(acc00, b00, a0);
647 acc10 = vmlaq_f32(acc10, b00, a1);
648 acc20 = vmlaq_f32(acc20, b00, a2);
649 acc30 = vmlaq_f32(acc30, b00, a3);
651 a4 = vld1q_dup_f32(mtx_a0 + 4);
652 a5 = vld1q_dup_f32(mtx_a0 + 5);
653 a6 = vld1q_dup_f32(mtx_a0 + 6);
654 a7 = vld1q_dup_f32(mtx_a0 + 7);
657 acc01 = vmlaq_f32(acc01, b10, a0);
658 acc11 = vmlaq_f32(acc11, b10, a1);
659 acc21 = vmlaq_f32(acc21, b10, a2);
660 acc31 = vmlaq_f32(acc31, b10, a3);
663 acc00 = vmlaq_f32(acc00, b01, a4);
664 acc10 = vmlaq_f32(acc10, b01, a5);
665 acc20 = vmlaq_f32(acc20, b01, a6);
666 acc30 = vmlaq_f32(acc30, b01, a7);
669 acc01 = vmlaq_f32(acc01, b11, a4);
670 acc11 = vmlaq_f32(acc11, b11, a5);
671 acc21 = vmlaq_f32(acc21, b11, a6);
672 acc31 = vmlaq_f32(acc31, b11, a7);
678 a0 = vld1q_dup_f32(mtx_a0 + 0);
679 a1 = vld1q_dup_f32(mtx_a0 + 1);
680 a2 = vld1q_dup_f32(mtx_a0 + 2);
681 a3 = vld1q_dup_f32(mtx_a0 + 3);
682 b00 = vld1q_f32(mtx_b0);
683 b10 = vld1q_f32(mtx_b1);
684 b01 = vld1q_f32(mtx_b0 + 4);
685 b11 = vld1q_f32(mtx_b1 + 4);
688 acc00 = vmlaq_f32(acc00, b00, a0);
689 acc10 = vmlaq_f32(acc10, b00, a1);
690 acc20 = vmlaq_f32(acc20, b00, a2);
691 acc30 = vmlaq_f32(acc30, b00, a3);
693 a4 = vld1q_dup_f32(mtx_a0 + 4);
694 a5 = vld1q_dup_f32(mtx_a0 + 5);
695 a6 = vld1q_dup_f32(mtx_a0 + 6);
696 a7 = vld1q_dup_f32(mtx_a0 + 7);
699 acc01 = vmlaq_f32(acc01, b10, a0);
700 acc11 = vmlaq_f32(acc11, b10, a1);
701 acc21 = vmlaq_f32(acc21, b10, a2);
702 acc31 = vmlaq_f32(acc31, b10, a3);
705 acc00 = vmlaq_f32(acc00, b01, a4);
706 acc10 = vmlaq_f32(acc10, b01, a5);
707 acc20 = vmlaq_f32(acc20, b01, a6);
708 acc30 = vmlaq_f32(acc30, b01, a7);
711 acc01 = vmlaq_f32(acc01, b11, a4);
712 acc11 = vmlaq_f32(acc11, b11, a5);
713 acc21 = vmlaq_f32(acc21, b11, a6);
714 acc31 = vmlaq_f32(acc31, b11, a7);
721 for(; mtx_b0 < mtx_b0_end_addr;)
723 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
724 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
725 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
726 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
727 float32x4_t b00 = vld1q_f32(mtx_b0);
728 float32x4_t b10 = vld1q_f32(mtx_b1);
731 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
732 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
733 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
736 acc00 = vmlaq_f32(acc00, b00, a0);
737 acc10 = vmlaq_f32(acc10, b00, a1);
738 acc20 = vmlaq_f32(acc20, b00, a2);
739 acc30 = vmlaq_f32(acc30, b00, a3);
742 acc01 = vmlaq_f32(acc01, b10, a0);
743 acc11 = vmlaq_f32(acc11, b10, a1);
744 acc21 = vmlaq_f32(acc21, b10, a2);
745 acc31 = vmlaq_f32(acc31, b10, a3);
755 acc00 = vmulq_f32(acc00, alpha_f32);
756 acc10 = vmulq_f32(acc10, alpha_f32);
757 acc20 = vmulq_f32(acc20, alpha_f32);
758 acc30 = vmulq_f32(acc30, alpha_f32);
759 acc01 = vmulq_f32(acc01, alpha_f32);
760 acc11 = vmulq_f32(acc11, alpha_f32);
761 acc21 = vmulq_f32(acc21, alpha_f32);
762 acc31 = vmulq_f32(acc31, alpha_f32);
765 const auto mtx_out0 =
reinterpret_cast<float *
>(out.ptr());
766 const auto mtx_out1 = mtx_out0 + 4;
768 if(
id.x() < (out_width - 8))
770 vst1q_f32(mtx_out0, acc00);
771 vst1q_f32(mtx_out1, acc01);
772 if(
id.y() + 1 < out_height)
774 vst1q_f32(mtx_out0 + out_stride1, acc10);
775 vst1q_f32(mtx_out1 + out_stride1, acc11);
776 if(
id.y() + 2 < out_height)
778 vst1q_f32(mtx_out0 + out_stride2, acc20);
779 vst1q_f32(mtx_out1 + out_stride2, acc21);
780 if(
id.y() + 3 < out_height)
782 vst1q_f32(mtx_out0 + out_stride3, acc30);
783 vst1q_f32(mtx_out1 + out_stride3, acc31);
788 else if(
id.x() < (out_width - 4))
790 vst1q_f32(mtx_out0, acc00);
791 if(
id.y() + 1 < out_height)
793 vst1q_f32(mtx_out0 + out_stride1, acc10);
794 if(
id.y() + 2 < out_height)
796 vst1q_f32(mtx_out0 + out_stride2, acc20);
797 if(
id.y() + 3 < out_height)
799 vst1q_f32(mtx_out0 + out_stride3, acc30);
804 const int columns_left = out_width -
id.x() - 4;
805 for(
auto x = 0; x < columns_left; ++x)
807 *(mtx_out1 + x) = acc01[x];
808 if(
id.y() + 1 < out_height)
810 *(mtx_out1 + x + out_stride1) = acc11[x];
811 if(
id.y() + 2 < out_height)
813 *(mtx_out1 + x + out_stride2) = acc21[x];
814 if(
id.y() + 3 < out_height)
816 *(mtx_out1 + x + out_stride3) = acc31[x];
825 const int columns_left = out_width -
id.x();
826 for(
int x = 0; x < columns_left; ++x)
828 *(mtx_out0 + x) = acc00[x];
829 if(
id.y() + 1 < out_height)
831 *(mtx_out0 + x + out_stride1) = acc10[x];
832 if(
id.y() + 2 < out_height)
834 *(mtx_out0 + x + out_stride2) = acc20[x];
835 if(
id.y() + 3 < out_height)
837 *(mtx_out0 + x + out_stride3) = acc30[x];
847 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 848 void matrix_matrix_multiply_f16(
const ITensor *lhs,
const ITensor *rhs, ITensor *dst,
const Window &window,
const ThreadInfo &info,
float alpha)
851 const int out_width =
static_cast<int>(dst->info()->dimension(0));
852 const int out_height =
static_cast<int>(dst->info()->dimension(1));
853 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] /
data_size_from_type(rhs->info()->data_type());
854 const size_t out_stride = dst->info()->strides_in_bytes()[1] /
data_size_from_type(dst->info()->data_type());
855 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
858 Window win_a(window);
860 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
865 if(rhs->info()->num_dimensions() >= 3)
870 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
873 Iterator ina(lhs, win_a);
874 Iterator inb(rhs, win_b);
875 Iterator out(dst, window);
879 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
883 const auto *mtx_a0 =
reinterpret_cast<const float16_t *
>(ina.ptr());
884 const auto *mtx_b0 =
reinterpret_cast<const float16_t *
>(inb.ptr());
885 auto *mtx_out =
reinterpret_cast<float16_t *
>(out.ptr());
924 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
926 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
929 const float16x8_t p00 = vld1q_f16(mtx_a0);
930 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
932 const float16x8_t q00 = vld1q_f16(mtx_b0);
933 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
934 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
935 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
961 for(; mtx_b0 < mtx_b0_end_addr;)
964 const float16x4_t p00 = vld1_f16(mtx_a0);
965 const float16x8_t q00 = vld1q_f16(mtx_b0);
978 c.val[0] =
vmulq_f16(c.val[0], alpha_f16);
979 c.val[1] =
vmulq_f16(c.val[1], alpha_f16);
980 c.val[2] =
vmulq_f16(c.val[2], alpha_f16);
981 c.val[3] =
vmulq_f16(c.val[3], alpha_f16);
984 if(
id.x() < (out_width - 8))
986 vst1q_f16(mtx_out, c.val[0]);
987 if(
id.y() + 1 < out_height)
989 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
990 if(
id.y() + 2 < out_height)
992 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
993 if(
id.y() + 3 < out_height)
995 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1003 const int columns_left = out_width -
id.x();
1004 for(
int x = 0; x < columns_left; ++x)
1006 *(mtx_out + x) = c.val[0][x];
1007 if(
id.y() + 1 < out_height)
1009 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1010 if(
id.y() + 2 < out_height)
1012 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1013 if(
id.y() + 3 < out_height)
1015 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1026 inline Status validate_arguments(
const ITensorInfo *lhs,
const ITensorInfo *rhs,
const ITensorInfo *dst,
float alpha,
bool is_interleaved,
const GEMMReshapeInfo &reshape_info)
1038 if(dst->total_size() != 0)
1047 const int m = reshape_info.m();
1048 const int n = reshape_info.n();
1049 const int k = reshape_info.k();
1050 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1051 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1054 TensorShape tensor_shape0{ lhs->tensor_shape() };
1055 tensor_shape0.set(0, k);
1056 tensor_shape0.set(1, m);
1058 const TensorInfo tensor_info0 = lhs->clone()->set_tensor_shape(tensor_shape0);
1064 TensorShape tensor_shape1{ rhs->tensor_shape() };
1065 tensor_shape1.set(0, n);
1066 tensor_shape1.set(1, k);
1068 const TensorInfo tensor_info1 = rhs->clone()->set_tensor_shape(tensor_shape1);
1073 if(dst->total_size() != 0)
1094 tensor_shape.
set(0, is_interleaved ? reshape_info.
n() : rhs->
dimension(0));
1095 tensor_shape.
set(1, is_interleaved ? reshape_info.
m() : lhs->
dimension(1));
1108 const bool is_dst_vector = (dst->
dimension(1) == 1);
1117 constexpr
unsigned int num_elems_processed_per_iteration_x = 8;
1118 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
1127 _func = (is_dst_vector) ? vector_matrix_multiply_f32 : matrix_matrix_multiply_f32;
1130 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 1133 _func = (is_dst_vector) ? vector_matrix_multiply_f16 : matrix_matrix_multiply_f16;
1143 ICPPKernel::configure(win);
1170 return "CpuGemmMatrixMultiplyKernel";
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width=1)
Calculate the transposed 1xW width element shape.
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
bool empty() const
Checks if pack is empty.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
GEMM reshape information class.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Store the tensor's metadata.
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
float16x4_t vadd_f16(float16x4_t, float16x4_t)
static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Static function to check if given info will lead to a valid configuration of CpuGemmMatrixMultiplyKer...
Interface for CPU tensor.
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height=1, bool reinterpret_input_as_3d=false)
Calculate the interleaved shape of an input tensor.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
int n() const
Number of matrix B columns.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Class to describe a number of elements in each dimension.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
const char * name() const override
Name of the kernel.
void configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info=GEMMReshapeInfo())
Initialise the kernel's input and output.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
float16x4_t vmul_f16(float16x4_t, float16x4_t)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Information about executing thread and CPU.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
int m() const
Number of matrix A rows.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Describe a multidimensional execution window.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)