47 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 48 void vector_matrix_multiply_f16(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
const ThreadInfo &
info,
float alpha)
50 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
51 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / input1->info()->element_size());
52 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
55 const int window_start_x = 32 *
info.thread_id;
56 const int window_step_x = 32 *
info.num_threads;
57 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
58 ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x,
" (window_end_x - window_start_x) must be multiple of window_step_x");
60 Window win_out(window);
71 if(input1->info()->num_dimensions() >= 3)
78 Iterator ina(input0, win_a);
79 Iterator inb(input1, win_b);
80 Iterator out(output, win_out);
84 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
88 int x = window_start_x;
91 for(; x < (window_end_x - window_step_x); x += window_step_x)
93 if(x > width_matrix_b)
98 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
100 float16x8_t acc0 = vdupq_n_f16(0.f);
101 float16x8_t acc1 = vdupq_n_f16(0.f);
102 float16x8_t acc2 = vdupq_n_f16(0.f);
103 float16x8_t acc3 = vdupq_n_f16(0.f);
105 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
106 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
107 for(; vec_a <= (vec_a_end_addr - 4);)
109 const float16x4_t a0l = vld1_f16(vec_a);
111 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
112 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
113 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
114 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
115 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
116 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
117 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
118 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
129 matrix_b += 2 * in_b_stride;
131 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
132 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
133 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
134 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
135 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
136 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
137 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
138 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
150 matrix_b += 2 * in_b_stride;
153 for(; vec_a < vec_a_end_addr; ++vec_a)
155 const float16_t a0 = *vec_a;
156 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
157 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
158 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
159 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
166 matrix_b += in_b_stride;
178 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
180 vst1q_f16(vec_out + 0, acc0);
181 vst1q_f16(vec_out + 8, acc1);
182 vst1q_f16(vec_out + 16, acc2);
183 vst1q_f16(vec_out + 24, acc3);
186 for(; x < window_end_x; ++x)
188 if(x > width_matrix_b)
193 auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
195 float16x4_t vacc = vdup_n_f16(0.f);
197 auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
198 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
199 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
201 const float16x4_t a0l = vld1_f16(vec_a);
203 const float16x4_t b_col =
205 *(matrix_b + 0 * in_b_stride),
206 *(matrix_b + 1 * in_b_stride),
207 *(matrix_b + 2 * in_b_stride),
208 *(matrix_b + 3 * in_b_stride),
213 matrix_b += 4 * in_b_stride;
216 float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
218 for(; vec_a < vec_a_end_addr; ++vec_a)
220 const float16_t a0 = *vec_a;
221 const float16_t b00 = *matrix_b;
225 matrix_b += in_b_stride;
231 acc *= static_cast<float16_t>(alpha);
234 auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
243 void vector_matrix_multiply_f32(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
const ThreadInfo &
info,
float alpha)
245 const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
246 const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type()));
247 const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
250 const int window_start_x = 16 *
info.thread_id;
251 const int window_step_x = 16 *
info.num_threads;
253 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
255 Window win_out(window);
259 Window win_a(window);
266 if(input1->info()->num_dimensions() >= 3)
273 Iterator ina(input0, win_a);
274 Iterator inb(input1, win_b);
275 Iterator out(output, win_out);
279 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
283 int x = window_start_x;
286 for(; x < (window_end_x - window_step_x); x += window_step_x)
288 if(x > width_matrix_b)
293 float32x4_t acc0 = vdupq_n_f32(0.f);
294 float32x4_t acc1 = vdupq_n_f32(0.f);
295 float32x4_t acc2 = vdupq_n_f32(0.f);
296 float32x4_t acc3 = vdupq_n_f32(0.f);
298 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
299 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
302 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(vec_a)));
303 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
304 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
307 auto vec_a_end_addr = vec_a + num_elems_vec_a;
308 for(; vec_a <= (vec_a_end_addr - 4);)
310 float32x2_t a0l = vld1_f32(vec_a);
312 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
313 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
314 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
315 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
317 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
318 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
319 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
320 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
323 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(vec_a)));
324 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
325 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
326 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
327 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
330 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
331 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
332 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
333 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
335 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
336 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
337 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
338 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
341 matrix_b += 2 * in_b_stride;
343 a0l = vld1_f32(vec_a);
345 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
346 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
347 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
348 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
350 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
351 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
352 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
353 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
355 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
356 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
357 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
358 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
360 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
361 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
362 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
363 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
366 matrix_b += 2 * in_b_stride;
369 for(; vec_a < vec_a_end_addr; ++vec_a)
371 const float a0 = *vec_a;
373 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
374 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
375 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
376 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
378 acc0 = vmlaq_n_f32(acc0, b00, a0);
379 acc1 = vmlaq_n_f32(acc1, b01, a0);
380 acc2 = vmlaq_n_f32(acc2, b02, a0);
381 acc3 = vmlaq_n_f32(acc3, b03, a0);
383 matrix_b += in_b_stride;
389 acc0 = vmulq_f32(acc0, alpha_f32);
390 acc1 = vmulq_f32(acc1, alpha_f32);
391 acc2 = vmulq_f32(acc2, alpha_f32);
392 acc3 = vmulq_f32(acc3, alpha_f32);
395 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
397 vst1q_f32(vec_out + 0, acc0);
398 vst1q_f32(vec_out + 4, acc1);
399 vst1q_f32(vec_out + 8, acc2);
400 vst1q_f32(vec_out + 12, acc3);
404 for(; x < window_end_x; ++x)
406 if(x > width_matrix_b)
411 float32x4_t vacc = vdupq_n_f32(0.f);
413 auto vec_a = reinterpret_cast<const float *>(ina.ptr());
414 auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
417 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(vec_a)));
418 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
419 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
422 auto vec_a_end_addr = vec_a + num_elems_vec_a;
423 for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
425 const float32x4_t a0l = vld1q_f32(vec_a);
427 const float32x4_t b_col =
429 *(matrix_b + 0 * in_b_stride),
430 *(matrix_b + 1 * in_b_stride),
431 *(matrix_b + 2 * in_b_stride),
432 *(matrix_b + 3 * in_b_stride),
436 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(vec_a)));
437 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
438 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
439 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
440 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
443 vacc = vmlaq_f32(vacc, b_col, a0l);
445 matrix_b += 4 * in_b_stride;
448 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
450 for(; vec_a < vec_a_end_addr; ++vec_a)
452 const float a0 = *vec_a;
454 const float b00 = *matrix_b;
458 matrix_b += in_b_stride;
467 const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
475 void matrix_matrix_multiply_f32(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
float alpha)
477 const int out_width = static_cast<int>(output->info()->dimension(0));
478 const int out_height = static_cast<int>(output->info()->dimension(1));
479 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type());
480 const size_t out_stride1 = output->info()->strides_in_bytes()[1] /
data_size_from_type(output->info()->data_type());
481 const size_t out_stride2 = out_stride1 * 2;
482 const size_t out_stride3 = out_stride1 * 3;
483 const int num_elems_matrix_b_x = input1->info()->dimension(0);
486 Window win_a(window);
488 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
493 if(input1->info()->num_dimensions() >= 3)
499 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
502 Iterator ina(input0, win_a);
503 Iterator inb(input1, win_b);
504 Iterator out(output, window);
508 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
515 auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
516 auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
517 auto mtx_b1 = mtx_b0 + in_b_stride;
519 float32x4_t acc00 = vdupq_n_f32(0.f);
520 float32x4_t acc10 = vdupq_n_f32(0.f);
521 float32x4_t acc20 = vdupq_n_f32(0.f);
522 float32x4_t acc30 = vdupq_n_f32(0.f);
524 float32x4_t acc01 = vdupq_n_f32(0.f);
525 float32x4_t acc11 = vdupq_n_f32(0.f);
526 float32x4_t acc21 = vdupq_n_f32(0.f);
527 float32x4_t acc31 = vdupq_n_f32(0.f);
530 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
531 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
532 asm volatile(
"PLD [%0, #128*1]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
535 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
536 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
538 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
539 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
540 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
541 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
543 float32x4_t b00 = vld1q_f32(mtx_b0);
544 float32x4_t b10 = vld1q_f32(mtx_b1);
545 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
546 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
549 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
550 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
551 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
555 acc00 = vmlaq_f32(acc00, b00, a0);
556 acc10 = vmlaq_f32(acc10, b00, a1);
557 acc20 = vmlaq_f32(acc20, b00, a2);
558 acc30 = vmlaq_f32(acc30, b00, a3);
560 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
561 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
562 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
563 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
566 acc01 = vmlaq_f32(acc01, b10, a0);
567 acc11 = vmlaq_f32(acc11, b10, a1);
568 acc21 = vmlaq_f32(acc21, b10, a2);
569 acc31 = vmlaq_f32(acc31, b10, a3);
572 acc00 = vmlaq_f32(acc00, b01, a4);
573 acc10 = vmlaq_f32(acc10, b01, a5);
574 acc20 = vmlaq_f32(acc20, b01, a6);
575 acc30 = vmlaq_f32(acc30, b01, a7);
578 acc01 = vmlaq_f32(acc01, b11, a4);
579 acc11 = vmlaq_f32(acc11, b11, a5);
580 acc21 = vmlaq_f32(acc21, b11, a6);
581 acc31 = vmlaq_f32(acc31, b11, a7);
587 a0 = vld1q_dup_f32(mtx_a0 + 0);
588 a1 = vld1q_dup_f32(mtx_a0 + 1);
589 a2 = vld1q_dup_f32(mtx_a0 + 2);
590 a3 = vld1q_dup_f32(mtx_a0 + 3);
592 b00 = vld1q_f32(mtx_b0);
593 b10 = vld1q_f32(mtx_b1);
594 b01 = vld1q_f32(mtx_b0 + 4);
595 b11 = vld1q_f32(mtx_b1 + 4);
598 acc00 = vmlaq_f32(acc00, b00, a0);
599 acc10 = vmlaq_f32(acc10, b00, a1);
600 acc20 = vmlaq_f32(acc20, b00, a2);
601 acc30 = vmlaq_f32(acc30, b00, a3);
603 a4 = vld1q_dup_f32(mtx_a0 + 4);
604 a5 = vld1q_dup_f32(mtx_a0 + 5);
605 a6 = vld1q_dup_f32(mtx_a0 + 6);
606 a7 = vld1q_dup_f32(mtx_a0 + 7);
609 acc01 = vmlaq_f32(acc01, b10, a0);
610 acc11 = vmlaq_f32(acc11, b10, a1);
611 acc21 = vmlaq_f32(acc21, b10, a2);
612 acc31 = vmlaq_f32(acc31, b10, a3);
615 acc00 = vmlaq_f32(acc00, b01, a4);
616 acc10 = vmlaq_f32(acc10, b01, a5);
617 acc20 = vmlaq_f32(acc20, b01, a6);
618 acc30 = vmlaq_f32(acc30, b01, a7);
621 acc01 = vmlaq_f32(acc01, b11, a4);
622 acc11 = vmlaq_f32(acc11, b11, a5);
623 acc21 = vmlaq_f32(acc21, b11, a6);
624 acc31 = vmlaq_f32(acc31, b11, a7);
630 a0 = vld1q_dup_f32(mtx_a0 + 0);
631 a1 = vld1q_dup_f32(mtx_a0 + 1);
632 a2 = vld1q_dup_f32(mtx_a0 + 2);
633 a3 = vld1q_dup_f32(mtx_a0 + 3);
634 b00 = vld1q_f32(mtx_b0);
635 b10 = vld1q_f32(mtx_b1);
636 b01 = vld1q_f32(mtx_b0 + 4);
637 b11 = vld1q_f32(mtx_b1 + 4);
640 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
641 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
642 asm volatile(
"PLD [%0, #128*4]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
646 acc00 = vmlaq_f32(acc00, b00, a0);
647 acc10 = vmlaq_f32(acc10, b00, a1);
648 acc20 = vmlaq_f32(acc20, b00, a2);
649 acc30 = vmlaq_f32(acc30, b00, a3);
651 a4 = vld1q_dup_f32(mtx_a0 + 4);
652 a5 = vld1q_dup_f32(mtx_a0 + 5);
653 a6 = vld1q_dup_f32(mtx_a0 + 6);
654 a7 = vld1q_dup_f32(mtx_a0 + 7);
657 acc01 = vmlaq_f32(acc01, b10, a0);
658 acc11 = vmlaq_f32(acc11, b10, a1);
659 acc21 = vmlaq_f32(acc21, b10, a2);
660 acc31 = vmlaq_f32(acc31, b10, a3);
663 acc00 = vmlaq_f32(acc00, b01, a4);
664 acc10 = vmlaq_f32(acc10, b01, a5);
665 acc20 = vmlaq_f32(acc20, b01, a6);
666 acc30 = vmlaq_f32(acc30, b01, a7);
669 acc01 = vmlaq_f32(acc01, b11, a4);
670 acc11 = vmlaq_f32(acc11, b11, a5);
671 acc21 = vmlaq_f32(acc21, b11, a6);
672 acc31 = vmlaq_f32(acc31, b11, a7);
678 a0 = vld1q_dup_f32(mtx_a0 + 0);
679 a1 = vld1q_dup_f32(mtx_a0 + 1);
680 a2 = vld1q_dup_f32(mtx_a0 + 2);
681 a3 = vld1q_dup_f32(mtx_a0 + 3);
682 b00 = vld1q_f32(mtx_b0);
683 b10 = vld1q_f32(mtx_b1);
684 b01 = vld1q_f32(mtx_b0 + 4);
685 b11 = vld1q_f32(mtx_b1 + 4);
688 acc00 = vmlaq_f32(acc00, b00, a0);
689 acc10 = vmlaq_f32(acc10, b00, a1);
690 acc20 = vmlaq_f32(acc20, b00, a2);
691 acc30 = vmlaq_f32(acc30, b00, a3);
693 a4 = vld1q_dup_f32(mtx_a0 + 4);
694 a5 = vld1q_dup_f32(mtx_a0 + 5);
695 a6 = vld1q_dup_f32(mtx_a0 + 6);
696 a7 = vld1q_dup_f32(mtx_a0 + 7);
699 acc01 = vmlaq_f32(acc01, b10, a0);
700 acc11 = vmlaq_f32(acc11, b10, a1);
701 acc21 = vmlaq_f32(acc21, b10, a2);
702 acc31 = vmlaq_f32(acc31, b10, a3);
705 acc00 = vmlaq_f32(acc00, b01, a4);
706 acc10 = vmlaq_f32(acc10, b01, a5);
707 acc20 = vmlaq_f32(acc20, b01, a6);
708 acc30 = vmlaq_f32(acc30, b01, a7);
711 acc01 = vmlaq_f32(acc01, b11, a4);
712 acc11 = vmlaq_f32(acc11, b11, a5);
713 acc21 = vmlaq_f32(acc21, b11, a6);
714 acc31 = vmlaq_f32(acc31, b11, a7);
721 for(; mtx_b0 < mtx_b0_end_addr;)
723 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
724 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
725 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
726 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
727 float32x4_t b00 = vld1q_f32(mtx_b0);
728 float32x4_t b10 = vld1q_f32(mtx_b1);
731 asm volatile(
"PLD [%0, #128*2]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
732 asm volatile(
"PLD [%0, #128*2]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
733 asm volatile(
"PLD [%0, #128*2]" ::
"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
736 acc00 = vmlaq_f32(acc00, b00, a0);
737 acc10 = vmlaq_f32(acc10, b00, a1);
738 acc20 = vmlaq_f32(acc20, b00, a2);
739 acc30 = vmlaq_f32(acc30, b00, a3);
742 acc01 = vmlaq_f32(acc01, b10, a0);
743 acc11 = vmlaq_f32(acc11, b10, a1);
744 acc21 = vmlaq_f32(acc21, b10, a2);
745 acc31 = vmlaq_f32(acc31, b10, a3);
755 acc00 = vmulq_f32(acc00, alpha_f32);
756 acc10 = vmulq_f32(acc10, alpha_f32);
757 acc20 = vmulq_f32(acc20, alpha_f32);
758 acc30 = vmulq_f32(acc30, alpha_f32);
759 acc01 = vmulq_f32(acc01, alpha_f32);
760 acc11 = vmulq_f32(acc11, alpha_f32);
761 acc21 = vmulq_f32(acc21, alpha_f32);
762 acc31 = vmulq_f32(acc31, alpha_f32);
765 const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
766 const auto mtx_out1 = mtx_out0 + 4;
768 if(
id.x() < (out_width - 8))
770 vst1q_f32(mtx_out0, acc00);
771 vst1q_f32(mtx_out1, acc01);
772 if(
id.y() + 1 < out_height)
774 vst1q_f32(mtx_out0 + out_stride1, acc10);
775 vst1q_f32(mtx_out1 + out_stride1, acc11);
776 if(
id.y() + 2 < out_height)
778 vst1q_f32(mtx_out0 + out_stride2, acc20);
779 vst1q_f32(mtx_out1 + out_stride2, acc21);
780 if(
id.y() + 3 < out_height)
782 vst1q_f32(mtx_out0 + out_stride3, acc30);
783 vst1q_f32(mtx_out1 + out_stride3, acc31);
788 else if(
id.x() < (out_width - 4))
790 vst1q_f32(mtx_out0, acc00);
791 if(
id.y() + 1 < out_height)
793 vst1q_f32(mtx_out0 + out_stride1, acc10);
794 if(
id.y() + 2 < out_height)
796 vst1q_f32(mtx_out0 + out_stride2, acc20);
797 if(
id.y() + 3 < out_height)
799 vst1q_f32(mtx_out0 + out_stride3, acc30);
804 const int columns_left = out_width -
id.x() - 4;
805 for(
auto x = 0; x < columns_left; ++x)
807 *(mtx_out1 + x) = acc01[x];
808 if(
id.y() + 1 < out_height)
810 *(mtx_out1 + x + out_stride1) = acc11[x];
811 if(
id.y() + 2 < out_height)
813 *(mtx_out1 + x + out_stride2) = acc21[x];
814 if(
id.y() + 3 < out_height)
816 *(mtx_out1 + x + out_stride3) = acc31[x];
825 const int columns_left = out_width -
id.x();
826 for(
int x = 0; x < columns_left; ++x)
828 *(mtx_out0 + x) = acc00[x];
829 if(
id.y() + 1 < out_height)
831 *(mtx_out0 + x + out_stride1) = acc10[x];
832 if(
id.y() + 2 < out_height)
834 *(mtx_out0 + x + out_stride2) = acc20[x];
835 if(
id.y() + 3 < out_height)
837 *(mtx_out0 + x + out_stride3) = acc30[x];
847 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 848 void matrix_matrix_multiply_f16(
const ITensor *input0,
const ITensor *input1, ITensor *output,
const Window &window,
float alpha)
850 const int out_width = static_cast<int>(output->info()->dimension(0));
851 const int out_height = static_cast<int>(output->info()->dimension(1));
852 const size_t in_b_stride = input1->info()->strides_in_bytes()[1] /
data_size_from_type(input1->info()->data_type());
853 const size_t out_stride = output->info()->strides_in_bytes()[1] /
data_size_from_type(output->info()->data_type());
854 const int num_elems_matrix_b_x = input1->info()->dimension(0);
857 Window win_a(window);
859 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
864 if(input1->info()->num_dimensions() >= 3)
869 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
872 Iterator ina(input0, win_a);
873 Iterator inb(input1, win_b);
874 Iterator out(output, window);
878 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
882 const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
883 const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
884 auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
923 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
925 for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
928 const float16x8_t p00 = vld1q_f16(mtx_a0);
929 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
931 const float16x8_t q00 = vld1q_f16(mtx_b0);
932 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
933 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
934 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
960 for(; mtx_b0 < mtx_b0_end_addr;)
963 const float16x4_t p00 = vld1_f16(mtx_a0);
964 const float16x8_t q00 = vld1q_f16(mtx_b0);
977 c.val[0] =
vmulq_f16(c.val[0], alpha_f16);
978 c.val[1] =
vmulq_f16(c.val[1], alpha_f16);
979 c.val[2] =
vmulq_f16(c.val[2], alpha_f16);
980 c.val[3] =
vmulq_f16(c.val[3], alpha_f16);
983 if(
id.x() < (out_width - 8))
985 vst1q_f16(mtx_out, c.val[0]);
986 if(
id.y() + 1 < out_height)
988 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
989 if(
id.y() + 2 < out_height)
991 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
992 if(
id.y() + 3 < out_height)
994 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
1002 const int columns_left = out_width -
id.x();
1003 for(
int x = 0; x < columns_left; ++x)
1005 *(mtx_out + x) = c.val[0][x];
1006 if(
id.y() + 1 < out_height)
1008 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1009 if(
id.y() + 2 < out_height)
1011 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1012 if(
id.y() + 3 < out_height)
1014 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1025 inline Status
validate_arguments(
const ITensorInfo *input0,
const ITensorInfo *input1,
const ITensorInfo *output,
float alpha,
bool is_interleaved,
const GEMMReshapeInfo &reshape_info)
1037 if(output->total_size() != 0)
1046 const int m = reshape_info.m();
1047 const int n = reshape_info.n();
1048 const int k = reshape_info.k();
1049 const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1050 const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1053 TensorShape tensor_shape0{ input0->tensor_shape() };
1054 tensor_shape0.set(0, k);
1055 tensor_shape0.set(1, m);
1057 const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
1063 TensorShape tensor_shape1{ input1->tensor_shape() };
1064 tensor_shape1.set(0, n);
1065 tensor_shape1.set(1, k);
1067 const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
1072 if(output->total_size() != 0)
1088 : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1098 tensor_shape.
set(0, is_interleaved ? reshape_info.
n() : input1->
info()->
dimension(0));
1099 tensor_shape.
set(1, is_interleaved ? reshape_info.
m() : input0->
info()->
dimension(1));
1123 constexpr
unsigned int num_elems_processed_per_iteration_x = 8;
1124 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
1129 INEKernel::configure(win);
1146 const bool is_output_vector = (_output->
info()->
dimension(1) == 1);
1151 is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output,
window,
info, _alpha) :
1152 matrix_matrix_multiply_f32(_input0, _input1, _output,
window, _alpha);
1155 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 1158 is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output,
window,
info, _alpha) :
1159 matrix_matrix_multiply_f16(_input0, _input1, _output,
window, _alpha);
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width=1)
Calculate the transposed 1xW width element shape.
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
GEMM reshape information class.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
1 channel, 1 F32 per channel
Store the tensor's metadata.
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
#define ARM_COMPUTE_ERROR_THROW_ON(status)
NEGEMMMatrixMultiplyKernel()
Constructor.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Interface for CPU tensor.
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height=1, bool reinterpret_input_as_3d=false)
Calculate the interleaved shape of an input tensor.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
int n() const
Number of matrix B columns.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Class to describe a number of elements in each dimension.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
float16x4_t vmul_f16(float16x4_t, float16x4_t)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info=GEMMReshapeInfo())
Initialise the kernel's input and output.
int m() const
Number of matrix A rows.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Static function to check if given info will lead to a valid configuration of NEGEMMMatrixMultiplyKern...
Describe a multidimensional execution window.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)