Compute Library
 20.05
NEGEMMMatrixMultiplyKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2019 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Error.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
40 
41 #include <arm_neon.h>
42 #include <cstddef>
43 #include <cstdint>
44 #include <tuple>
45 
46 using namespace arm_compute;
47 
48 namespace arm_compute
49 {
50 class Coordinates;
51 } // namespace arm_compute
52 
53 namespace
54 {
55 template <bool multiply_alpha>
56 void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
57 {
58 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
59  const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
60  const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
61  const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
62 
63  // The implementation computes 32 elements per iteration
64  const int window_start_x = 32 * info.thread_id;
65  const int window_step_x = 32 * info.num_threads;
66  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
67  ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
68 
69  Window win_out(window);
70  win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
71  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
72 
73  Window win_a(window);
74  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
75  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
76 
77  Window win_b;
78  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
79  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
80  if(input1->info()->num_dimensions() >= 3)
81  {
82  win_b = window;
83  }
84  win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
85  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
86 
87  Iterator ina(input0, win_a);
88  Iterator inb(input1, win_b);
89  Iterator out(output, win_out);
90 
91  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
92  ARM_COMPUTE_UNUSED(alpha_f16);
93 
94  execute_window_loop(win_out, [&](const Coordinates & id)
95  {
96  if(id.x() > width_matrix_b)
97  {
98  return;
99  }
100 
101  float16x8_t acc0 = vdupq_n_f16(0.f);
102  float16x8_t acc1 = vdupq_n_f16(0.f);
103  float16x8_t acc2 = vdupq_n_f16(0.f);
104  float16x8_t acc3 = vdupq_n_f16(0.f);
105 
106  auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
107  auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
108 
109  const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
110  for(; vec_a <= (vec_a_end_addr - 4);)
111  {
112  const float16x4_t a0l = vld1_f16(vec_a);
113 
114  float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
115  float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
116  float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
117  float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
118  float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
119  float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
120  float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
121  float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
122 
123  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
124  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
125  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
126  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
127  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
128  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
129  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
130  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
131 
132  matrix_b += 2 * in_b_stride;
133 
134  b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
135  b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
136  b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
137  b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
138  b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
139  b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
140  b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
141  b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
142 
143  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
144  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
145  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
146  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
147  acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
148  acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
149  acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
150  acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
151 
152  vec_a += 4;
153  matrix_b += 2 * in_b_stride;
154  }
155 
156  for(; vec_a < vec_a_end_addr;)
157  {
158  const float16_t a0 = *vec_a;
159  const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
160  const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
161  const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
162  const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
163 
164  acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
165  acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
166  acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
167  acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
168 
169  vec_a += 1;
170  matrix_b += in_b_stride;
171  }
172 
173  // Multiply by the weight of matrix product (alpha)
174  if(multiply_alpha)
175  {
176  acc0 = vmulq_f16(acc0, alpha_f16);
177  acc1 = vmulq_f16(acc1, alpha_f16);
178  acc2 = vmulq_f16(acc2, alpha_f16);
179  acc3 = vmulq_f16(acc3, alpha_f16);
180  }
181 
182  const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
183 
184  vst1q_f16(vec_out + 0, acc0);
185  vst1q_f16(vec_out + 8, acc1);
186  vst1q_f16(vec_out + 16, acc2);
187  vst1q_f16(vec_out + 24, acc3);
188 
189  },
190  ina, inb, out);
191 #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
192  ARM_COMPUTE_UNUSED(input0);
193  ARM_COMPUTE_UNUSED(input1);
194  ARM_COMPUTE_UNUSED(output);
195  ARM_COMPUTE_UNUSED(window);
198  ARM_COMPUTE_ERROR("Not implemented");
199 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
200 }
201 
202 template <bool multiply_alpha>
203 void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
204 {
205  const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
206  const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
207  const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
208 
209  // The implementation computes 16 elements per iteration
210  const int window_start_x = 16 * info.thread_id;
211  const int window_step_x = 16 * info.num_threads;
212  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
213  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
214 
215  Window win_out(window);
216  win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
217  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
218 
219  Window win_a(window);
220  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
221  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
222 
223  Window win_b;
224  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
225  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
226  if(input1->info()->num_dimensions() >= 3)
227  {
228  win_b = window;
229  }
230  win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
231  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
232 
233  Iterator ina(input0, win_a);
234  Iterator inb(input1, win_b);
235  Iterator out(output, win_out);
236 
237  execute_window_loop(win_out, [&](const Coordinates & id)
238  {
239  if(id.x() > width_matrix_b)
240  {
241  return;
242  }
243 
244  float32x4_t acc0 = vdupq_n_f32(0.f);
245  float32x4_t acc1 = vdupq_n_f32(0.f);
246  float32x4_t acc2 = vdupq_n_f32(0.f);
247  float32x4_t acc3 = vdupq_n_f32(0.f);
248 
249  auto vec_a = reinterpret_cast<const float *>(ina.ptr());
250  auto matrix_b = reinterpret_cast<const float *>(inb.ptr());
251 
252 #if __arm__
253  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
254  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
255  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
256 #endif /* __arm__ */
257 
258  auto vec_a_end_addr = vec_a + num_elems_vec_a;
259  for(; vec_a <= (vec_a_end_addr - 4);)
260  {
261  float32x2_t a0l = vld1_f32(vec_a);
262 
263  float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
264  float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
265  float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
266  float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
267 
268  float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
269  float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
270  float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
271  float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
272 
273 #if __arm__
274  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
275  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
276  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
277  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
278  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
279 #endif /* __arm__ */
280 
281  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
282  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
283  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
284  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
285 
286  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
287  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
288  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
289  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
290 
291  vec_a += 2;
292  matrix_b += 2 * in_b_stride;
293 
294  a0l = vld1_f32(vec_a);
295 
296  b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
297  b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
298  b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
299  b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
300 
301  b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
302  b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
303  b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
304  b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
305 
306  acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
307  acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
308  acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
309  acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
310 
311  acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
312  acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
313  acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
314  acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
315 
316  vec_a += 2;
317  matrix_b += 2 * in_b_stride;
318  }
319 
320  for(; vec_a < vec_a_end_addr;)
321  {
322  const float a0 = *vec_a;
323 
324  const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
325  const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
326  const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
327  const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
328 
329  acc0 = vmlaq_n_f32(acc0, b00, a0);
330  acc1 = vmlaq_n_f32(acc1, b01, a0);
331  acc2 = vmlaq_n_f32(acc2, b02, a0);
332  acc3 = vmlaq_n_f32(acc3, b03, a0);
333 
334  vec_a += 1;
335  matrix_b += in_b_stride;
336  }
337 
338  // Multiply by the weight of matrix product (alpha)
339  if(multiply_alpha)
340  {
341  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
342  acc0 = vmulq_f32(acc0, alpha_f32);
343  acc1 = vmulq_f32(acc1, alpha_f32);
344  acc2 = vmulq_f32(acc2, alpha_f32);
345  acc3 = vmulq_f32(acc3, alpha_f32);
346  }
347 
348  const auto vec_out = reinterpret_cast<float *>(out.ptr());
349 
350  vst1q_f32(vec_out + 0, acc0);
351  vst1q_f32(vec_out + 4, acc1);
352  vst1q_f32(vec_out + 8, acc2);
353  vst1q_f32(vec_out + 12, acc3);
354  },
355  ina, inb, out);
356 }
357 
358 template <bool multiply_alpha>
359 void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
360 {
361  const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
362  const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
363  const size_t out_stride2 = out_stride1 * 2;
364  const size_t out_stride3 = out_stride1 * 3;
365  const int num_elems_matrix_b_x = input1->info()->dimension(0);
366 
367  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
368  Window win_a(window);
369  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
370  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
371 
372  Window win_b;
373  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
374  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
375  if(input1->info()->num_dimensions() >= 3)
376  {
377  win_b = window;
378  }
379  // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
380  // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
381  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
382  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
383 
384  Iterator ina(input0, win_a);
385  Iterator inb(input1, win_b);
386  Iterator out(output, window);
387 
388  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
389  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
390  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
391  execute_window_loop(window, [&](const Coordinates &)
392  {
393  auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
394  auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
395  auto mtx_b1 = mtx_b0 + in_b_stride;
396 
397  float32x4_t acc00 = vdupq_n_f32(0.f);
398  float32x4_t acc10 = vdupq_n_f32(0.f);
399  float32x4_t acc20 = vdupq_n_f32(0.f);
400  float32x4_t acc30 = vdupq_n_f32(0.f);
401 
402  float32x4_t acc01 = vdupq_n_f32(0.f);
403  float32x4_t acc11 = vdupq_n_f32(0.f);
404  float32x4_t acc21 = vdupq_n_f32(0.f);
405  float32x4_t acc31 = vdupq_n_f32(0.f);
406 
407 #if __arm__
408  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
409  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
410  asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
411 #endif /* __arm__ */
412 
413  auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
414  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
415  {
416  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
417  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
418  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
419  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
420 
421  float32x4_t b00 = vld1q_f32(mtx_b0);
422  float32x4_t b10 = vld1q_f32(mtx_b1);
423  float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
424  float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
425 
426 #if __arm__
427  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
428  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
429  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
430 #endif /* __arm__ */
431 
432  // 4x4 block 0
433  acc00 = vmlaq_f32(acc00, b00, a0);
434  acc10 = vmlaq_f32(acc10, b00, a1);
435  acc20 = vmlaq_f32(acc20, b00, a2);
436  acc30 = vmlaq_f32(acc30, b00, a3);
437 
438  float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
439  float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
440  float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
441  float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
442 
443  // 4x4 block 1
444  acc01 = vmlaq_f32(acc01, b10, a0);
445  acc11 = vmlaq_f32(acc11, b10, a1);
446  acc21 = vmlaq_f32(acc21, b10, a2);
447  acc31 = vmlaq_f32(acc31, b10, a3);
448 
449  // 4x4 block 0
450  acc00 = vmlaq_f32(acc00, b01, a4);
451  acc10 = vmlaq_f32(acc10, b01, a5);
452  acc20 = vmlaq_f32(acc20, b01, a6);
453  acc30 = vmlaq_f32(acc30, b01, a7);
454 
455  // 4x4 block 1
456  acc01 = vmlaq_f32(acc01, b11, a4);
457  acc11 = vmlaq_f32(acc11, b11, a5);
458  acc21 = vmlaq_f32(acc21, b11, a6);
459  acc31 = vmlaq_f32(acc31, b11, a7);
460 
461  mtx_a0 += 8;
462  mtx_b0 += 8;
463  mtx_b1 += 8;
464 
465  a0 = vld1q_dup_f32(mtx_a0 + 0);
466  a1 = vld1q_dup_f32(mtx_a0 + 1);
467  a2 = vld1q_dup_f32(mtx_a0 + 2);
468  a3 = vld1q_dup_f32(mtx_a0 + 3);
469 
470  b00 = vld1q_f32(mtx_b0);
471  b10 = vld1q_f32(mtx_b1);
472  b01 = vld1q_f32(mtx_b0 + 4);
473  b11 = vld1q_f32(mtx_b1 + 4);
474 
475  // 4x4 block 0
476  acc00 = vmlaq_f32(acc00, b00, a0);
477  acc10 = vmlaq_f32(acc10, b00, a1);
478  acc20 = vmlaq_f32(acc20, b00, a2);
479  acc30 = vmlaq_f32(acc30, b00, a3);
480 
481  a4 = vld1q_dup_f32(mtx_a0 + 4);
482  a5 = vld1q_dup_f32(mtx_a0 + 5);
483  a6 = vld1q_dup_f32(mtx_a0 + 6);
484  a7 = vld1q_dup_f32(mtx_a0 + 7);
485 
486  // 4x4 block 1
487  acc01 = vmlaq_f32(acc01, b10, a0);
488  acc11 = vmlaq_f32(acc11, b10, a1);
489  acc21 = vmlaq_f32(acc21, b10, a2);
490  acc31 = vmlaq_f32(acc31, b10, a3);
491 
492  // 4x4 block 0
493  acc00 = vmlaq_f32(acc00, b01, a4);
494  acc10 = vmlaq_f32(acc10, b01, a5);
495  acc20 = vmlaq_f32(acc20, b01, a6);
496  acc30 = vmlaq_f32(acc30, b01, a7);
497 
498  // 4x4 block 1
499  acc01 = vmlaq_f32(acc01, b11, a4);
500  acc11 = vmlaq_f32(acc11, b11, a5);
501  acc21 = vmlaq_f32(acc21, b11, a6);
502  acc31 = vmlaq_f32(acc31, b11, a7);
503 
504  mtx_a0 += 8;
505  mtx_b0 += 8;
506  mtx_b1 += 8;
507 
508  a0 = vld1q_dup_f32(mtx_a0 + 0);
509  a1 = vld1q_dup_f32(mtx_a0 + 1);
510  a2 = vld1q_dup_f32(mtx_a0 + 2);
511  a3 = vld1q_dup_f32(mtx_a0 + 3);
512  b00 = vld1q_f32(mtx_b0);
513  b10 = vld1q_f32(mtx_b1);
514  b01 = vld1q_f32(mtx_b0 + 4);
515  b11 = vld1q_f32(mtx_b1 + 4);
516 
517 #if __arm__
518  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
519  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
520  asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
521 #endif /* __arm__ */
522 
523  // 4x4 block 0
524  acc00 = vmlaq_f32(acc00, b00, a0);
525  acc10 = vmlaq_f32(acc10, b00, a1);
526  acc20 = vmlaq_f32(acc20, b00, a2);
527  acc30 = vmlaq_f32(acc30, b00, a3);
528 
529  a4 = vld1q_dup_f32(mtx_a0 + 4);
530  a5 = vld1q_dup_f32(mtx_a0 + 5);
531  a6 = vld1q_dup_f32(mtx_a0 + 6);
532  a7 = vld1q_dup_f32(mtx_a0 + 7);
533 
534  // 4x4 block 1
535  acc01 = vmlaq_f32(acc01, b10, a0);
536  acc11 = vmlaq_f32(acc11, b10, a1);
537  acc21 = vmlaq_f32(acc21, b10, a2);
538  acc31 = vmlaq_f32(acc31, b10, a3);
539 
540  // 4x4 block 0
541  acc00 = vmlaq_f32(acc00, b01, a4);
542  acc10 = vmlaq_f32(acc10, b01, a5);
543  acc20 = vmlaq_f32(acc20, b01, a6);
544  acc30 = vmlaq_f32(acc30, b01, a7);
545 
546  // 4x4 block 1
547  acc01 = vmlaq_f32(acc01, b11, a4);
548  acc11 = vmlaq_f32(acc11, b11, a5);
549  acc21 = vmlaq_f32(acc21, b11, a6);
550  acc31 = vmlaq_f32(acc31, b11, a7);
551 
552  mtx_a0 += 8;
553  mtx_b0 += 8;
554  mtx_b1 += 8;
555 
556  a0 = vld1q_dup_f32(mtx_a0 + 0);
557  a1 = vld1q_dup_f32(mtx_a0 + 1);
558  a2 = vld1q_dup_f32(mtx_a0 + 2);
559  a3 = vld1q_dup_f32(mtx_a0 + 3);
560  b00 = vld1q_f32(mtx_b0);
561  b10 = vld1q_f32(mtx_b1);
562  b01 = vld1q_f32(mtx_b0 + 4);
563  b11 = vld1q_f32(mtx_b1 + 4);
564 
565  // 4x4 block 0
566  acc00 = vmlaq_f32(acc00, b00, a0);
567  acc10 = vmlaq_f32(acc10, b00, a1);
568  acc20 = vmlaq_f32(acc20, b00, a2);
569  acc30 = vmlaq_f32(acc30, b00, a3);
570 
571  a4 = vld1q_dup_f32(mtx_a0 + 4);
572  a5 = vld1q_dup_f32(mtx_a0 + 5);
573  a6 = vld1q_dup_f32(mtx_a0 + 6);
574  a7 = vld1q_dup_f32(mtx_a0 + 7);
575 
576  // 4x4 block 1
577  acc01 = vmlaq_f32(acc01, b10, a0);
578  acc11 = vmlaq_f32(acc11, b10, a1);
579  acc21 = vmlaq_f32(acc21, b10, a2);
580  acc31 = vmlaq_f32(acc31, b10, a3);
581 
582  // 4x4 block 0
583  acc00 = vmlaq_f32(acc00, b01, a4);
584  acc10 = vmlaq_f32(acc10, b01, a5);
585  acc20 = vmlaq_f32(acc20, b01, a6);
586  acc30 = vmlaq_f32(acc30, b01, a7);
587 
588  // 4x4 block 1
589  acc01 = vmlaq_f32(acc01, b11, a4);
590  acc11 = vmlaq_f32(acc11, b11, a5);
591  acc21 = vmlaq_f32(acc21, b11, a6);
592  acc31 = vmlaq_f32(acc31, b11, a7);
593 
594  mtx_a0 += 8;
595  mtx_b0 += 8;
596  mtx_b1 += 8;
597  }
598 
599  for(; mtx_b0 < mtx_b0_end_addr;)
600  {
601  float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
602  float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
603  float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
604  float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
605  float32x4_t b00 = vld1q_f32(mtx_b0);
606  float32x4_t b10 = vld1q_f32(mtx_b1);
607 
608 #if __arm__
609  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
610  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
611  asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
612 #endif /* __arm__ */
613  // 4x4 block 0
614  acc00 = vmlaq_f32(acc00, b00, a0);
615  acc10 = vmlaq_f32(acc10, b00, a1);
616  acc20 = vmlaq_f32(acc20, b00, a2);
617  acc30 = vmlaq_f32(acc30, b00, a3);
618 
619  // 4x4 block 1
620  acc01 = vmlaq_f32(acc01, b10, a0);
621  acc11 = vmlaq_f32(acc11, b10, a1);
622  acc21 = vmlaq_f32(acc21, b10, a2);
623  acc31 = vmlaq_f32(acc31, b10, a3);
624 
625  mtx_a0 += 4;
626  mtx_b0 += 4;
627  mtx_b1 += 4;
628  }
629 
630  // Multiply by the weight of matrix product (alpha)
631  if(multiply_alpha)
632  {
633  const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
634  acc00 = vmulq_f32(acc00, alpha_f32);
635  acc10 = vmulq_f32(acc10, alpha_f32);
636  acc20 = vmulq_f32(acc20, alpha_f32);
637  acc30 = vmulq_f32(acc30, alpha_f32);
638  acc01 = vmulq_f32(acc01, alpha_f32);
639  acc11 = vmulq_f32(acc11, alpha_f32);
640  acc21 = vmulq_f32(acc21, alpha_f32);
641  acc31 = vmulq_f32(acc31, alpha_f32);
642  }
643 
644  const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
645  const auto mtx_out1 = mtx_out0 + 4;
646 
647  // Store the 4 blocks
648  vst1q_f32(mtx_out0, acc00);
649  vst1q_f32(mtx_out1, acc01);
650  vst1q_f32(mtx_out0 + out_stride1, acc10);
651  vst1q_f32(mtx_out1 + out_stride1, acc11);
652  vst1q_f32(mtx_out0 + out_stride2, acc20);
653  vst1q_f32(mtx_out1 + out_stride2, acc21);
654  vst1q_f32(mtx_out0 + out_stride3, acc30);
655  vst1q_f32(mtx_out1 + out_stride3, acc31);
656  },
657  ina, inb, out);
658 }
659 
660 template <bool multiply_alpha>
661 void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
662 {
663 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
664  const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
665  const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
666  const int num_elems_matrix_b_x = input1->info()->dimension(0);
667 
668  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
669  Window win_a(window);
670  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
671  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
672 
673  Window win_b;
674  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
675  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
676  if(input1->info()->num_dimensions() >= 3)
677  {
678  win_b = window;
679  }
680  // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix
681  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
682  win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
683 
684  Iterator ina(input0, win_a);
685  Iterator inb(input1, win_b);
686  Iterator out(output, window);
687 
688  const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
689 
690  execute_window_loop(window, [&](const Coordinates &)
691  {
692  const auto *mtx_a0 = reinterpret_cast<const float16_t *>(ina.ptr());
693  const auto *mtx_b0 = reinterpret_cast<const float16_t *>(inb.ptr());
694  auto *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
695  float16x8x4_t c =
696  {
697  {
698  vdupq_n_f16(0.f),
699  vdupq_n_f16(0.f),
700  vdupq_n_f16(0.f),
701  vdupq_n_f16(0.f)
702  }
703  };
704 
705  /*
706  This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
707  |a00 a01 a02 a03 | a04 a05 a06 a07|
708  |a10 a11 a12 a13 | a14 a15 a16 a17|
709  |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
710  |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
711  |a40 a41 a42 a43 | a44 a45 a46 a47|
712  |a50 a51 a52 a53 | a54 a55 a56 a57|
713  |a60 a61 a62 a63 | a64 a65 a66 a67|
714  |a70 a71 a72 a73 | a74 a75 a76 a77|
715 
716  After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
717 
718  B Matrix has been transposed as shown below
719 
720  |b00 b01 b02 b03 b04 b05 b06 b07|
721  |b10 b11 b12 b13 b14 b15 b16 b17|
722  |b20 b21 b22 b23 b24 b25 b26 b27|
723  |b30 b31 b32 b33 b34 b35 b36 b37|
724  ------------------->
725 
726  |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
727 
728  c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
729  c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
730 
731  The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
732  */
733  const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
734 
735  for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
736 
737  {
738  const float16x8_t p00 = vld1q_f16(mtx_a0);
739  const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
740 
741  const float16x8_t q00 = vld1q_f16(mtx_b0);
742  const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
743  const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
744  const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
745 
746  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
747  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
748  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
749  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
750 
751  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
752  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
753  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
754  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
755 
756  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
757  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
758  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
759  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
760 
761  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
762  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
763  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
764  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
765 
766  mtx_a0 += 16;
767  mtx_b0 += 32;
768  }
769 
770  for(; mtx_b0 < mtx_b0_end_addr;)
771 
772  {
773  const float16x4_t p00 = vld1_f16(mtx_a0);
774  const float16x8_t q00 = vld1q_f16(mtx_b0);
775 
776  c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
777  c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
778  c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
779  c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
780 
781  mtx_a0 += 4;
782  mtx_b0 += 8;
783  }
784 
785  if(multiply_alpha)
786  {
787  c.val[0] = vmulq_f16(c.val[0], alpha_f16);
788  c.val[1] = vmulq_f16(c.val[1], alpha_f16);
789  c.val[2] = vmulq_f16(c.val[2], alpha_f16);
790  c.val[3] = vmulq_f16(c.val[3], alpha_f16);
791  }
792 
793  vst1q_f16(mtx_out + 0 * out_stride, c.val[0]);
794  vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
795  vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
796  vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
797  },
798  ina, inb, out);
799 #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
800  ARM_COMPUTE_UNUSED(input0);
801  ARM_COMPUTE_UNUSED(input1);
802  ARM_COMPUTE_UNUSED(output);
803  ARM_COMPUTE_UNUSED(window);
805  ARM_COMPUTE_ERROR("Not implemented");
806 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
807 }
808 
809 inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
810 {
812 
816 
817  if(!is_interleaved)
818  {
819  ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
820 
821  if(output->total_size() != 0)
822  {
823  ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
824  ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
826  }
827  }
828  else
829  {
830  const int m = reshape_info.m();
831  const int n = reshape_info.n();
832  const int k = reshape_info.k();
833  const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
834  const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
835 
836  /* Interleave */
837  TensorShape tensor_shape0{ input0->tensor_shape() };
838  tensor_shape0.set(0, k);
839  tensor_shape0.set(1, m);
840 
841  const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
842  const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
843  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
844 
845  if(n != 0) /* Transpose */
846  {
847  TensorShape tensor_shape1{ input1->tensor_shape() };
848  tensor_shape1.set(0, n);
849  tensor_shape1.set(1, k);
850 
851  const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
852  const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
853  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
854  }
855 
856  if(output->total_size() != 0)
857  {
858  if(n != 0)
859  {
860  ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
861  }
862  ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
864  }
865  }
866 
867  return Status{};
868 }
869 
870 inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
871 {
872  bool window_changed{};
873  Window win{};
874 
875  unsigned int num_elems_processed_per_iteration_x = 0;
876  const unsigned int num_elems_processed_per_iteration_y = 4;
877 
878  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
879  if((output->dimension(1) == 1))
880  {
881  switch(input0->data_type())
882  {
883  case DataType::F32:
884  {
885  num_elems_processed_per_iteration_x = 16;
886  break;
887  }
888 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
889  case DataType::F16:
890  {
891  num_elems_processed_per_iteration_x = 32;
892  break;
893  }
894 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
895  default:
896  {
897  ARM_COMPUTE_ERROR("Data type not supported");
898  break;
899  }
900  }
901 
902  // Configure kernel window
903  win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
904 
905  AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
906 
907  window_changed = update_window_and_padding(win,
908  AccessWindowStatic(input0, 0, 0, input0->tensor_shape().x(), 1),
909  AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration_x),
910  output_access);
911 
912  Coordinates coord;
913  coord.set_num_dimensions(output->num_dimensions());
914  output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
915  }
916  else
917  {
918  switch(input0->data_type())
919  {
920  case DataType::F32:
921  {
922  num_elems_processed_per_iteration_x = 8;
923  break;
924  }
925 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
926  case DataType::F16:
927  {
928  num_elems_processed_per_iteration_x = 8;
929  break;
930  }
931 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
932  default:
933  {
934  ARM_COMPUTE_ERROR("Data type not supported");
935  break;
936  }
937  }
938 
939  // Configure kernel window
940  win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
941 
942  AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
943 
944  window_changed = update_window_and_padding(win,
945  AccessWindowRectangle(input0, 0, 0, 4, 1, 1.f, 0.25f),
946  AccessWindowStatic(input1, 0, 0, input1->tensor_shape().x(), ceil_to_multiple(input1->tensor_shape().y(), 4)),
947  output_access);
948 
949  output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
950  }
951 
952  Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
953  return std::make_pair(err, win);
954 }
955 } // namespace
956 
958  : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f)
959 {
960 }
961 
962 void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
963 {
964  ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
965 
966  // Output tensor auto inizialitation if not yet initialized
967  TensorShape tensor_shape{ input0->info()->tensor_shape() };
968  tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0));
969  tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1));
970 
971  auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
972 
973  // Perform validate step
974  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info));
975 
976  _input0 = input0;
977  _input1 = input1;
978  _output = output;
979  _alpha = alpha;
980 
981  // Configure kernel window
982  auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
983  ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
984  INEKernel::configure(win_config.second);
985 }
986 
987 Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
988  const GEMMReshapeInfo &reshape_info)
989 {
990  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
991  ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
992 
993  return Status{};
994 }
995 
997 {
1000 
1001  const bool multiply_alpha = !(helpers::float_ops::is_one(_alpha));
1002 
1003  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
1004  if((_output->info()->dimension(1) == 1))
1005  {
1006  switch(_input0->info()->data_type())
1007  {
1008  case DataType::F32:
1009  {
1010  multiply_alpha ? vector_matrix_multiply_f32<true>(_input0, _input1, _output, window, info, _alpha) :
1011  vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
1012  break;
1013  }
1014 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1015  case DataType::F16:
1016  {
1017  multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, info, _alpha) :
1018  vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, info, _alpha);
1019  break;
1020  }
1021 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1022  default:
1023  {
1024  ARM_COMPUTE_ERROR("Data type not supported");
1025  break;
1026  }
1027  }
1028  }
1029  else
1030  {
1031  switch(_input0->info()->data_type())
1032  {
1033  case DataType::F32:
1034  {
1035  multiply_alpha ? matrix_matrix_multiply_f32<true>(_input0, _input1, _output, window, _alpha) :
1036  matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
1037  break;
1038  }
1039 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1040  case DataType::F16:
1041  {
1042  multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
1043  matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
1044  break;
1045  }
1046 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1047  default:
1048  {
1049  ARM_COMPUTE_ERROR("Data type not supported");
1050  break;
1051  }
1052  }
1053  }
1054 }
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
bool is_one(float a, float epsilon=0.00001f)
Checks if the input floating point number is 1.0f checking if the difference is within a range define...
Definition: float_ops.h:97
int mult_interleave4x4_height() const
Multiplication factor for the height of the 4x4 interleaved block.
Definition: Types.h:1822
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
Shape of a tensor.
Definition: TensorShape.h:39
TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width=1)
Calculate the transposed 1xW width element shape.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:545
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
int mult_transpose1xW_width() const
Multiplication factor for the width of the 1xW transposed block.
Definition: Types.h:1814
GEMM reshape information class.
Definition: Types.h:1760
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
1 channel, 1 F32 per channel
Store the tensor's metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Describe one of the image's dimensions with a start, end and step.
Definition: Window.h:75
Status class.
Definition: Error.h:52
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
Interface for NEON tensor.
Definition: ITensor.h:36
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps=Steps(), bool skip_border=false, BorderSize border_size=BorderSize())
Calculate the maximum window for a given tensor shape and border setting.
Definition: Helpers.cpp:28
TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height=1, bool reinterpret_input_as_3d=false)
Calculate the interleaved shape of an input tensor.
Copyright (c) 2017-2020 ARM Limited.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
Definition: Helpers.inl:202
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:108
1 channel, 1 F16 per channel
Implementation of a static rectangular access pattern.
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
T x() const
Alias to access the size of the first dimension.
Definition: Dimensions.h:81
int n() const
Number of matrix B columns.
Definition: Types.h:1798
Implementation of a rectangular access pattern.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
bool update_window_and_padding(Window &win, Ts &&... patterns)
Update window and padding size for each of the access patterns.
Definition: Helpers.h:437
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:443
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:67
Class to describe a number of elements in each dimension.
Definition: Steps.h:40
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
Coordinates of an item.
Definition: Coordinates.h:37
Implementation of a row access pattern.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:102
int k() const
Number of matrix A columns or matrix B rows.
Definition: Types.h:1806
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
Information about executing thread and CPU.
Definition: CPPTypes.h:225
virtual size_t total_size() const =0
Returns the total size of the tensor in bytes.
#define ARM_COMPUTE_CREATE_ERROR(error_code, msg)
Creates an error with a given message.
Definition: Error.h:159
void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info=GEMMReshapeInfo())
Initialise the kernel's input and output.
int m() const
Number of matrix A rows.
Definition: Types.h:1790
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
Definition: Window.h:152
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true)
Accessor to set the value of one of the dimensions.
Definition: TensorShape.h:78
Store the tensor's metadata.
Definition: TensorInfo.h:45
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:123
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
Static function to check if given info will lead to a valid configuration of NEGEMMMatrixMultiplyKern...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
Definition: Dimensions.h:128
T y() const
Alias to access the size of the second dimension.
Definition: Dimensions.h:86
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
Container for valid region of a window.
Definition: Types.h:187
constexpr int end() const
Return the end of the dimension.
Definition: Window.h:97
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:353
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:205
constexpr int start() const
Return the start of the dimension.
Definition: Window.h:92
Describe a multidimensional execution window.
Definition: Window.h:39
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:941
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
Definition: Window.h:143