24 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
35 void vector_matrix_multiply_f16(
36 const ITensor *lhs,
const ITensor *rhs, ITensor *
dst,
const Window &window,
const ThreadInfo &
info,
float alpha)
38 const auto width_matrix_b =
static_cast<int>(
dst->info()->dimension(0));
39 const auto in_b_stride =
static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
40 const auto num_elems_vec_a =
static_cast<int>(lhs->info()->dimension(0));
43 const int window_start_x = 32 *
info.thread_id;
44 const int window_step_x = 32 *
info.num_threads;
45 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
47 " (window_end_x - window_start_x) must be multiple of window_step_x");
49 Window win_out(window);
60 if (rhs->info()->num_dimensions() >= 3)
67 Iterator ina(lhs, win_a);
68 Iterator inb(rhs, win_b);
69 Iterator out(
dst, win_out);
73 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
77 [&](
const Coordinates &)
79 int x = window_start_x;
82 for (; x < (window_end_x - window_step_x); x += window_step_x)
84 if (x > width_matrix_b)
89 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
91 float16x8_t acc0 = vdupq_n_f16(0.f);
92 float16x8_t acc1 = vdupq_n_f16(0.f);
93 float16x8_t acc2 = vdupq_n_f16(0.f);
94 float16x8_t acc3 = vdupq_n_f16(0.f);
96 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
97 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
98 for (; vec_a <= (vec_a_end_addr - 4);)
100 const float16x4_t a0l = vld1_f16(vec_a);
102 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
103 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
104 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
105 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
106 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
107 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
108 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
109 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
120 matrix_b += 2 * in_b_stride;
122 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
123 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
124 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
125 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
126 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
127 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
128 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
129 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
141 matrix_b += 2 * in_b_stride;
144 for (; vec_a < vec_a_end_addr; ++vec_a)
146 const float16_t a0 = *vec_a;
147 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
148 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
149 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
150 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
157 matrix_b += in_b_stride;
169 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
171 vst1q_f16(vec_out + 0, acc0);
172 vst1q_f16(vec_out + 8, acc1);
173 vst1q_f16(vec_out + 16, acc2);
174 vst1q_f16(vec_out + 24, acc3);
177 for (; x < window_end_x; ++x)
179 if (x > width_matrix_b)
184 auto matrix_b =
reinterpret_cast<const float16_t *
>(inb.ptr()) + x;
186 float16x4_t vacc = vdup_n_f16(0.f);
188 auto vec_a =
reinterpret_cast<const float16_t *
>(ina.ptr());
189 const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
190 for (; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
192 const float16x4_t a0l = vld1_f16(vec_a);
194 const float16x4_t b_col = {
195 *(matrix_b + 0 * in_b_stride),
196 *(matrix_b + 1 * in_b_stride),
197 *(matrix_b + 2 * in_b_stride),
198 *(matrix_b + 3 * in_b_stride),
203 matrix_b += 4 * in_b_stride;
207 vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
209 for (; vec_a < vec_a_end_addr; ++vec_a)
211 const float16_t a0 = *vec_a;
212 const float16_t b00 = *matrix_b;
216 matrix_b += in_b_stride;
222 acc *=
static_cast<float16_t
>(alpha);
225 auto vec_out =
reinterpret_cast<float16_t *
>(out.ptr()) + x;
233 void matrix_matrix_multiply_f16(
234 const ITensor *lhs,
const ITensor *rhs, ITensor *
dst,
const Window &window,
const ThreadInfo &
info,
float alpha)
237 const int out_width =
static_cast<int>(
dst->info()->dimension(0));
238 const int out_height =
static_cast<int>(
dst->info()->dimension(1));
239 const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] /
data_size_from_type(rhs->info()->data_type());
241 const int num_elems_matrix_b_x = rhs->info()->dimension(0);
244 Window win_a(window);
246 win_a.set(
Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
251 if (rhs->info()->num_dimensions() >= 3)
256 win_b.set(
Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
259 Iterator ina(lhs, win_a);
260 Iterator inb(rhs, win_b);
261 Iterator out(
dst, window);
265 const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
269 [&](
const Coordinates &
id)
271 const auto *mtx_a0 =
reinterpret_cast<const float16_t *
>(ina.ptr());
272 const auto *mtx_b0 =
reinterpret_cast<const float16_t *
>(inb.ptr());
273 auto *mtx_out =
reinterpret_cast<float16_t *
>(out.ptr());
274 float16x8x4_t c = {{vdupq_n_f16(0.f), vdupq_n_f16(0.f), vdupq_n_f16(0.f), vdupq_n_f16(0.f)}};
304 const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
306 for (; mtx_b0 <= (mtx_b0_end_addr - 32);)
309 const float16x8_t p00 = vld1q_f16(mtx_a0);
310 const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
312 const float16x8_t q00 = vld1q_f16(mtx_b0);
313 const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
314 const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
315 const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
341 for (; mtx_b0 < mtx_b0_end_addr;)
344 const float16x4_t p00 = vld1_f16(mtx_a0);
345 const float16x8_t q00 = vld1q_f16(mtx_b0);
358 c.val[0] =
vmulq_f16(c.val[0], alpha_f16);
359 c.val[1] =
vmulq_f16(c.val[1], alpha_f16);
360 c.val[2] =
vmulq_f16(c.val[2], alpha_f16);
361 c.val[3] =
vmulq_f16(c.val[3], alpha_f16);
364 if (
id.x() < (out_width - 8))
366 vst1q_f16(mtx_out, c.val[0]);
367 if (
id.y() + 1 < out_height)
369 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
370 if (
id.y() + 2 < out_height)
372 vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
373 if (
id.y() + 3 < out_height)
375 vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
383 const int columns_left = out_width -
id.x();
384 for (
int x = 0; x < columns_left; ++x)
386 *(mtx_out + x) = c.val[0][x];
387 if (
id.y() + 1 < out_height)
389 *(mtx_out + x + 1 * out_stride) = c.val[1][x];
390 if (
id.y() + 2 < out_height)
392 *(mtx_out + x + 2 * out_stride) = c.val[2][x];
393 if (
id.y() + 3 < out_height)
395 *(mtx_out + x + 3 * out_stride) = c.val[3][x];
408 const Window &window,
409 const ThreadInfo &
info,
411 const bool is_dst_vector)
413 return (is_dst_vector) ? vector_matrix_multiply_f16(lhs, rhs,
dst, window,
info, alpha)
414 : matrix_matrix_multiply_f16(lhs, rhs,
dst, window,
info, alpha);
418 #endif //__ARM_FEATURE_FP16_VECTOR_ARITHMETIC