Compute Library
 22.05
impl.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
27 
28 #include <arm_neon.h>
29 
30 namespace arm_compute
31 {
32 namespace cpu
33 {
34 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
35 void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
36 {
37  const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
38  const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
39  const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
40 
41  // The implementation computes 32 elements per iteration
42  const int window_start_x = 32 * info.thread_id;
43  const int window_step_x = 32 * info.num_threads;
44  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
45  ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
46 
47  Window win_out(window);
48  win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
49  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
50 
51  Window win_a(window);
52  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
53  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
54 
55  Window win_b;
56  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
57  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
58  if(rhs->info()->num_dimensions() >= 3)
59  {
60  win_b = window;
61  }
62  win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
63  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
64 
65  Iterator ina(lhs, win_a);
66  Iterator inb(rhs, win_b);
67  Iterator out(dst, win_out);
68 
69  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
70 
71  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
72 
73  execute_window_loop(win_out, [&](const Coordinates &)
74  {
75  int x = window_start_x;
76  // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
77  // window_end_x is computed above which may cause out-of-bound writes to the dst.
78  for(; x < (window_end_x - window_step_x); x += window_step_x)
79  {
80  if(x > width_matrix_b)
81  {
82  return;
83  }
84 
85  auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
86 
87  float16x8_t acc0 = vdupq_n_f16(0.f);
88  float16x8_t acc1 = vdupq_n_f16(0.f);
89  float16x8_t acc2 = vdupq_n_f16(0.f);
90  float16x8_t acc3 = vdupq_n_f16(0.f);
91 
92  auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
93  const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
94  for(; vec_a <= (vec_a_end_addr - 4);)
95  {
96  const float16x4_t a0l = vld1_f16(vec_a);
97 
98  float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
99  float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
100  float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
101  float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
102  float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
103  float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
104  float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
105  float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
106 
107  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
108  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
109  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
110  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
111  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
112  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
113  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
114  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
115 
116  matrix_b += 2 * in_b_stride;
117 
118  b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
119  b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
120  b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
121  b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
122  b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
123  b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
124  b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
125  b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
126 
127  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
128  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
129  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
130  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
131  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
132  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
133  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
134  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
135 
136  vec_a += 4;
137  matrix_b += 2 * in_b_stride;
138  }
139 
140  for(; vec_a < vec_a_end_addr; ++vec_a)
141  {
142  const float16_t a0 = *vec_a;
143  const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
144  const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
145  const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
146  const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
147 
148  acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
149  acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
150  acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
151  acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
152 
153  matrix_b += in_b_stride;
154  }
155 
156  // Multiply by the weight of matrix product (alpha)
157  if(multiply_alpha)
158  {
159  acc0 = vmulq_f16(acc0, alpha_f16);
160  acc1 = vmulq_f16(acc1, alpha_f16);
161  acc2 = vmulq_f16(acc2, alpha_f16);
162  acc3 = vmulq_f16(acc3, alpha_f16);
163  }
164 
165  auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
166 
167  vst1q_f16(vec_out + 0, acc0);
168  vst1q_f16(vec_out + 8, acc1);
169  vst1q_f16(vec_out + 16, acc2);
170  vst1q_f16(vec_out + 24, acc3);
171  }
172 
173  for(; x < window_end_x; ++x)
174  {
175  if(x > width_matrix_b)
176  {
177  return;
178  }
179 
180  auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
181 
182  float16x4_t vacc = vdup_n_f16(0.f);
183 
184  auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
185  const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
186  for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
187  {
188  const float16x4_t a0l = vld1_f16(vec_a);
189 
190  const float16x4_t b_col =
191  {
192  *(matrix_b + 0 * in_b_stride),
193  *(matrix_b + 1 * in_b_stride),
194  *(matrix_b + 2 * in_b_stride),
195  *(matrix_b + 3 * in_b_stride),
196  };
197 
198  vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
199 
200  matrix_b += 4 * in_b_stride;
201  }
202 
203  float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
204 
205  for(; vec_a < vec_a_end_addr; ++vec_a)
206  {
207  const float16_t a0 = *vec_a;
208  const float16_t b00 = *matrix_b;
209 
210  acc += b00 * a0;
211 
212  matrix_b += in_b_stride;
213  }
214 
215  // Multiply by the weight of matrix product (alpha)
216  if(multiply_alpha)
217  {
218  acc *= static_cast<float16_t>(alpha);
219  }
220 
221  auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
222 
223  *(vec_out) = acc;
224  }
225  },
226  ina, inb, out);
227 }
228 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
229 
230 void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
231 {
232  const auto width_matrix_b = static_cast<int>(dst->info()->dimension(0));
233  const auto in_b_stride = static_cast<int>(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()));
234  const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
235 
236  // The implementation computes 16 elements per iteration
237  const int window_start_x = 16 * info.thread_id;
238  const int window_step_x = 16 * info.num_threads;
239  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
240  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
241 
242  Window win_out(window);
243  win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
244  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
245 
246  Window win_a(window);
247  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
248  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
249 
250  Window win_b;
251  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
252  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
253  if(rhs->info()->num_dimensions() >= 3)
254  {
255  win_b = window;
256  }
257  win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
258  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
259 
260  Iterator ina(lhs, win_a);
261  Iterator inb(rhs, win_b);
262  Iterator out(dst, win_out);
263 
264  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
265 
266  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
267 
268  execute_window_loop(win_out, [&](const Coordinates &)
269  {
270  int x = window_start_x;
271  // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
272  // window_end_x is computed above which may cause out-of-bound writes to the dst.
273  for(; x < (window_end_x - window_step_x); x += window_step_x)
274  {
275  if(x > width_matrix_b)
276  {
277  return;
278  }
279 
280  float32x4_t acc0 = vdupq_n_f32(0.f);
281  float32x4_t acc1 = vdupq_n_f32(0.f);
282  float32x4_t acc2 = vdupq_n_f32(0.f);
283  float32x4_t acc3 = vdupq_n_f32(0.f);
284 
285  auto vec_a = reinterpret_cast<const float *>(ina.ptr());
286  auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
287 
288 #if __arm__
289  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
290  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
291  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
292 #endif /* __arm__ */
293 
294  auto vec_a_end_addr = vec_a + num_elems_vec_a;
295  for(; vec_a <= (vec_a_end_addr - 4);)
296  {
297  float32x2_t a0l = vld1_f32(vec_a);
298 
299  float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
300  float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
301  float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
302  float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
303 
304  float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
305  float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
306  float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
307  float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
308 
309 #if __arm__
310  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
311  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
312  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
313  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
314  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
315 #endif /* __arm__ */
316 
317  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
318  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
319  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
320  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
321 
322  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
323  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
324  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
325  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
326 
327  vec_a += 2;
328  matrix_b += 2 * in_b_stride;
329 
330  a0l = vld1_f32(vec_a);
331 
332  b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
333  b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
334  b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
335  b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
336 
337  b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
338  b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
339  b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
340  b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
341 
342  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
343  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
344  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
345  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
346 
347  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
348  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
349  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
350  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
351 
352  vec_a += 2;
353  matrix_b += 2 * in_b_stride;
354  }
355 
356  for(; vec_a < vec_a_end_addr; ++vec_a)
357  {
358  const float a0 = *vec_a;
359 
360  const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
361  const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
362  const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
363  const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
364 
365  acc0 = vmlaq_n_f32(acc0, b00, a0);
366  acc1 = vmlaq_n_f32(acc1, b01, a0);
367  acc2 = vmlaq_n_f32(acc2, b02, a0);
368  acc3 = vmlaq_n_f32(acc3, b03, a0);
369 
370  matrix_b += in_b_stride;
371  }
372 
373  // Multiply by the weight of matrix product (alpha)
374  if(multiply_alpha)
375  {
376  acc0 = vmulq_f32(acc0, alpha_f32);
377  acc1 = vmulq_f32(acc1, alpha_f32);
378  acc2 = vmulq_f32(acc2, alpha_f32);
379  acc3 = vmulq_f32(acc3, alpha_f32);
380  }
381 
382  const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
383 
384  vst1q_f32(vec_out + 0, acc0);
385  vst1q_f32(vec_out + 4, acc1);
386  vst1q_f32(vec_out + 8, acc2);
387  vst1q_f32(vec_out + 12, acc3);
388  }
389 
390  // Left-over loop
391  for(; x < window_end_x; ++x)
392  {
393  if(x > width_matrix_b)
394  {
395  return;
396  }
397 
398  float32x4_t vacc = vdupq_n_f32(0.f);
399 
400  auto vec_a = reinterpret_cast<const float *>(ina.ptr());
401  auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
402 
403 #if __arm__
404  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
405  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
406  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
407 #endif /* __arm__ */
408 
409  auto vec_a_end_addr = vec_a + num_elems_vec_a;
410  for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
411  {
412  const float32x4_t a0l = vld1q_f32(vec_a);
413 
414  const float32x4_t b_col =
415  {
416  *(matrix_b + 0 * in_b_stride),
417  *(matrix_b + 1 * in_b_stride),
418  *(matrix_b + 2 * in_b_stride),
419  *(matrix_b + 3 * in_b_stride),
420  };
421 
422 #if __arm__
423  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
424  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
425  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
426  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
427  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
428 #endif /* __arm__ */
429 
430  vacc = vmlaq_f32(vacc, b_col, a0l);
431 
432  matrix_b += 4 * in_b_stride;
433  }
434 
435  float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
436 
437  for(; vec_a < vec_a_end_addr; ++vec_a)
438  {
439  const float a0 = *vec_a;
440 
441  const float b00 = *matrix_b;
442 
443  acc += b00 * a0;
444 
445  matrix_b += in_b_stride;
446  }
447 
448  // Multiply by the weight of matrix product (alpha)
449  if(multiply_alpha)
450  {
451  acc *= alpha;
452  }
453 
454  const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
455 
456  *vec_out = acc;
457  }
458  },
459  ina, inb, out);
460 }
461 
462 void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
463 {
464  ARM_COMPUTE_UNUSED(info);
465  const int out_width = static_cast<int>(dst->info()->dimension(0));
466  const int out_height = static_cast<int>(dst->info()->dimension(1));
467  const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
468  const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
469  const size_t out_stride2 = out_stride1 * 2;
470  const size_t out_stride3 = out_stride1 * 3;
471  const int num_elems_matrix_b_x = rhs->info()->dimension(0);
472 
473  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
474  Window win_a(window);
475  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
476  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
477 
478  Window win_b;
479  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
480  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
481  if(rhs->info()->num_dimensions() >= 3)
482  {
483  win_b = window;
484  }
485  // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix
486  // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
487  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
488  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
489 
490  Iterator ina(lhs, win_a);
491  Iterator inb(rhs, win_b);
492  Iterator out(dst, window);
493 
494  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
495 
496  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
497 
498  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
499  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
500  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
501  execute_window_loop(window, [&](const Coordinates & id)
502  {
503  auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
504  auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
505  auto mtx_b1 = mtx_b0 + in_b_stride;
506 
507  float32x4_t acc00 = vdupq_n_f32(0.f);
508  float32x4_t acc10 = vdupq_n_f32(0.f);
509  float32x4_t acc20 = vdupq_n_f32(0.f);
510  float32x4_t acc30 = vdupq_n_f32(0.f);
511 
512  float32x4_t acc01 = vdupq_n_f32(0.f);
513  float32x4_t acc11 = vdupq_n_f32(0.f);
514  float32x4_t acc21 = vdupq_n_f32(0.f);
515  float32x4_t acc31 = vdupq_n_f32(0.f);
516 
517 #if __arm__
518  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
519  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
520  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
521 #endif /* __arm__ */
522 
523  auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
524  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
525  {
526  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
527  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
528  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
529  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
530 
531  float32x4_t b00 = vld1q_f32(mtx_b0);
532  float32x4_t b10 = vld1q_f32(mtx_b1);
533  float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
534  float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
535 
536 #if __arm__
537  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
538  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
539  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
540 #endif /* __arm__ */
541 
542  // 4x4 block 0
543  acc00 = vmlaq_f32(acc00, b00, a0);
544  acc10 = vmlaq_f32(acc10, b00, a1);
545  acc20 = vmlaq_f32(acc20, b00, a2);
546  acc30 = vmlaq_f32(acc30, b00, a3);
547 
548  float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
549  float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
550  float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
551  float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
552 
553  // 4x4 block 1
554  acc01 = vmlaq_f32(acc01, b10, a0);
555  acc11 = vmlaq_f32(acc11, b10, a1);
556  acc21 = vmlaq_f32(acc21, b10, a2);
557  acc31 = vmlaq_f32(acc31, b10, a3);
558 
559  // 4x4 block 0
560  acc00 = vmlaq_f32(acc00, b01, a4);
561  acc10 = vmlaq_f32(acc10, b01, a5);
562  acc20 = vmlaq_f32(acc20, b01, a6);
563  acc30 = vmlaq_f32(acc30, b01, a7);
564 
565  // 4x4 block 1
566  acc01 = vmlaq_f32(acc01, b11, a4);
567  acc11 = vmlaq_f32(acc11, b11, a5);
568  acc21 = vmlaq_f32(acc21, b11, a6);
569  acc31 = vmlaq_f32(acc31, b11, a7);
570 
571  mtx_a0 += 8;
572  mtx_b0 += 8;
573  mtx_b1 += 8;
574 
575  a0 = vld1q_dup_f32(mtx_a0 + 0);
576  a1 = vld1q_dup_f32(mtx_a0 + 1);
577  a2 = vld1q_dup_f32(mtx_a0 + 2);
578  a3 = vld1q_dup_f32(mtx_a0 + 3);
579 
580  b00 = vld1q_f32(mtx_b0);
581  b10 = vld1q_f32(mtx_b1);
582  b01 = vld1q_f32(mtx_b0 + 4);
583  b11 = vld1q_f32(mtx_b1 + 4);
584 
585  // 4x4 block 0
586  acc00 = vmlaq_f32(acc00, b00, a0);
587  acc10 = vmlaq_f32(acc10, b00, a1);
588  acc20 = vmlaq_f32(acc20, b00, a2);
589  acc30 = vmlaq_f32(acc30, b00, a3);
590 
591  a4 = vld1q_dup_f32(mtx_a0 + 4);
592  a5 = vld1q_dup_f32(mtx_a0 + 5);
593  a6 = vld1q_dup_f32(mtx_a0 + 6);
594  a7 = vld1q_dup_f32(mtx_a0 + 7);
595 
596  // 4x4 block 1
597  acc01 = vmlaq_f32(acc01, b10, a0);
598  acc11 = vmlaq_f32(acc11, b10, a1);
599  acc21 = vmlaq_f32(acc21, b10, a2);
600  acc31 = vmlaq_f32(acc31, b10, a3);
601 
602  // 4x4 block 0
603  acc00 = vmlaq_f32(acc00, b01, a4);
604  acc10 = vmlaq_f32(acc10, b01, a5);
605  acc20 = vmlaq_f32(acc20, b01, a6);
606  acc30 = vmlaq_f32(acc30, b01, a7);
607 
608  // 4x4 block 1
609  acc01 = vmlaq_f32(acc01, b11, a4);
610  acc11 = vmlaq_f32(acc11, b11, a5);
611  acc21 = vmlaq_f32(acc21, b11, a6);
612  acc31 = vmlaq_f32(acc31, b11, a7);
613 
614  mtx_a0 += 8;
615  mtx_b0 += 8;
616  mtx_b1 += 8;
617 
618  a0 = vld1q_dup_f32(mtx_a0 + 0);
619  a1 = vld1q_dup_f32(mtx_a0 + 1);
620  a2 = vld1q_dup_f32(mtx_a0 + 2);
621  a3 = vld1q_dup_f32(mtx_a0 + 3);
622  b00 = vld1q_f32(mtx_b0);
623  b10 = vld1q_f32(mtx_b1);
624  b01 = vld1q_f32(mtx_b0 + 4);
625  b11 = vld1q_f32(mtx_b1 + 4);
626 
627 #if __arm__
628  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
629  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
630  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
631 #endif /* __arm__ */
632 
633  // 4x4 block 0
634  acc00 = vmlaq_f32(acc00, b00, a0);
635  acc10 = vmlaq_f32(acc10, b00, a1);
636  acc20 = vmlaq_f32(acc20, b00, a2);
637  acc30 = vmlaq_f32(acc30, b00, a3);
638 
639  a4 = vld1q_dup_f32(mtx_a0 + 4);
640  a5 = vld1q_dup_f32(mtx_a0 + 5);
641  a6 = vld1q_dup_f32(mtx_a0 + 6);
642  a7 = vld1q_dup_f32(mtx_a0 + 7);
643 
644  // 4x4 block 1
645  acc01 = vmlaq_f32(acc01, b10, a0);
646  acc11 = vmlaq_f32(acc11, b10, a1);
647  acc21 = vmlaq_f32(acc21, b10, a2);
648  acc31 = vmlaq_f32(acc31, b10, a3);
649 
650  // 4x4 block 0
651  acc00 = vmlaq_f32(acc00, b01, a4);
652  acc10 = vmlaq_f32(acc10, b01, a5);
653  acc20 = vmlaq_f32(acc20, b01, a6);
654  acc30 = vmlaq_f32(acc30, b01, a7);
655 
656  // 4x4 block 1
657  acc01 = vmlaq_f32(acc01, b11, a4);
658  acc11 = vmlaq_f32(acc11, b11, a5);
659  acc21 = vmlaq_f32(acc21, b11, a6);
660  acc31 = vmlaq_f32(acc31, b11, a7);
661 
662  mtx_a0 += 8;
663  mtx_b0 += 8;
664  mtx_b1 += 8;
665 
666  a0 = vld1q_dup_f32(mtx_a0 + 0);
667  a1 = vld1q_dup_f32(mtx_a0 + 1);
668  a2 = vld1q_dup_f32(mtx_a0 + 2);
669  a3 = vld1q_dup_f32(mtx_a0 + 3);
670  b00 = vld1q_f32(mtx_b0);
671  b10 = vld1q_f32(mtx_b1);
672  b01 = vld1q_f32(mtx_b0 + 4);
673  b11 = vld1q_f32(mtx_b1 + 4);
674 
675  // 4x4 block 0
676  acc00 = vmlaq_f32(acc00, b00, a0);
677  acc10 = vmlaq_f32(acc10, b00, a1);
678  acc20 = vmlaq_f32(acc20, b00, a2);
679  acc30 = vmlaq_f32(acc30, b00, a3);
680 
681  a4 = vld1q_dup_f32(mtx_a0 + 4);
682  a5 = vld1q_dup_f32(mtx_a0 + 5);
683  a6 = vld1q_dup_f32(mtx_a0 + 6);
684  a7 = vld1q_dup_f32(mtx_a0 + 7);
685 
686  // 4x4 block 1
687  acc01 = vmlaq_f32(acc01, b10, a0);
688  acc11 = vmlaq_f32(acc11, b10, a1);
689  acc21 = vmlaq_f32(acc21, b10, a2);
690  acc31 = vmlaq_f32(acc31, b10, a3);
691 
692  // 4x4 block 0
693  acc00 = vmlaq_f32(acc00, b01, a4);
694  acc10 = vmlaq_f32(acc10, b01, a5);
695  acc20 = vmlaq_f32(acc20, b01, a6);
696  acc30 = vmlaq_f32(acc30, b01, a7);
697 
698  // 4x4 block 1
699  acc01 = vmlaq_f32(acc01, b11, a4);
700  acc11 = vmlaq_f32(acc11, b11, a5);
701  acc21 = vmlaq_f32(acc21, b11, a6);
702  acc31 = vmlaq_f32(acc31, b11, a7);
703 
704  mtx_a0 += 8;
705  mtx_b0 += 8;
706  mtx_b1 += 8;
707  }
708 
709  for(; mtx_b0 < mtx_b0_end_addr;)
710  {
711  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
712  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
713  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
714  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
715  float32x4_t b00 = vld1q_f32(mtx_b0);
716  float32x4_t b10 = vld1q_f32(mtx_b1);
717 
718 #if __arm__
719  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
720  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
721  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
722 #endif /* __arm__ */
723  // 4x4 block 0
724  acc00 = vmlaq_f32(acc00, b00, a0);
725  acc10 = vmlaq_f32(acc10, b00, a1);
726  acc20 = vmlaq_f32(acc20, b00, a2);
727  acc30 = vmlaq_f32(acc30, b00, a3);
728 
729  // 4x4 block 1
730  acc01 = vmlaq_f32(acc01, b10, a0);
731  acc11 = vmlaq_f32(acc11, b10, a1);
732  acc21 = vmlaq_f32(acc21, b10, a2);
733  acc31 = vmlaq_f32(acc31, b10, a3);
734 
735  mtx_a0 += 4;
736  mtx_b0 += 4;
737  mtx_b1 += 4;
738  }
739 
740  // Multiply by the weight of matrix product (alpha)
741  if(multiply_alpha)
742  {
743  acc00 = vmulq_f32(acc00, alpha_f32);
744  acc10 = vmulq_f32(acc10, alpha_f32);
745  acc20 = vmulq_f32(acc20, alpha_f32);
746  acc30 = vmulq_f32(acc30, alpha_f32);
747  acc01 = vmulq_f32(acc01, alpha_f32);
748  acc11 = vmulq_f32(acc11, alpha_f32);
749  acc21 = vmulq_f32(acc21, alpha_f32);
750  acc31 = vmulq_f32(acc31, alpha_f32);
751  }
752 
753  const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
754  const auto mtx_out1 = mtx_out0 + 4;
755 
756  if(id.x() < (out_width - 8))
757  {
758  vst1q_f32(mtx_out0, acc00);
759  vst1q_f32(mtx_out1, acc01);
760  if(id.y() + 1 < out_height)
761  {
762  vst1q_f32(mtx_out0 + out_stride1, acc10);
763  vst1q_f32(mtx_out1 + out_stride1, acc11);
764  if(id.y() + 2 < out_height)
765  {
766  vst1q_f32(mtx_out0 + out_stride2, acc20);
767  vst1q_f32(mtx_out1 + out_stride2, acc21);
768  if(id.y() + 3 < out_height)
769  {
770  vst1q_f32(mtx_out0 + out_stride3, acc30);
771  vst1q_f32(mtx_out1 + out_stride3, acc31);
772  }
773  }
774  }
775  }
776  else if(id.x() < (out_width - 4))
777  {
778  vst1q_f32(mtx_out0, acc00);
779  if(id.y() + 1 < out_height)
780  {
781  vst1q_f32(mtx_out0 + out_stride1, acc10);
782  if(id.y() + 2 < out_height)
783  {
784  vst1q_f32(mtx_out0 + out_stride2, acc20);
785  if(id.y() + 3 < out_height)
786  {
787  vst1q_f32(mtx_out0 + out_stride3, acc30);
788  }
789  }
790  }
791  // Left-over columns
792  const int columns_left = out_width - id.x() - 4;
793  for(auto x = 0; x < columns_left; ++x)
794  {
795  *(mtx_out1 + x) = acc01[x];
796  if(id.y() + 1 < out_height)
797  {
798  *(mtx_out1 + x + out_stride1) = acc11[x];
799  if(id.y() + 2 < out_height)
800  {
801  *(mtx_out1 + x + out_stride2) = acc21[x];
802  if(id.y() + 3 < out_height)
803  {
804  *(mtx_out1 + x + out_stride3) = acc31[x];
805  }
806  }
807  }
808  }
809  }
810  else
811  {
812  // Left-over columns
813  const int columns_left = out_width - id.x();
814  for(int x = 0; x < columns_left; ++x)
815  {
816  *(mtx_out0 + x) = acc00[x];
817  if(id.y() + 1 < out_height)
818  {
819  *(mtx_out0 + x + out_stride1) = acc10[x];
820  if(id.y() + 2 < out_height)
821  {
822  *(mtx_out0 + x + out_stride2) = acc20[x];
823  if(id.y() + 3 < out_height)
824  {
825  *(mtx_out0 + x + out_stride3) = acc30[x];
826  }
827  }
828  }
829  }
830  }
831  },
832  ina, inb, out);
833 }
834 
835 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
836 void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
837 {
838  ARM_COMPUTE_UNUSED(info);
839  const int out_width = static_cast<int>(dst->info()->dimension(0));
840  const int out_height = static_cast<int>(dst->info()->dimension(1));
841  const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
842  const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
843  const int num_elems_matrix_b_x = rhs->info()->dimension(0);
844 
845  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
846  Window win_a(window);
847  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
848  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
849 
850  Window win_b;
851  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
852  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
853  if(rhs->info()->num_dimensions() >= 3)
854  {
855  win_b = window;
856  }
857  // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix
858  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
859  win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
860 
861  Iterator ina(lhs, win_a);
862  Iterator inb(rhs, win_b);
863  Iterator out(dst, window);
864 
865  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
866 
867  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
868 
869  execute_window_loop(window, [&](const Coordinates & id)
870  {
871  const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
872  const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
873  auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
874  float16x8x4_t c =
875  {
876  {
877  vdupq_n_f16(0.f),
878  vdupq_n_f16(0.f),
879  vdupq_n_f16(0.f),
880  vdupq_n_f16(0.f)
881  }
882  };
883 
884  /*
885  This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
886  |a00 a01 a02 a03 | a04 a05 a06 a07|
887  |a10 a11 a12 a13 | a14 a15 a16 a17|
888  |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
889  |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
890  |a40 a41 a42 a43 | a44 a45 a46 a47|
891  |a50 a51 a52 a53 | a54 a55 a56 a57|
892  |a60 a61 a62 a63 | a64 a65 a66 a67|
893  |a70 a71 a72 a73 | a74 a75 a76 a77|
894 
895  After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ]
896 
897  B Matrix has been transposed as shown below
898 
899  |b00 b01 b02 b03 b04 b05 b06 b07|
900  |b10 b11 b12 b13 b14 b15 b16 b17|
901  |b20 b21 b22 b23 b24 b25 b26 b27|
902  |b30 b31 b32 b33 b34 b35 b36 b37|
903  ------------------->
904 
905  |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
906 
907  c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
908  c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
909 
910  The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
911  */
912  const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
913 
914  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
915 
916  {
917  const float16x8_t p00 = vld1q_f16(mtx_a0);
918  const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
919 
920  const float16x8_t q00 = vld1q_f16(mtx_b0);
921  const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
922  const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
923  const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
924 
925  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
926  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
927  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
928  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
929 
930  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
931  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
932  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
933  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
934 
935  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
936  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
937  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
938  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
939 
940  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
941  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
942  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
943  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
944 
945  mtx_a0 += 16;
946  mtx_b0 += 32;
947  }
948 
949  for(; mtx_b0 < mtx_b0_end_addr;)
950 
951  {
952  const float16x4_t p00 = vld1_f16(mtx_a0);
953  const float16x8_t q00 = vld1q_f16(mtx_b0);
954 
955  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
956  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
957  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
958  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
959 
960  mtx_a0 += 4;
961  mtx_b0 += 8;
962  }
963 
964  if(multiply_alpha)
965  {
966  c.val[0] = vmulq_f16(c.val[0], alpha_f16);
967  c.val[1] = vmulq_f16(c.val[1], alpha_f16);
968  c.val[2] = vmulq_f16(c.val[2], alpha_f16);
969  c.val[3] = vmulq_f16(c.val[3], alpha_f16);
970  }
971 
972  if(id.x() < (out_width - 8))
973  {
974  vst1q_f16(mtx_out, c.val[0]);
975  if(id.y() + 1 < out_height)
976  {
977  vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
978  if(id.y() + 2 < out_height)
979  {
980  vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
981  if(id.y() + 3 < out_height)
982  {
983  vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
984  }
985  }
986  }
987  }
988  else
989  {
990  // Left-over columns
991  const int columns_left = out_width - id.x();
992  for(int x = 0; x < columns_left; ++x)
993  {
994  *(mtx_out + x) = c.val[0][x];
995  if(id.y() + 1 < out_height)
996  {
997  *(mtx_out + x + 1 * out_stride) = c.val[1][x];
998  if(id.y() + 2 < out_height)
999  {
1000  *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1001  if(id.y() + 3 < out_height)
1002  {
1003  *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1004  }
1005  }
1006  }
1007  }
1008  }
1009  },
1010  ina, inb, out);
1011 }
1012 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1013 
1014 } // namespace cpu
1015 
1016 } // namespace arm_compute
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Definition: float_ops.h:97
void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Definition: impl.cpp:230
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:78
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:68
Describe one of the image&#39;s dimensions with a start, end and step.
Definition: Window.h:79
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Definition: clang-tidy.h:33
Interface for CPU tensor.
Definition: ITensor.h:36
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
Definition: clang-tidy.h:83
Copyright (c) 2017-2022 Arm Limited.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:71
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
Coordinates of an item.
Definition: Coordinates.h:37
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:106
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
Definition: Helpers.inl:139
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
float16x4_t vmul_f16(float16x4_t, float16x4_t)
Definition: clang-tidy.h:28
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
Definition: CPPTypes.h:169
void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
Definition: impl.cpp:462
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
Definition: Window.h:167
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
Definition: clang-tidy.h:23
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
constexpr int end() const
Return the end of the dimension.
Definition: Window.h:101
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:46
constexpr int start() const
Return the start of the dimension.
Definition: Window.h:96
Describe a multidimensional execution window.
Definition: Window.h:39
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
Definition: Window.h:158