48 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 49 void vector_matrix_multiply_f16(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
const ThreadInfo &
info,
float alpha)
51 const auto width_matrix_b =
static_cast<int>(output->info()->dimension(0));
52 const auto in_b_stride =
static_cast<int>(input1->info()->strides_in_bytes()[1] / input1->info()->element_size());
53 const auto num_elems_vec_a =
static_cast<int>(input0->info()->dimension(0));
56 const int window_start_x = 32 * info.thread_id;
57 const int window_step_x = 32 * info.num_threads;
58 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
59 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x,
" (window_end_x - window_start_x) must be multiple of window_step_x");
61 Window win_out(window);
72 if(input1->info()->num_dimensions() >= 3)
79 Iterator ina(input0, win_a);
80 Iterator inb(input1, win_b);
81 Iterator out(output, win_out);
85 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
89 int x = window_start_x;
92 for(; x < (window_end_x - window_step_x); x += window_step_x)
94 if(x > width_matrix_b)
99 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
101 float16x8_t acc0 = vdupq_n_f16(0.f);
102 float16x8_t acc1 = vdupq_n_f16(0.f);
103 float16x8_t acc2 = vdupq_n_f16(0.f);
104 float16x8_t acc3 = vdupq_n_f16(0.f);
106 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
107 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
108 for(; vec_a <= (vec_a_end_addr - 4);)
110 const float16x4_t a0l = vld1_f16(vec_a);
112 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
113 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
114 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
115 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
116 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
117 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
118 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
119 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
130 matrix_b += 2 * in_b_stride;
132 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
133 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
134 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
135 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
136 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
137 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
138 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
139 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
151 matrix_b += 2 * in_b_stride;
154 for(; vec_a < vec_a_end_addr; ++vec_a)
156 const float16_t a0 = *vec_a;
157 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
158 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
159 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
160 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
167 matrix_b += in_b_stride;
179 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
181 vst1q_f16(vec_out + 0, acc0);
182 vst1q_f16(vec_out + 8, acc1);
183 vst1q_f16(vec_out + 16, acc2);
184 vst1q_f16(vec_out + 24, acc3);
187 for(; x < window_end_x; ++x)
189 if(x > width_matrix_b)
194 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
196 float16x4_t vacc = vdup_n_f16(0.f);
198 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
199 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
200 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
202 const float16x4_t a0l = vld1_f16(vec_a);
204 const float16x4_t b_col =
206 *(matrix_b + 0 * in_b_stride),
207 *(matrix_b + 1 * in_b_stride),
208 *(matrix_b + 2 * in_b_stride),
209 *(matrix_b + 3 * in_b_stride),
214 matrix_b += 4 * in_b_stride;
217 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
219 for(; vec_a < vec_a_end_addr; ++vec_a)
221 const float16_t a0 = *vec_a;
222 const float16_t b00 = *matrix_b;
226 matrix_b += in_b_stride;
232 acc *=
static_cast<float16_t
>(alpha);
235 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
244 void vector_matrix_multiply_f32(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
const ThreadInfo &info,
float alpha)
246 const auto width_matrix_b =
static_cast<int>(output->info()->dimension(0));
247 const auto in_b_stride =
static_cast<int>(input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type()));
248 const auto num_elems_vec_a =
static_cast<int>(input0->info()->dimension(0));
251 const int window_start_x = 16 * info.thread_id;
252 const int window_step_x = 16 * info.num_threads;
254 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
256 Window win_out(window);
260 Window win_a(window);
267 if(input1->info()->num_dimensions() >= 3)
274 Iterator ina(input0, win_a);
275 Iterator inb(input1, win_b);
276 Iterator out(output, win_out);
280 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
284 int x = window_start_x;
287 for(; x < (window_end_x - window_step_x); x += window_step_x)
289 if(x > width_matrix_b)
294 float32x4_t acc0 = vdupq_n_f32(0.f);
295 float32x4_t acc1 = vdupq_n_f32(0.f);
296 float32x4_t acc2 = vdupq_n_f32(0.f);
297 float32x4_t acc3 = vdupq_n_f32(0.f);
299 auto vec_a =
reinterpret_cast<const float *
>(ina.ptr());
300 auto matrix_b =
reinterpret_cast<const float *
>(inb.ptr()) + x;
303 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
304 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
305 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
308 auto vec_a_end_addr = vec_a + num_elems_vec_a;
309 for(; vec_a <= (vec_a_end_addr - 4);)
311 float32x2_t a0l = vld1_f32(vec_a);
313 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
314 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
315 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
316 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
318 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
319 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
320 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
321 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
324 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
325 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
326 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
327 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
328 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
331 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
332 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
333 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
334 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
336 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
337 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
338 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
339 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
342 matrix_b += 2 * in_b_stride;
344 a0l = vld1_f32(vec_a);
346 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
347 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
348 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
349 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
351 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
352 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
353 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
354 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
356 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
357 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
358 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
359 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
361 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
362 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
363 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
364 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
367 matrix_b += 2 * in_b_stride;
370 for(; vec_a < vec_a_end_addr; ++vec_a)
372 const float a0 = *vec_a;
374 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
375 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
376 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
377 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
379 acc0 = vmlaq_n_f32(acc0, b00, a0);
380 acc1 = vmlaq_n_f32(acc1, b01, a0);
381 acc2 = vmlaq_n_f32(acc2, b02, a0);
382 acc3 = vmlaq_n_f32(acc3, b03, a0);
384 matrix_b += in_b_stride;
390 acc0 = vmulq_f32(acc0, alpha_f32);
391 acc1 = vmulq_f32(acc1, alpha_f32);
392 acc2 = vmulq_f32(acc2, alpha_f32);
393 acc3 = vmulq_f32(acc3, alpha_f32);
396 const auto vec_out =
reinterpret_cast<float *
>(out.ptr()) + x;
398 vst1q_f32(vec_out + 0, acc0);
399 vst1q_f32(vec_out + 4, acc1);
400 vst1q_f32(vec_out + 8, acc2);
401 vst1q_f32(vec_out + 12, acc3);
405 for(; x < window_end_x; ++x)
407 if(x > width_matrix_b)
412 float32x4_t vacc = vdupq_n_f32(0.f);
414 auto vec_a =
reinterpret_cast<const float *
>(ina.ptr());
415 auto matrix_b =
reinterpret_cast<const float *
>(inb.ptr()) + x;
418 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
419 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
420 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
423 auto vec_a_end_addr = vec_a + num_elems_vec_a;
424 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
426 const float32x4_t a0l = vld1q_f32(vec_a);
428 const float32x4_t b_col =
430 *(matrix_b + 0 * in_b_stride),
431 *(matrix_b + 1 * in_b_stride),
432 *(matrix_b + 2 * in_b_stride),
433 *(matrix_b + 3 * in_b_stride),
437 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
438 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
439 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
440 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
441 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
444 vacc = vmlaq_f32(vacc, b_col, a0l);
446 matrix_b += 4 * in_b_stride;
449 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
451 for(; vec_a < vec_a_end_addr; ++vec_a)
453 const float a0 = *vec_a;
455 const float b00 = *matrix_b;
459 matrix_b += in_b_stride;
468 const auto vec_out =
reinterpret_cast<float *
>(out.ptr()) + x;
476 void matrix_matrix_multiply_f32(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
float alpha)
478 const int out_width =
static_cast<int>(output->info()->dimension(0));
479 const int out_height =
static_cast<int>(output->info()->dimension(1));
480 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type());
481 const size_t out_stride1 = output->info()->strides_in_bytes()[1] /
data_size_from_type(output->info()->data_type());
482 const size_t out_stride2 = out_stride1 * 2;
483 const size_t out_stride3 = out_stride1 * 3;
484 const int num_elems_matrix_b_x = input1->info()->dimension(0);
487 Window win_a(window);
489 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
494 if(input1->info()->num_dimensions() >= 3)
500 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
503 Iterator ina(input0, win_a);
504 Iterator inb(input1, win_b);
505 Iterator out(output, window);
509 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
516 auto mtx_a0 =
reinterpret_cast<const float *
>(ina.ptr());
517 auto mtx_b0 =
reinterpret_cast<const float *
>(inb.ptr());
518 auto mtx_b1 = mtx_b0 + in_b_stride;
520 float32x4_t acc00 = vdupq_n_f32(0.f);
521 float32x4_t acc10 = vdupq_n_f32(0.f);
522 float32x4_t acc20 = vdupq_n_f32(0.f);
523 float32x4_t acc30 = vdupq_n_f32(0.f);
525 float32x4_t acc01 = vdupq_n_f32(0.f);
526 float32x4_t acc11 = vdupq_n_f32(0.f);
527 float32x4_t acc21 = vdupq_n_f32(0.f);
528 float32x4_t acc31 = vdupq_n_f32(0.f);
531 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
532 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
533 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
536 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
537 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
539 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
540 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
541 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
542 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
544 float32x4_t b00 = vld1q_f32(mtx_b0);
545 float32x4_t b10 = vld1q_f32(mtx_b1);
546 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
547 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
550 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
551 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
552 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
556 acc00 = vmlaq_f32(acc00, b00, a0);
557 acc10 = vmlaq_f32(acc10, b00, a1);
558 acc20 = vmlaq_f32(acc20, b00, a2);
559 acc30 = vmlaq_f32(acc30, b00, a3);
561 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
562 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
563 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
564 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
567 acc01 = vmlaq_f32(acc01, b10, a0);
568 acc11 = vmlaq_f32(acc11, b10, a1);
569 acc21 = vmlaq_f32(acc21, b10, a2);
570 acc31 = vmlaq_f32(acc31, b10, a3);
573 acc00 = vmlaq_f32(acc00, b01, a4);
574 acc10 = vmlaq_f32(acc10, b01, a5);
575 acc20 = vmlaq_f32(acc20, b01, a6);
576 acc30 = vmlaq_f32(acc30, b01, a7);
579 acc01 = vmlaq_f32(acc01, b11, a4);
580 acc11 = vmlaq_f32(acc11, b11, a5);
581 acc21 = vmlaq_f32(acc21, b11, a6);
582 acc31 = vmlaq_f32(acc31, b11, a7);
588 a0 = vld1q_dup_f32(mtx_a0 + 0);
589 a1 = vld1q_dup_f32(mtx_a0 + 1);
590 a2 = vld1q_dup_f32(mtx_a0 + 2);
591 a3 = vld1q_dup_f32(mtx_a0 + 3);
593 b00 = vld1q_f32(mtx_b0);
594 b10 = vld1q_f32(mtx_b1);
595 b01 = vld1q_f32(mtx_b0 + 4);
596 b11 = vld1q_f32(mtx_b1 + 4);
599 acc00 = vmlaq_f32(acc00, b00, a0);
600 acc10 = vmlaq_f32(acc10, b00, a1);
601 acc20 = vmlaq_f32(acc20, b00, a2);
602 acc30 = vmlaq_f32(acc30, b00, a3);
604 a4 = vld1q_dup_f32(mtx_a0 + 4);
605 a5 = vld1q_dup_f32(mtx_a0 + 5);
606 a6 = vld1q_dup_f32(mtx_a0 + 6);
607 a7 = vld1q_dup_f32(mtx_a0 + 7);
610 acc01 = vmlaq_f32(acc01, b10, a0);
611 acc11 = vmlaq_f32(acc11, b10, a1);
612 acc21 = vmlaq_f32(acc21, b10, a2);
613 acc31 = vmlaq_f32(acc31, b10, a3);
616 acc00 = vmlaq_f32(acc00, b01, a4);
617 acc10 = vmlaq_f32(acc10, b01, a5);
618 acc20 = vmlaq_f32(acc20, b01, a6);
619 acc30 = vmlaq_f32(acc30, b01, a7);
622 acc01 = vmlaq_f32(acc01, b11, a4);
623 acc11 = vmlaq_f32(acc11, b11, a5);
624 acc21 = vmlaq_f32(acc21, b11, a6);
625 acc31 = vmlaq_f32(acc31, b11, a7);
631 a0 = vld1q_dup_f32(mtx_a0 + 0);
632 a1 = vld1q_dup_f32(mtx_a0 + 1);
633 a2 = vld1q_dup_f32(mtx_a0 + 2);
634 a3 = vld1q_dup_f32(mtx_a0 + 3);
635 b00 = vld1q_f32(mtx_b0);
636 b10 = vld1q_f32(mtx_b1);
637 b01 = vld1q_f32(mtx_b0 + 4);
638 b11 = vld1q_f32(mtx_b1 + 4);
641 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
642 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
643 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
647 acc00 = vmlaq_f32(acc00, b00, a0);
648 acc10 = vmlaq_f32(acc10, b00, a1);
649 acc20 = vmlaq_f32(acc20, b00, a2);
650 acc30 = vmlaq_f32(acc30, b00, a3);
652 a4 = vld1q_dup_f32(mtx_a0 + 4);
653 a5 = vld1q_dup_f32(mtx_a0 + 5);
654 a6 = vld1q_dup_f32(mtx_a0 + 6);
655 a7 = vld1q_dup_f32(mtx_a0 + 7);
658 acc01 = vmlaq_f32(acc01, b10, a0);
659 acc11 = vmlaq_f32(acc11, b10, a1);
660 acc21 = vmlaq_f32(acc21, b10, a2);
661 acc31 = vmlaq_f32(acc31, b10, a3);
664 acc00 = vmlaq_f32(acc00, b01, a4);
665 acc10 = vmlaq_f32(acc10, b01, a5);
666 acc20 = vmlaq_f32(acc20, b01, a6);
667 acc30 = vmlaq_f32(acc30, b01, a7);
670 acc01 = vmlaq_f32(acc01, b11, a4);
671 acc11 = vmlaq_f32(acc11, b11, a5);
672 acc21 = vmlaq_f32(acc21, b11, a6);
673 acc31 = vmlaq_f32(acc31, b11, a7);
679 a0 = vld1q_dup_f32(mtx_a0 + 0);
680 a1 = vld1q_dup_f32(mtx_a0 + 1);
681 a2 = vld1q_dup_f32(mtx_a0 + 2);
682 a3 = vld1q_dup_f32(mtx_a0 + 3);
683 b00 = vld1q_f32(mtx_b0);
684 b10 = vld1q_f32(mtx_b1);
685 b01 = vld1q_f32(mtx_b0 + 4);
686 b11 = vld1q_f32(mtx_b1 + 4);
689 acc00 = vmlaq_f32(acc00, b00, a0);
690 acc10 = vmlaq_f32(acc10, b00, a1);
691 acc20 = vmlaq_f32(acc20, b00, a2);
692 acc30 = vmlaq_f32(acc30, b00, a3);
694 a4 = vld1q_dup_f32(mtx_a0 + 4);
695 a5 = vld1q_dup_f32(mtx_a0 + 5);
696 a6 = vld1q_dup_f32(mtx_a0 + 6);
697 a7 = vld1q_dup_f32(mtx_a0 + 7);
700 acc01 = vmlaq_f32(acc01, b10, a0);
701 acc11 = vmlaq_f32(acc11, b10, a1);
702 acc21 = vmlaq_f32(acc21, b10, a2);
703 acc31 = vmlaq_f32(acc31, b10, a3);
706 acc00 = vmlaq_f32(acc00, b01, a4);
707 acc10 = vmlaq_f32(acc10, b01, a5);
708 acc20 = vmlaq_f32(acc20, b01, a6);
709 acc30 = vmlaq_f32(acc30, b01, a7);
712 acc01 = vmlaq_f32(acc01, b11, a4);
713 acc11 = vmlaq_f32(acc11, b11, a5);
714 acc21 = vmlaq_f32(acc21, b11, a6);
715 acc31 = vmlaq_f32(acc31, b11, a7);
722 for(; mtx_b0 < mtx_b0_end_addr;)
724 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
725 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
726 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
727 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
728 float32x4_t b00 = vld1q_f32(mtx_b0);
729 float32x4_t b10 = vld1q_f32(mtx_b1);
732 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
733 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
734 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
737 acc00 = vmlaq_f32(acc00, b00, a0);
738 acc10 = vmlaq_f32(acc10, b00, a1);
739 acc20 = vmlaq_f32(acc20, b00, a2);
740 acc30 = vmlaq_f32(acc30, b00, a3);
743 acc01 = vmlaq_f32(acc01, b10, a0);
744 acc11 = vmlaq_f32(acc11, b10, a1);
745 acc21 = vmlaq_f32(acc21, b10, a2);
746 acc31 = vmlaq_f32(acc31, b10, a3);
756 acc00 = vmulq_f32(acc00, alpha_f32);
757 acc10 = vmulq_f32(acc10, alpha_f32);
758 acc20 = vmulq_f32(acc20, alpha_f32);
759 acc30 = vmulq_f32(acc30, alpha_f32);
760 acc01 = vmulq_f32(acc01, alpha_f32);
761 acc11 = vmulq_f32(acc11, alpha_f32);
762 acc21 = vmulq_f32(acc21, alpha_f32);
763 acc31 = vmulq_f32(acc31, alpha_f32);
766 const auto mtx_out0 =
reinterpret_cast<float *
>(out.ptr());
767 const auto mtx_out1 = mtx_out0 + 4;
769 if(
id.x() < (out_width - 8))
771 vst1q_f32(mtx_out0, acc00);
772 vst1q_f32(mtx_out1, acc01);
773 if(
id.y() + 1 < out_height)
775 vst1q_f32(mtx_out0 + out_stride1, acc10);
776 vst1q_f32(mtx_out1 + out_stride1, acc11);
777 if(
id.y() + 2 < out_height)
779 vst1q_f32(mtx_out0 + out_stride2, acc20);
780 vst1q_f32(mtx_out1 + out_stride2, acc21);
781 if(
id.y() + 3 < out_height)
783 vst1q_f32(mtx_out0 + out_stride3, acc30);
784 vst1q_f32(mtx_out1 + out_stride3, acc31);
789 else if(
id.x() < (out_width - 4))
791 vst1q_f32(mtx_out0, acc00);
792 if(
id.y() + 1 < out_height)
794 vst1q_f32(mtx_out0 + out_stride1, acc10);
795 if(
id.y() + 2 < out_height)
797 vst1q_f32(mtx_out0 + out_stride2, acc20);
798 if(
id.y() + 3 < out_height)
800 vst1q_f32(mtx_out0 + out_stride3, acc30);
805 const int columns_left = out_width -
id.x() - 4;
806 for(
auto x = 0; x < columns_left; ++x)
808 *(mtx_out1 + x) = acc01[x];
809 if(
id.y() + 1 < out_height)
811 *(mtx_out1 + x + out_stride1) = acc11[x];
812 if(
id.y() + 2 < out_height)
814 *(mtx_out1 + x + out_stride2) = acc21[x];
815 if(
id.y() + 3 < out_height)
817 *(mtx_out1 + x + out_stride3) = acc31[x];
826 const int columns_left = out_width -
id.x();
827 for(
int x = 0; x < columns_left; ++x)
829 *(mtx_out0 + x) = acc00[x];
830 if(
id.y() + 1 < out_height)
832 *(mtx_out0 + x + out_stride1) = acc10[x];
833 if(
id.y() + 2 < out_height)
835 *(mtx_out0 + x + out_stride2) = acc20[x];
836 if(
id.y() + 3 < out_height)
838 *(mtx_out0 + x + out_stride3) = acc30[x];
848 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 849 void matrix_matrix_multiply_f16(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
float alpha)
851 const int out_width =
static_cast<int>(output->info()->dimension(0));
852 const int out_height =
static_cast<int>(output->info()->dimension(1));
853 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type());
854 const size_t out_stride = output->info()->strides_in_bytes()[1] /
data_size_from_type(output->info()->data_type());
855 const int num_elems_matrix_b_x = input1->info()->dimension(0);
858 Window win_a(window);
860 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
865 if(input1->info()->num_dimensions() >= 3)
870 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
873 Iterator ina(input0, win_a);
874 Iterator inb(input1, win_b);
875 Iterator out(output, window);
879 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
883 const auto *mtx_a0 =
reinterpret_cast<const float16_t *
>(ina.ptr());
884 const auto *mtx_b0 =
reinterpret_cast<const float16_t *
>(inb.ptr());
885 auto *mtx_out =
reinterpret_cast<float16_t *
>(out.ptr());
924 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
926 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
929 const float16x8_t p00 = vld1q_f16(mtx_a0);
930 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
932 const float16x8_t q00 = vld1q_f16(mtx_b0);
933 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
934 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
935 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
961 for(; mtx_b0 < mtx_b0_end_addr;)
964 const float16x4_t p00 = vld1_f16(mtx_a0);
965 const float16x8_t q00 = vld1q_f16(mtx_b0);
978 c.val[0] =
vmulq_f16(c.val[0], alpha_f16);
979 c.val[1] =
vmulq_f16(c.val[1], alpha_f16);
980 c.val[2] =
vmulq_f16(c.val[2], alpha_f16);
981 c.val[3] =
vmulq_f16(c.val[3], alpha_f16);
984 if(
id.x() < (out_width - 8))
986 vst1q_f16(mtx_out, c.val[0]);
987 if(
id.y() + 1 < out_height)
989 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
990 if(
id.y() + 2 < out_height)
992 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
993 if(
id.y() + 3 < out_height)
995 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1003 const int columns_left = out_width -
id.x();
1004 for(
int x = 0; x < columns_left; ++x)
1006 *(mtx_out + x) = c.val[0][x];
1007 if(
id.y() + 1 < out_height)
1009 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1010 if(
id.y() + 2 < out_height)
1012 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1013 if(
id.y() + 3 < out_height)
1015 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1026 inline Status
validate_arguments(
const ITensorInfo *input0,
const ITensorInfo *input1,
const ITensorInfo *output,
float alpha,
bool is_interleaved,
const GEMMReshapeInfo &reshape_info)
1038 if(output->total_size() != 0)
1047 const int m = reshape_info.m();
1048 const int n = reshape_info.n();
1049 const int k = reshape_info.k();
1050 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1051 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1054 TensorShape tensor_shape0{ input0->tensor_shape() };
1055 tensor_shape0.set(0, k);
1056 tensor_shape0.set(1, m);
1058 const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
1064 TensorShape tensor_shape1{ input1->tensor_shape() };
1065 tensor_shape1.set(0, n);
1066 tensor_shape1.set(1, k);
1068 const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
1073 if(output->total_size() != 0)
1089 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1099 tensor_shape.
set(0, is_interleaved ? reshape_info.
n() : input1->
info()->
dimension(0));
1100 tensor_shape.
set(1, is_interleaved ? reshape_info.
m() : input0->
info()->
dimension(1));
1124 constexpr
unsigned int num_elems_processed_per_iteration_x = 8;
1125 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
1133 INEKernel::configure(win);
1150 const bool is_output_vector = (_output->
info()->
dimension(1) == 1);
1155 is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) :
1156 matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha);
1159 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 1162 is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) :
1163 matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha);
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width=1)
Calculate the transposed 1xW width element shape.
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
GEMM reshape information class.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
1 channel, 1 F32 per channel
Store the tensor's metadata.
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_ERROR_THROW_ON(status)
NEGEMMMatrixMultiplyKernel()
Constructor.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Interface for Neon tensor.
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height=1, bool reinterpret_input_as_3d=false)
Calculate the interleaved shape of an input tensor.
Copyright (c) 2017-2021 Arm Limited.
virtual void set_valid_region(const ValidRegion &valid_region)=0
Set the valid region of the tensor.
1 channel, 1 F16 per channel
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
int n() const
Number of matrix B columns.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Class to describe a number of elements in each dimension.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
float16x4_t vmul_f16(float16x4_t, float16x4_t)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info=GEMMReshapeInfo())
Initialise the kernel's input and output.
int m() const
Number of matrix A rows.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Static function to check if given info will lead to a valid configuration of NEGEMMMatrixMultiplyKern...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
Container for valid region of a window.
Describe a multidimensional execution window.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)