Compute Library
 21.02
NEGEMMMatrixMultiplyKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Utils.h"
36 #include "src/core/CPP/Validate.h"
41 
42 #include <arm_neon.h>
43 
44 namespace arm_compute
45 {
46 namespace
47 {
48 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
49 void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
50 {
51  const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
52  const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / input1->info()->element_size());
53  const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
54 
55  // The implementation computes 32 elements per iteration
56  const int window_start_x = 32 * info.thread_id;
57  const int window_step_x = 32 * info.num_threads;
58  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
59  ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
60 
61  Window win_out(window);
62  win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
63  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
64 
65  Window win_a(window);
66  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
67  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
68 
69  Window win_b;
70  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
71  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
72  if(input1->info()->num_dimensions() >= 3)
73  {
74  win_b = window;
75  }
76  win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
77  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
78 
79  Iterator ina(input0, win_a);
80  Iterator inb(input1, win_b);
81  Iterator out(output, win_out);
82 
83  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
84 
85  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
86 
87  execute_window_loop(win_out, [&](const Coordinates &)
88  {
89  int x = window_start_x;
90  // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
91  // window_end_x is computed above which may cause out-of-bound writes to the output.
92  for(; x < (window_end_x - window_step_x); x += window_step_x)
93  {
94  if(x > width_matrix_b)
95  {
96  return;
97  }
98 
99  auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
100 
101  float16x8_t acc0 = vdupq_n_f16(0.f);
102  float16x8_t acc1 = vdupq_n_f16(0.f);
103  float16x8_t acc2 = vdupq_n_f16(0.f);
104  float16x8_t acc3 = vdupq_n_f16(0.f);
105 
106  auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
107  const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
108  for(; vec_a <= (vec_a_end_addr - 4);)
109  {
110  const float16x4_t a0l = vld1_f16(vec_a);
111 
112  float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
113  float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
114  float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
115  float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
116  float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
117  float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
118  float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
119  float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
120 
121  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
122  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
123  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
124  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
125  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
126  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
127  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
128  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
129 
130  matrix_b += 2 * in_b_stride;
131 
132  b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
133  b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
134  b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
135  b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
136  b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
137  b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
138  b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
139  b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
140 
141  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
142  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
143  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
144  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
145  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
146  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
147  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
148  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
149 
150  vec_a += 4;
151  matrix_b += 2 * in_b_stride;
152  }
153 
154  for(; vec_a < vec_a_end_addr; ++vec_a)
155  {
156  const float16_t a0 = *vec_a;
157  const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
158  const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
159  const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
160  const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
161 
162  acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
163  acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
164  acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
165  acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
166 
167  matrix_b += in_b_stride;
168  }
169 
170  // Multiply by the weight of matrix product (alpha)
171  if(multiply_alpha)
172  {
173  acc0 = vmulq_f16(acc0, alpha_f16);
174  acc1 = vmulq_f16(acc1, alpha_f16);
175  acc2 = vmulq_f16(acc2, alpha_f16);
176  acc3 = vmulq_f16(acc3, alpha_f16);
177  }
178 
179  auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
180 
181  vst1q_f16(vec_out + 0, acc0);
182  vst1q_f16(vec_out + 8, acc1);
183  vst1q_f16(vec_out + 16, acc2);
184  vst1q_f16(vec_out + 24, acc3);
185  }
186 
187  for(; x < window_end_x; ++x)
188  {
189  if(x > width_matrix_b)
190  {
191  return;
192  }
193 
194  auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
195 
196  float16x4_t vacc = vdup_n_f16(0.f);
197 
198  auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
199  const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
200  for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
201  {
202  const float16x4_t a0l = vld1_f16(vec_a);
203 
204  const float16x4_t b_col =
205  {
206  *(matrix_b + 0 * in_b_stride),
207  *(matrix_b + 1 * in_b_stride),
208  *(matrix_b + 2 * in_b_stride),
209  *(matrix_b + 3 * in_b_stride),
210  };
211 
212  vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
213 
214  matrix_b += 4 * in_b_stride;
215  }
216 
217  float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
218 
219  for(; vec_a < vec_a_end_addr; ++vec_a)
220  {
221  const float16_t a0 = *vec_a;
222  const float16_t b00 = *matrix_b;
223 
224  acc += b00 * a0;
225 
226  matrix_b += in_b_stride;
227  }
228 
229  // Multiply by the weight of matrix product (alpha)
230  if(multiply_alpha)
231  {
232  acc *= static_cast<float16_t>(alpha);
233  }
234 
235  auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
236 
237  *(vec_out) = acc;
238  }
239  },
240  ina, inb, out);
241 }
242 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
243 
244 void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
245 {
246  const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
247  const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
248  const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
249 
250  // The implementation computes 16 elements per iteration
251  const int window_start_x = 16 * info.thread_id;
252  const int window_step_x = 16 * info.num_threads;
253  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
254  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
255 
256  Window win_out(window);
257  win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
258  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
259 
260  Window win_a(window);
261  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
262  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
263 
264  Window win_b;
265  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
266  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
267  if(input1->info()->num_dimensions() >= 3)
268  {
269  win_b = window;
270  }
271  win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
272  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
273 
274  Iterator ina(input0, win_a);
275  Iterator inb(input1, win_b);
276  Iterator out(output, win_out);
277 
278  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
279 
280  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
281 
282  execute_window_loop(win_out, [&](const Coordinates &)
283  {
284  int x = window_start_x;
285  // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
286  // window_end_x is computed above which may cause out-of-bound writes to the output.
287  for(; x < (window_end_x - window_step_x); x += window_step_x)
288  {
289  if(x > width_matrix_b)
290  {
291  return;
292  }
293 
294  float32x4_t acc0 = vdupq_n_f32(0.f);
295  float32x4_t acc1 = vdupq_n_f32(0.f);
296  float32x4_t acc2 = vdupq_n_f32(0.f);
297  float32x4_t acc3 = vdupq_n_f32(0.f);
298 
299  auto vec_a = reinterpret_cast<const float *>(ina.ptr());
300  auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
301 
302 #if __arm__
303  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
304  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
305  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
306 #endif /* __arm__ */
307 
308  auto vec_a_end_addr = vec_a + num_elems_vec_a;
309  for(; vec_a <= (vec_a_end_addr - 4);)
310  {
311  float32x2_t a0l = vld1_f32(vec_a);
312 
313  float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
314  float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
315  float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
316  float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
317 
318  float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
319  float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
320  float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
321  float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
322 
323 #if __arm__
324  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
325  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
326  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
327  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
328  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
329 #endif /* __arm__ */
330 
331  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
332  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
333  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
334  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
335 
336  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
337  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
338  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
339  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
340 
341  vec_a += 2;
342  matrix_b += 2 * in_b_stride;
343 
344  a0l = vld1_f32(vec_a);
345 
346  b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
347  b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
348  b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
349  b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
350 
351  b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
352  b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
353  b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
354  b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
355 
356  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
357  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
358  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
359  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
360 
361  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
362  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
363  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
364  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
365 
366  vec_a += 2;
367  matrix_b += 2 * in_b_stride;
368  }
369 
370  for(; vec_a < vec_a_end_addr; ++vec_a)
371  {
372  const float a0 = *vec_a;
373 
374  const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
375  const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
376  const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
377  const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
378 
379  acc0 = vmlaq_n_f32(acc0, b00, a0);
380  acc1 = vmlaq_n_f32(acc1, b01, a0);
381  acc2 = vmlaq_n_f32(acc2, b02, a0);
382  acc3 = vmlaq_n_f32(acc3, b03, a0);
383 
384  matrix_b += in_b_stride;
385  }
386 
387  // Multiply by the weight of matrix product (alpha)
388  if(multiply_alpha)
389  {
390  acc0 = vmulq_f32(acc0, alpha_f32);
391  acc1 = vmulq_f32(acc1, alpha_f32);
392  acc2 = vmulq_f32(acc2, alpha_f32);
393  acc3 = vmulq_f32(acc3, alpha_f32);
394  }
395 
396  const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
397 
398  vst1q_f32(vec_out + 0, acc0);
399  vst1q_f32(vec_out + 4, acc1);
400  vst1q_f32(vec_out + 8, acc2);
401  vst1q_f32(vec_out + 12, acc3);
402  }
403 
404  // Left-over loop
405  for(; x < window_end_x; ++x)
406  {
407  if(x > width_matrix_b)
408  {
409  return;
410  }
411 
412  float32x4_t vacc = vdupq_n_f32(0.f);
413 
414  auto vec_a = reinterpret_cast<const float *>(ina.ptr());
415  auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
416 
417 #if __arm__
418  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
419  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
420  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
421 #endif /* __arm__ */
422 
423  auto vec_a_end_addr = vec_a + num_elems_vec_a;
424  for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
425  {
426  const float32x4_t a0l = vld1q_f32(vec_a);
427 
428  const float32x4_t b_col =
429  {
430  *(matrix_b + 0 * in_b_stride),
431  *(matrix_b + 1 * in_b_stride),
432  *(matrix_b + 2 * in_b_stride),
433  *(matrix_b + 3 * in_b_stride),
434  };
435 
436 #if __arm__
437  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
438  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
439  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
440  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
441  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
442 #endif /* __arm__ */
443 
444  vacc = vmlaq_f32(vacc, b_col, a0l);
445 
446  matrix_b += 4 * in_b_stride;
447  }
448 
449  float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
450 
451  for(; vec_a < vec_a_end_addr; ++vec_a)
452  {
453  const float a0 = *vec_a;
454 
455  const float b00 = *matrix_b;
456 
457  acc += b00 * a0;
458 
459  matrix_b += in_b_stride;
460  }
461 
462  // Multiply by the weight of matrix product (alpha)
463  if(multiply_alpha)
464  {
465  acc *= alpha;
466  }
467 
468  const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
469 
470  *vec_out = acc;
471  }
472  },
473  ina, inb, out);
474 }
475 
476 void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
477 {
478  const int out_width = static_cast<int>(output->info()->dimension(0));
479  const int out_height = static_cast<int>(output->info()->dimension(1));
480  const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
481  const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
482  const size_t out_stride2 = out_stride1 * 2;
483  const size_t out_stride3 = out_stride1 * 3;
484  const int num_elems_matrix_b_x = input1->info()->dimension(0);
485 
486  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
487  Window win_a(window);
488  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
489  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
490 
491  Window win_b;
492  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
493  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
494  if(input1->info()->num_dimensions() >= 3)
495  {
496  win_b = window;
497  }
498  // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
499  // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
500  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
501  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
502 
503  Iterator ina(input0, win_a);
504  Iterator inb(input1, win_b);
505  Iterator out(output, window);
506 
507  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
508 
509  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
510 
511  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
512  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
513  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
514  execute_window_loop(window, [&](const Coordinates & id)
515  {
516  auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
517  auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
518  auto mtx_b1 = mtx_b0 + in_b_stride;
519 
520  float32x4_t acc00 = vdupq_n_f32(0.f);
521  float32x4_t acc10 = vdupq_n_f32(0.f);
522  float32x4_t acc20 = vdupq_n_f32(0.f);
523  float32x4_t acc30 = vdupq_n_f32(0.f);
524 
525  float32x4_t acc01 = vdupq_n_f32(0.f);
526  float32x4_t acc11 = vdupq_n_f32(0.f);
527  float32x4_t acc21 = vdupq_n_f32(0.f);
528  float32x4_t acc31 = vdupq_n_f32(0.f);
529 
530 #if __arm__
531  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
532  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
533  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
534 #endif /* __arm__ */
535 
536  auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
537  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
538  {
539  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
540  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
541  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
542  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
543 
544  float32x4_t b00 = vld1q_f32(mtx_b0);
545  float32x4_t b10 = vld1q_f32(mtx_b1);
546  float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
547  float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
548 
549 #if __arm__
550  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
551  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
552  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
553 #endif /* __arm__ */
554 
555  // 4x4 block 0
556  acc00 = vmlaq_f32(acc00, b00, a0);
557  acc10 = vmlaq_f32(acc10, b00, a1);
558  acc20 = vmlaq_f32(acc20, b00, a2);
559  acc30 = vmlaq_f32(acc30, b00, a3);
560 
561  float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
562  float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
563  float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
564  float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
565 
566  // 4x4 block 1
567  acc01 = vmlaq_f32(acc01, b10, a0);
568  acc11 = vmlaq_f32(acc11, b10, a1);
569  acc21 = vmlaq_f32(acc21, b10, a2);
570  acc31 = vmlaq_f32(acc31, b10, a3);
571 
572  // 4x4 block 0
573  acc00 = vmlaq_f32(acc00, b01, a4);
574  acc10 = vmlaq_f32(acc10, b01, a5);
575  acc20 = vmlaq_f32(acc20, b01, a6);
576  acc30 = vmlaq_f32(acc30, b01, a7);
577 
578  // 4x4 block 1
579  acc01 = vmlaq_f32(acc01, b11, a4);
580  acc11 = vmlaq_f32(acc11, b11, a5);
581  acc21 = vmlaq_f32(acc21, b11, a6);
582  acc31 = vmlaq_f32(acc31, b11, a7);
583 
584  mtx_a0 += 8;
585  mtx_b0 += 8;
586  mtx_b1 += 8;
587 
588  a0 = vld1q_dup_f32(mtx_a0 + 0);
589  a1 = vld1q_dup_f32(mtx_a0 + 1);
590  a2 = vld1q_dup_f32(mtx_a0 + 2);
591  a3 = vld1q_dup_f32(mtx_a0 + 3);
592 
593  b00 = vld1q_f32(mtx_b0);
594  b10 = vld1q_f32(mtx_b1);
595  b01 = vld1q_f32(mtx_b0 + 4);
596  b11 = vld1q_f32(mtx_b1 + 4);
597 
598  // 4x4 block 0
599  acc00 = vmlaq_f32(acc00, b00, a0);
600  acc10 = vmlaq_f32(acc10, b00, a1);
601  acc20 = vmlaq_f32(acc20, b00, a2);
602  acc30 = vmlaq_f32(acc30, b00, a3);
603 
604  a4 = vld1q_dup_f32(mtx_a0 + 4);
605  a5 = vld1q_dup_f32(mtx_a0 + 5);
606  a6 = vld1q_dup_f32(mtx_a0 + 6);
607  a7 = vld1q_dup_f32(mtx_a0 + 7);
608 
609  // 4x4 block 1
610  acc01 = vmlaq_f32(acc01, b10, a0);
611  acc11 = vmlaq_f32(acc11, b10, a1);
612  acc21 = vmlaq_f32(acc21, b10, a2);
613  acc31 = vmlaq_f32(acc31, b10, a3);
614 
615  // 4x4 block 0
616  acc00 = vmlaq_f32(acc00, b01, a4);
617  acc10 = vmlaq_f32(acc10, b01, a5);
618  acc20 = vmlaq_f32(acc20, b01, a6);
619  acc30 = vmlaq_f32(acc30, b01, a7);
620 
621  // 4x4 block 1
622  acc01 = vmlaq_f32(acc01, b11, a4);
623  acc11 = vmlaq_f32(acc11, b11, a5);
624  acc21 = vmlaq_f32(acc21, b11, a6);
625  acc31 = vmlaq_f32(acc31, b11, a7);
626 
627  mtx_a0 += 8;
628  mtx_b0 += 8;
629  mtx_b1 += 8;
630 
631  a0 = vld1q_dup_f32(mtx_a0 + 0);
632  a1 = vld1q_dup_f32(mtx_a0 + 1);
633  a2 = vld1q_dup_f32(mtx_a0 + 2);
634  a3 = vld1q_dup_f32(mtx_a0 + 3);
635  b00 = vld1q_f32(mtx_b0);
636  b10 = vld1q_f32(mtx_b1);
637  b01 = vld1q_f32(mtx_b0 + 4);
638  b11 = vld1q_f32(mtx_b1 + 4);
639 
640 #if __arm__
641  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
642  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
643  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
644 #endif /* __arm__ */
645 
646  // 4x4 block 0
647  acc00 = vmlaq_f32(acc00, b00, a0);
648  acc10 = vmlaq_f32(acc10, b00, a1);
649  acc20 = vmlaq_f32(acc20, b00, a2);
650  acc30 = vmlaq_f32(acc30, b00, a3);
651 
652  a4 = vld1q_dup_f32(mtx_a0 + 4);
653  a5 = vld1q_dup_f32(mtx_a0 + 5);
654  a6 = vld1q_dup_f32(mtx_a0 + 6);
655  a7 = vld1q_dup_f32(mtx_a0 + 7);
656 
657  // 4x4 block 1
658  acc01 = vmlaq_f32(acc01, b10, a0);
659  acc11 = vmlaq_f32(acc11, b10, a1);
660  acc21 = vmlaq_f32(acc21, b10, a2);
661  acc31 = vmlaq_f32(acc31, b10, a3);
662 
663  // 4x4 block 0
664  acc00 = vmlaq_f32(acc00, b01, a4);
665  acc10 = vmlaq_f32(acc10, b01, a5);
666  acc20 = vmlaq_f32(acc20, b01, a6);
667  acc30 = vmlaq_f32(acc30, b01, a7);
668 
669  // 4x4 block 1
670  acc01 = vmlaq_f32(acc01, b11, a4);
671  acc11 = vmlaq_f32(acc11, b11, a5);
672  acc21 = vmlaq_f32(acc21, b11, a6);
673  acc31 = vmlaq_f32(acc31, b11, a7);
674 
675  mtx_a0 += 8;
676  mtx_b0 += 8;
677  mtx_b1 += 8;
678 
679  a0 = vld1q_dup_f32(mtx_a0 + 0);
680  a1 = vld1q_dup_f32(mtx_a0 + 1);
681  a2 = vld1q_dup_f32(mtx_a0 + 2);
682  a3 = vld1q_dup_f32(mtx_a0 + 3);
683  b00 = vld1q_f32(mtx_b0);
684  b10 = vld1q_f32(mtx_b1);
685  b01 = vld1q_f32(mtx_b0 + 4);
686  b11 = vld1q_f32(mtx_b1 + 4);
687 
688  // 4x4 block 0
689  acc00 = vmlaq_f32(acc00, b00, a0);
690  acc10 = vmlaq_f32(acc10, b00, a1);
691  acc20 = vmlaq_f32(acc20, b00, a2);
692  acc30 = vmlaq_f32(acc30, b00, a3);
693 
694  a4 = vld1q_dup_f32(mtx_a0 + 4);
695  a5 = vld1q_dup_f32(mtx_a0 + 5);
696  a6 = vld1q_dup_f32(mtx_a0 + 6);
697  a7 = vld1q_dup_f32(mtx_a0 + 7);
698 
699  // 4x4 block 1
700  acc01 = vmlaq_f32(acc01, b10, a0);
701  acc11 = vmlaq_f32(acc11, b10, a1);
702  acc21 = vmlaq_f32(acc21, b10, a2);
703  acc31 = vmlaq_f32(acc31, b10, a3);
704 
705  // 4x4 block 0
706  acc00 = vmlaq_f32(acc00, b01, a4);
707  acc10 = vmlaq_f32(acc10, b01, a5);
708  acc20 = vmlaq_f32(acc20, b01, a6);
709  acc30 = vmlaq_f32(acc30, b01, a7);
710 
711  // 4x4 block 1
712  acc01 = vmlaq_f32(acc01, b11, a4);
713  acc11 = vmlaq_f32(acc11, b11, a5);
714  acc21 = vmlaq_f32(acc21, b11, a6);
715  acc31 = vmlaq_f32(acc31, b11, a7);
716 
717  mtx_a0 += 8;
718  mtx_b0 += 8;
719  mtx_b1 += 8;
720  }
721 
722  for(; mtx_b0 < mtx_b0_end_addr;)
723  {
724  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
725  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
726  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
727  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
728  float32x4_t b00 = vld1q_f32(mtx_b0);
729  float32x4_t b10 = vld1q_f32(mtx_b1);
730 
731 #if __arm__
732  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
733  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
734  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
735 #endif /* __arm__ */
736  // 4x4 block 0
737  acc00 = vmlaq_f32(acc00, b00, a0);
738  acc10 = vmlaq_f32(acc10, b00, a1);
739  acc20 = vmlaq_f32(acc20, b00, a2);
740  acc30 = vmlaq_f32(acc30, b00, a3);
741 
742  // 4x4 block 1
743  acc01 = vmlaq_f32(acc01, b10, a0);
744  acc11 = vmlaq_f32(acc11, b10, a1);
745  acc21 = vmlaq_f32(acc21, b10, a2);
746  acc31 = vmlaq_f32(acc31, b10, a3);
747 
748  mtx_a0 += 4;
749  mtx_b0 += 4;
750  mtx_b1 += 4;
751  }
752 
753  // Multiply by the weight of matrix product (alpha)
754  if(multiply_alpha)
755  {
756  acc00 = vmulq_f32(acc00, alpha_f32);
757  acc10 = vmulq_f32(acc10, alpha_f32);
758  acc20 = vmulq_f32(acc20, alpha_f32);
759  acc30 = vmulq_f32(acc30, alpha_f32);
760  acc01 = vmulq_f32(acc01, alpha_f32);
761  acc11 = vmulq_f32(acc11, alpha_f32);
762  acc21 = vmulq_f32(acc21, alpha_f32);
763  acc31 = vmulq_f32(acc31, alpha_f32);
764  }
765 
766  const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
767  const auto mtx_out1 = mtx_out0 + 4;
768 
769  if(id.x() < (out_width - 8))
770  {
771  vst1q_f32(mtx_out0, acc00);
772  vst1q_f32(mtx_out1, acc01);
773  if(id.y() + 1 < out_height)
774  {
775  vst1q_f32(mtx_out0 + out_stride1, acc10);
776  vst1q_f32(mtx_out1 + out_stride1, acc11);
777  if(id.y() + 2 < out_height)
778  {
779  vst1q_f32(mtx_out0 + out_stride2, acc20);
780  vst1q_f32(mtx_out1 + out_stride2, acc21);
781  if(id.y() + 3 < out_height)
782  {
783  vst1q_f32(mtx_out0 + out_stride3, acc30);
784  vst1q_f32(mtx_out1 + out_stride3, acc31);
785  }
786  }
787  }
788  }
789  else if(id.x() < (out_width - 4))
790  {
791  vst1q_f32(mtx_out0, acc00);
792  if(id.y() + 1 < out_height)
793  {
794  vst1q_f32(mtx_out0 + out_stride1, acc10);
795  if(id.y() + 2 < out_height)
796  {
797  vst1q_f32(mtx_out0 + out_stride2, acc20);
798  if(id.y() + 3 < out_height)
799  {
800  vst1q_f32(mtx_out0 + out_stride3, acc30);
801  }
802  }
803  }
804  // Left-over columns
805  const int columns_left = out_width - id.x() - 4;
806  for(auto x = 0; x < columns_left; ++x)
807  {
808  *(mtx_out1 + x) = acc01[x];
809  if(id.y() + 1 < out_height)
810  {
811  *(mtx_out1 + x + out_stride1) = acc11[x];
812  if(id.y() + 2 < out_height)
813  {
814  *(mtx_out1 + x + out_stride2) = acc21[x];
815  if(id.y() + 3 < out_height)
816  {
817  *(mtx_out1 + x + out_stride3) = acc31[x];
818  }
819  }
820  }
821  }
822  }
823  else
824  {
825  // Left-over columns
826  const int columns_left = out_width - id.x();
827  for(int x = 0; x < columns_left; ++x)
828  {
829  *(mtx_out0 + x) = acc00[x];
830  if(id.y() + 1 < out_height)
831  {
832  *(mtx_out0 + x + out_stride1) = acc10[x];
833  if(id.y() + 2 < out_height)
834  {
835  *(mtx_out0 + x + out_stride2) = acc20[x];
836  if(id.y() + 3 < out_height)
837  {
838  *(mtx_out0 + x + out_stride3) = acc30[x];
839  }
840  }
841  }
842  }
843  }
844  },
845  ina, inb, out);
846 }
847 
848 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
849 void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
850 {
851  const int out_width = static_cast<int>(output->info()->dimension(0));
852  const int out_height = static_cast<int>(output->info()->dimension(1));
853  const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
854  const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
855  const int num_elems_matrix_b_x = input1->info()->dimension(0);
856 
857  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
858  Window win_a(window);
859  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
860  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
861 
862  Window win_b;
863  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
864  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
865  if(input1->info()->num_dimensions() >= 3)
866  {
867  win_b = window;
868  }
869  // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
870  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
871  win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
872 
873  Iterator ina(input0, win_a);
874  Iterator inb(input1, win_b);
875  Iterator out(output, window);
876 
877  const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
878 
879  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
880 
881  execute_window_loop(window, [&](const Coordinates & id)
882  {
883  const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
884  const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
885  auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
886  float16x8x4_t c =
887  {
888  {
889  vdupq_n_f16(0.f),
890  vdupq_n_f16(0.f),
891  vdupq_n_f16(0.f),
892  vdupq_n_f16(0.f)
893  }
894  };
895 
896  /*
897  This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
898  |a00 a01 a02 a03 | a04 a05 a06 a07|
899  |a10 a11 a12 a13 | a14 a15 a16 a17|
900  |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
901  |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
902  |a40 a41 a42 a43 | a44 a45 a46 a47|
903  |a50 a51 a52 a53 | a54 a55 a56 a57|
904  |a60 a61 a62 a63 | a64 a65 a66 a67|
905  |a70 a71 a72 a73 | a74 a75 a76 a77|
906 
907  After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
908 
909  B Matrix has been transposed as shown below
910 
911  |b00 b01 b02 b03 b04 b05 b06 b07|
912  |b10 b11 b12 b13 b14 b15 b16 b17|
913  |b20 b21 b22 b23 b24 b25 b26 b27|
914  |b30 b31 b32 b33 b34 b35 b36 b37|
915  ------------------->
916 
917  |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
918 
919  c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
920  c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
921 
922  The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
923  */
924  const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
925 
926  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
927 
928  {
929  const float16x8_t p00 = vld1q_f16(mtx_a0);
930  const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
931 
932  const float16x8_t q00 = vld1q_f16(mtx_b0);
933  const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
934  const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
935  const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
936 
937  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
938  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
939  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
940  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
941 
942  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
943  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
944  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
945  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
946 
947  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
948  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
949  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
950  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
951 
952  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
953  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
954  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
955  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
956 
957  mtx_a0 += 16;
958  mtx_b0 += 32;
959  }
960 
961  for(; mtx_b0 < mtx_b0_end_addr;)
962 
963  {
964  const float16x4_t p00 = vld1_f16(mtx_a0);
965  const float16x8_t q00 = vld1q_f16(mtx_b0);
966 
967  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
968  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
969  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
970  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
971 
972  mtx_a0 += 4;
973  mtx_b0 += 8;
974  }
975 
976  if(multiply_alpha)
977  {
978  c.val[0] = vmulq_f16(c.val[0], alpha_f16);
979  c.val[1] = vmulq_f16(c.val[1], alpha_f16);
980  c.val[2] = vmulq_f16(c.val[2], alpha_f16);
981  c.val[3] = vmulq_f16(c.val[3], alpha_f16);
982  }
983 
984  if(id.x() < (out_width - 8))
985  {
986  vst1q_f16(mtx_out, c.val[0]);
987  if(id.y() + 1 < out_height)
988  {
989  vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
990  if(id.y() + 2 < out_height)
991  {
992  vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
993  if(id.y() + 3 < out_height)
994  {
995  vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
996  }
997  }
998  }
999  }
1000  else
1001  {
1002  // Left-over columns
1003  const int columns_left = out_width - id.x();
1004  for(int x = 0; x < columns_left; ++x)
1005  {
1006  *(mtx_out + x) = c.val[0][x];
1007  if(id.y() + 1 < out_height)
1008  {
1009  *(mtx_out + x + 1 * out_stride) = c.val[1][x];
1010  if(id.y() + 2 < out_height)
1011  {
1012  *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1013  if(id.y() + 3 < out_height)
1014  {
1015  *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1016  }
1017  }
1018  }
1019  }
1020  }
1021  },
1022  ina, inb, out);
1023 }
1024 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1025 
1026 inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
1027 {
1028  ARM_COMPUTE_UNUSED(alpha);
1029 
1033 
1034  if(!is_interleaved)
1035  {
1036  ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
1037 
1038  if(output->total_size() != 0)
1039  {
1040  ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
1041  ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
1043  }
1044  }
1045  else
1046  {
1047  const int m = reshape_info.m();
1048  const int n = reshape_info.n();
1049  const int k = reshape_info.k();
1050  const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
1051  const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
1052 
1053  /* Interleave */
1054  TensorShape tensor_shape0{ input0->tensor_shape() };
1055  tensor_shape0.set(0, k);
1056  tensor_shape0.set(1, m);
1057 
1058  const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
1059  const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
1060  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
1061 
1062  if(n != 0) /* Transpose */
1063  {
1064  TensorShape tensor_shape1{ input1->tensor_shape() };
1065  tensor_shape1.set(0, n);
1066  tensor_shape1.set(1, k);
1067 
1068  const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
1069  const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
1070  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
1071  }
1072 
1073  if(output->total_size() != 0)
1074  {
1075  if(n != 0)
1076  {
1077  ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
1078  }
1079  ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
1081  }
1082  }
1083 
1084  return Status{};
1085 }
1086 } // namespace
1087 
1089  : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
1090 {
1091 }
1092 
1093 void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
1094 {
1095  ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
1096 
1097  // Output tensor auto inizialitation if not yet initialized
1098  TensorShape tensor_shape{ input0->info()->tensor_shape() };
1099  tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0));
1100  tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1));
1101 
1102  auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
1103 
1104  // Perform validate step
1105  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info));
1106 
1107  _input0 = input0;
1108  _input1 = input1;
1109  _output = output;
1110  _alpha = alpha;
1111 
1112  // Configure kernel window
1113  Window win{};
1114 
1115  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1116  if((output->info()->dimension(1) == 1))
1117  {
1118  const unsigned int num_elems_processed_per_iteration_x = (input0->info()->data_type() == DataType::F32) ? 16 : 32;
1119 
1120  win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
1121  }
1122  else
1123  {
1124  constexpr unsigned int num_elems_processed_per_iteration_x = 8;
1125  constexpr unsigned int num_elems_processed_per_iteration_y = 4;
1126 
1127  win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
1128  }
1129 
1130  Coordinates coord;
1131  coord.set_num_dimensions(output->info()->num_dimensions());
1132  output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
1133  INEKernel::configure(win);
1134 }
1135 
1136 Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
1137  const GEMMReshapeInfo &reshape_info)
1138 {
1139  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
1140 
1141  return Status{};
1142 }
1143 
1144 void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
1145 {
1148 
1149  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1150  const bool is_output_vector = (_output->info()->dimension(1) == 1);
1151  switch(_input0->info()->data_type())
1152  {
1153  case DataType::F32:
1154  {
1155  is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) :
1156  matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha);
1157  break;
1158  }
1159 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1160  case DataType::F16:
1161  {
1162  is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) :
1163  matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha);
1164  break;
1165  }
1166 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1167  default:
1168  {
1169  ARM_COMPUTE_ERROR("Data type not supported");
1170  break;
1171  }
1172  }
1173 }
1174 } // namespace arm_compute
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Definition: float_ops.h:97
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
Shape of a tensor.
Definition: TensorShape.h:39
TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width=1)
Calculate the transposed 1xW width element shape.
float16x8_t vmulq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:78
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:108
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
GEMM reshape information class.
Definition: Types.h:1831
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
1 channel, 1 F32 per channel
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
float16x8_t vaddq_f16(float16x8_t, float16x8_t)
Definition: clang-tidy.h:68
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Status class.
Definition: Error.h:52
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
float16x4_t vadd_f16(float16x4_t, float16x4_t)
Definition: clang-tidy.h:33
Interface for Neon tensor.
Definition: ITensor.h:36
float16x8_t vmulq_n_f16(float16x8_t, float16_t)
Definition: clang-tidy.h:83
TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height=1, bool reinterpret_input_as_3d=false)
Calculate the interleaved shape of an input tensor.
Copyright (c) 2017-2021 Arm Limited.
virtual void set_valid_region(const ValidRegion &valid_region)=0
Set the valid region of the tensor.
1 channel, 1 F16 per channel
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
int n() const
Number of matrix B columns.
Definition: Types.h:1869
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:71
Class to describe a number of elements in each dimension.
Definition: Steps.h:40
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
Coordinates of an item.
Definition: Coordinates.h:37
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:106
float16x4_t vmul_f16(float16x4_t, float16x4_t)
Definition: clang-tidy.h:28
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:941
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
Definition: CPPTypes.h:235
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:443
void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info=GEMMReshapeInfo())
Initialise the kernel&#39;s input and output.
int m() const
Number of matrix A rows.
Definition: Types.h:1861
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:545
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
float16x8_t vmulq_lane_f16(float16x8_t, float16x4_t, const int)
Definition: clang-tidy.h:23
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Static function to check if given info will lead to a valid configuration of NEGEMMMatrixMultiplyKern...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
Definition: Dimensions.h:149
Container for valid region of a window.
Definition: Types.h:188
Describe a multidimensional execution window.
Definition: Window.h:39
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
Definition: TensorShape.h:79
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:205