38 const auto width_matrix_b =
static_cast<int>(
dst->info()->dimension(0));
39 const auto in_b_stride =
41 const auto num_elems_vec_a =
static_cast<int>(lhs->
info()->
dimension(0));
44 const int window_start_x = 16 *
info.thread_id;
45 const int window_step_x = 16 *
info.num_threads;
47 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
73 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
79 int x = window_start_x;
82 for (; x < (window_end_x - window_step_x); x += window_step_x)
84 if (x > width_matrix_b)
89 float32x4_t acc0 = vdupq_n_f32(0.f);
90 float32x4_t acc1 = vdupq_n_f32(0.f);
91 float32x4_t acc2 = vdupq_n_f32(0.f);
92 float32x4_t acc3 = vdupq_n_f32(0.f);
94 auto vec_a =
reinterpret_cast<const float *
>(ina.
ptr());
95 auto matrix_b =
reinterpret_cast<const float *
>(inb.
ptr()) + x;
98 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
99 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
100 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
103 auto vec_a_end_addr = vec_a + num_elems_vec_a;
104 for (; vec_a <= (vec_a_end_addr - 4);)
106 float32x2_t a0l = vld1_f32(vec_a);
108 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
109 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
110 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
111 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
113 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
114 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
115 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
116 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
119 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
121 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
123 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
125 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
127 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
130 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
131 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
132 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
133 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
135 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
136 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
137 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
138 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
141 matrix_b += 2 * in_b_stride;
143 a0l = vld1_f32(vec_a);
145 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
146 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
147 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
148 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
150 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
151 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
152 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
153 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
155 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
156 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
157 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
158 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
160 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
161 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
162 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
163 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
166 matrix_b += 2 * in_b_stride;
169 for (; vec_a < vec_a_end_addr; ++vec_a)
171 const float a0 = *vec_a;
173 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
174 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
175 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
176 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
178 acc0 = vmlaq_n_f32(acc0, b00, a0);
179 acc1 = vmlaq_n_f32(acc1, b01, a0);
180 acc2 = vmlaq_n_f32(acc2, b02, a0);
181 acc3 = vmlaq_n_f32(acc3, b03, a0);
183 matrix_b += in_b_stride;
189 acc0 = vmulq_f32(acc0, alpha_f32);
190 acc1 = vmulq_f32(acc1, alpha_f32);
191 acc2 = vmulq_f32(acc2, alpha_f32);
192 acc3 = vmulq_f32(acc3, alpha_f32);
195 const auto vec_out =
reinterpret_cast<float *
>(out.
ptr()) + x;
197 vst1q_f32(vec_out + 0, acc0);
198 vst1q_f32(vec_out + 4, acc1);
199 vst1q_f32(vec_out + 8, acc2);
200 vst1q_f32(vec_out + 12, acc3);
204 for (; x < window_end_x; ++x)
206 if (x > width_matrix_b)
211 float32x4_t vacc = vdupq_n_f32(0.f);
213 auto vec_a =
reinterpret_cast<const float *
>(ina.
ptr());
214 auto matrix_b =
reinterpret_cast<const float *
>(inb.
ptr()) + x;
217 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
218 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b)));
219 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + in_b_stride)));
222 auto vec_a_end_addr = vec_a + num_elems_vec_a;
223 for (; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
225 const float32x4_t a0l = vld1q_f32(vec_a);
227 const float32x4_t b_col = {
228 *(matrix_b + 0 * in_b_stride),
229 *(matrix_b + 1 * in_b_stride),
230 *(matrix_b + 2 * in_b_stride),
231 *(matrix_b + 3 * in_b_stride),
235 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(vec_a)));
237 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 1 * in_b_stride)));
239 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 2 * in_b_stride)));
241 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 3 * in_b_stride)));
243 "PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(matrix_b + 4 * in_b_stride)));
246 vacc = vmlaq_f32(vacc, b_col, a0l);
248 matrix_b += 4 * in_b_stride;
251 float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) +
252 vgetq_lane_f32(vacc, 3);
254 for (; vec_a < vec_a_end_addr; ++vec_a)
256 const float a0 = *vec_a;
258 const float b00 = *matrix_b;
262 matrix_b += in_b_stride;
271 const auto vec_out =
reinterpret_cast<float *
>(out.
ptr()) + x;
283 const int out_width =
static_cast<int>(
dst->info()->dimension(0));
284 const int out_height =
static_cast<int>(
dst->info()->dimension(1));
287 const size_t out_stride2 = out_stride1 * 2;
288 const size_t out_stride3 = out_stride1 * 3;
314 const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
323 auto mtx_a0 =
reinterpret_cast<const float *
>(ina.
ptr());
324 auto mtx_b0 =
reinterpret_cast<const float *
>(inb.
ptr());
325 auto mtx_b1 = mtx_b0 + in_b_stride;
327 float32x4_t acc00 = vdupq_n_f32(0.f);
328 float32x4_t acc10 = vdupq_n_f32(0.f);
329 float32x4_t acc20 = vdupq_n_f32(0.f);
330 float32x4_t acc30 = vdupq_n_f32(0.f);
332 float32x4_t acc01 = vdupq_n_f32(0.f);
333 float32x4_t acc11 = vdupq_n_f32(0.f);
334 float32x4_t acc21 = vdupq_n_f32(0.f);
335 float32x4_t acc31 = vdupq_n_f32(0.f);
338 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
339 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
340 asm volatile(
"PLD [%0, #128*1]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
343 auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
344 for (; mtx_b0 <= (mtx_b0_end_addr - 32);)
346 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
347 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
348 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
349 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
351 float32x4_t b00 = vld1q_f32(mtx_b0);
352 float32x4_t b10 = vld1q_f32(mtx_b1);
353 float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
354 float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
357 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
358 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
359 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
363 acc00 = vmlaq_f32(acc00, b00, a0);
364 acc10 = vmlaq_f32(acc10, b00, a1);
365 acc20 = vmlaq_f32(acc20, b00, a2);
366 acc30 = vmlaq_f32(acc30, b00, a3);
368 float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
369 float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
370 float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
371 float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
374 acc01 = vmlaq_f32(acc01, b10, a0);
375 acc11 = vmlaq_f32(acc11, b10, a1);
376 acc21 = vmlaq_f32(acc21, b10, a2);
377 acc31 = vmlaq_f32(acc31, b10, a3);
380 acc00 = vmlaq_f32(acc00, b01, a4);
381 acc10 = vmlaq_f32(acc10, b01, a5);
382 acc20 = vmlaq_f32(acc20, b01, a6);
383 acc30 = vmlaq_f32(acc30, b01, a7);
386 acc01 = vmlaq_f32(acc01, b11, a4);
387 acc11 = vmlaq_f32(acc11, b11, a5);
388 acc21 = vmlaq_f32(acc21, b11, a6);
389 acc31 = vmlaq_f32(acc31, b11, a7);
395 a0 = vld1q_dup_f32(mtx_a0 + 0);
396 a1 = vld1q_dup_f32(mtx_a0 + 1);
397 a2 = vld1q_dup_f32(mtx_a0 + 2);
398 a3 = vld1q_dup_f32(mtx_a0 + 3);
400 b00 = vld1q_f32(mtx_b0);
401 b10 = vld1q_f32(mtx_b1);
402 b01 = vld1q_f32(mtx_b0 + 4);
403 b11 = vld1q_f32(mtx_b1 + 4);
406 acc00 = vmlaq_f32(acc00, b00, a0);
407 acc10 = vmlaq_f32(acc10, b00, a1);
408 acc20 = vmlaq_f32(acc20, b00, a2);
409 acc30 = vmlaq_f32(acc30, b00, a3);
411 a4 = vld1q_dup_f32(mtx_a0 + 4);
412 a5 = vld1q_dup_f32(mtx_a0 + 5);
413 a6 = vld1q_dup_f32(mtx_a0 + 6);
414 a7 = vld1q_dup_f32(mtx_a0 + 7);
417 acc01 = vmlaq_f32(acc01, b10, a0);
418 acc11 = vmlaq_f32(acc11, b10, a1);
419 acc21 = vmlaq_f32(acc21, b10, a2);
420 acc31 = vmlaq_f32(acc31, b10, a3);
423 acc00 = vmlaq_f32(acc00, b01, a4);
424 acc10 = vmlaq_f32(acc10, b01, a5);
425 acc20 = vmlaq_f32(acc20, b01, a6);
426 acc30 = vmlaq_f32(acc30, b01, a7);
429 acc01 = vmlaq_f32(acc01, b11, a4);
430 acc11 = vmlaq_f32(acc11, b11, a5);
431 acc21 = vmlaq_f32(acc21, b11, a6);
432 acc31 = vmlaq_f32(acc31, b11, a7);
438 a0 = vld1q_dup_f32(mtx_a0 + 0);
439 a1 = vld1q_dup_f32(mtx_a0 + 1);
440 a2 = vld1q_dup_f32(mtx_a0 + 2);
441 a3 = vld1q_dup_f32(mtx_a0 + 3);
442 b00 = vld1q_f32(mtx_b0);
443 b10 = vld1q_f32(mtx_b1);
444 b01 = vld1q_f32(mtx_b0 + 4);
445 b11 = vld1q_f32(mtx_b1 + 4);
448 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
449 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
450 asm volatile(
"PLD [%0, #128*4]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
454 acc00 = vmlaq_f32(acc00, b00, a0);
455 acc10 = vmlaq_f32(acc10, b00, a1);
456 acc20 = vmlaq_f32(acc20, b00, a2);
457 acc30 = vmlaq_f32(acc30, b00, a3);
459 a4 = vld1q_dup_f32(mtx_a0 + 4);
460 a5 = vld1q_dup_f32(mtx_a0 + 5);
461 a6 = vld1q_dup_f32(mtx_a0 + 6);
462 a7 = vld1q_dup_f32(mtx_a0 + 7);
465 acc01 = vmlaq_f32(acc01, b10, a0);
466 acc11 = vmlaq_f32(acc11, b10, a1);
467 acc21 = vmlaq_f32(acc21, b10, a2);
468 acc31 = vmlaq_f32(acc31, b10, a3);
471 acc00 = vmlaq_f32(acc00, b01, a4);
472 acc10 = vmlaq_f32(acc10, b01, a5);
473 acc20 = vmlaq_f32(acc20, b01, a6);
474 acc30 = vmlaq_f32(acc30, b01, a7);
477 acc01 = vmlaq_f32(acc01, b11, a4);
478 acc11 = vmlaq_f32(acc11, b11, a5);
479 acc21 = vmlaq_f32(acc21, b11, a6);
480 acc31 = vmlaq_f32(acc31, b11, a7);
486 a0 = vld1q_dup_f32(mtx_a0 + 0);
487 a1 = vld1q_dup_f32(mtx_a0 + 1);
488 a2 = vld1q_dup_f32(mtx_a0 + 2);
489 a3 = vld1q_dup_f32(mtx_a0 + 3);
490 b00 = vld1q_f32(mtx_b0);
491 b10 = vld1q_f32(mtx_b1);
492 b01 = vld1q_f32(mtx_b0 + 4);
493 b11 = vld1q_f32(mtx_b1 + 4);
496 acc00 = vmlaq_f32(acc00, b00, a0);
497 acc10 = vmlaq_f32(acc10, b00, a1);
498 acc20 = vmlaq_f32(acc20, b00, a2);
499 acc30 = vmlaq_f32(acc30, b00, a3);
501 a4 = vld1q_dup_f32(mtx_a0 + 4);
502 a5 = vld1q_dup_f32(mtx_a0 + 5);
503 a6 = vld1q_dup_f32(mtx_a0 + 6);
504 a7 = vld1q_dup_f32(mtx_a0 + 7);
507 acc01 = vmlaq_f32(acc01, b10, a0);
508 acc11 = vmlaq_f32(acc11, b10, a1);
509 acc21 = vmlaq_f32(acc21, b10, a2);
510 acc31 = vmlaq_f32(acc31, b10, a3);
513 acc00 = vmlaq_f32(acc00, b01, a4);
514 acc10 = vmlaq_f32(acc10, b01, a5);
515 acc20 = vmlaq_f32(acc20, b01, a6);
516 acc30 = vmlaq_f32(acc30, b01, a7);
519 acc01 = vmlaq_f32(acc01, b11, a4);
520 acc11 = vmlaq_f32(acc11, b11, a5);
521 acc21 = vmlaq_f32(acc21, b11, a6);
522 acc31 = vmlaq_f32(acc31, b11, a7);
529 for (; mtx_b0 < mtx_b0_end_addr;)
531 float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
532 float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
533 float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
534 float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
535 float32x4_t b00 = vld1q_f32(mtx_b0);
536 float32x4_t b10 = vld1q_f32(mtx_b1);
539 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_a0)));
540 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b0)));
541 asm volatile(
"PLD [%0, #128*2]" ::
"r"(
reinterpret_cast<const uint8_t *
>(mtx_b1)));
544 acc00 = vmlaq_f32(acc00, b00, a0);
545 acc10 = vmlaq_f32(acc10, b00, a1);
546 acc20 = vmlaq_f32(acc20, b00, a2);
547 acc30 = vmlaq_f32(acc30, b00, a3);
550 acc01 = vmlaq_f32(acc01, b10, a0);
551 acc11 = vmlaq_f32(acc11, b10, a1);
552 acc21 = vmlaq_f32(acc21, b10, a2);
553 acc31 = vmlaq_f32(acc31, b10, a3);
563 acc00 = vmulq_f32(acc00, alpha_f32);
564 acc10 = vmulq_f32(acc10, alpha_f32);
565 acc20 = vmulq_f32(acc20, alpha_f32);
566 acc30 = vmulq_f32(acc30, alpha_f32);
567 acc01 = vmulq_f32(acc01, alpha_f32);
568 acc11 = vmulq_f32(acc11, alpha_f32);
569 acc21 = vmulq_f32(acc21, alpha_f32);
570 acc31 = vmulq_f32(acc31, alpha_f32);
573 const auto mtx_out0 =
reinterpret_cast<float *
>(out.
ptr());
574 const auto mtx_out1 = mtx_out0 + 4;
576 if (
id.x() < (out_width - 8))
578 vst1q_f32(mtx_out0, acc00);
579 vst1q_f32(mtx_out1, acc01);
580 if (
id.y() + 1 < out_height)
582 vst1q_f32(mtx_out0 + out_stride1, acc10);
583 vst1q_f32(mtx_out1 + out_stride1, acc11);
584 if (
id.y() + 2 < out_height)
586 vst1q_f32(mtx_out0 + out_stride2, acc20);
587 vst1q_f32(mtx_out1 + out_stride2, acc21);
588 if (
id.y() + 3 < out_height)
590 vst1q_f32(mtx_out0 + out_stride3, acc30);
591 vst1q_f32(mtx_out1 + out_stride3, acc31);
596 else if (
id.x() < (out_width - 4))
598 vst1q_f32(mtx_out0, acc00);
599 if (
id.y() + 1 < out_height)
601 vst1q_f32(mtx_out0 + out_stride1, acc10);
602 if (
id.y() + 2 < out_height)
604 vst1q_f32(mtx_out0 + out_stride2, acc20);
605 if (
id.y() + 3 < out_height)
607 vst1q_f32(mtx_out0 + out_stride3, acc30);
612 const int columns_left = out_width -
id.x() - 4;
613 for (
auto x = 0; x < columns_left; ++x)
615 *(mtx_out1 + x) = acc01[x];
616 if (
id.y() + 1 < out_height)
618 *(mtx_out1 + x + out_stride1) = acc11[x];
619 if (
id.y() + 2 < out_height)
621 *(mtx_out1 + x + out_stride2) = acc21[x];
622 if (
id.y() + 3 < out_height)
624 *(mtx_out1 + x + out_stride3) = acc31[x];
633 const int columns_left = out_width -
id.x();
634 for (
int x = 0; x < columns_left; ++x)
636 *(mtx_out0 + x) = acc00[x];
637 if (
id.y() + 1 < out_height)
639 *(mtx_out0 + x + out_stride1) = acc10[x];
640 if (
id.y() + 2 < out_height)
642 *(mtx_out0 + x + out_stride2) = acc20[x];
643 if (
id.y() + 3 < out_height)
645 *(mtx_out0 + x + out_stride3) = acc30[x];