Compute Library
 21.02
NEGEMMLowpMatrixMultiplyKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Utils.h"
37 
38 #include <arm_neon.h>
39 
40 using namespace arm_compute;
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
47 {
48  execute_window_loop(window, [&](const Coordinates & id)
49  {
50  if(id.x() > width_b)
51  {
52  return;
53  }
54 
55  // Note: Since the input are all positives, we can use uint32_t
56  // Accumulators for the block 0
57  uint32x4x4_t c0 =
58  {
59  {
60  vdupq_n_u32(0),
61  vdupq_n_u32(0),
62  vdupq_n_u32(0),
63  vdupq_n_u32(0)
64  }
65  };
66 
67  auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
68  auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
69  auto vec_a_end_addr = vec_a + width_a;
70 
71  // This for loop performs 8 accumulations
72  for(; vec_a <= (vec_a_end_addr - 8);)
73  {
74  const uint8x8_t a00_u8 = vld1_u8(vec_a);
75  const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
76  const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
77  const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
78  const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
79  const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
80  const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
81  const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
82  const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
83 
84  // Convert a00_u8 to uint16_t and get the lower part
85  const uint16x4x2_t a00_u16 =
86  {
87  {
88  vget_low_u16(vmovl_u8(a00_u8)),
89  vget_high_u16(vmovl_u8(a00_u8))
90  }
91  };
92 
93  const uint16x4x4_t b00_u16 =
94  {
95  {
96  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
97  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
98  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
99  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
100  }
101  };
102 
103  const uint16x4x4_t b10_u16 =
104  {
105  {
106  vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
107  vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
108  vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
109  vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
110  }
111  };
112 
113  const uint16x4x4_t b20_u16 =
114  {
115  {
116  vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
117  vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
118  vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
119  vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
120  }
121  };
122 
123  const uint16x4x4_t b30_u16 =
124  {
125  {
126  vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
127  vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
128  vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
129  vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
130  }
131  };
132 
133  const uint16x4x4_t b40_u16 =
134  {
135  {
136  vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
137  vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
138  vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
139  vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
140  }
141  };
142 
143  const uint16x4x4_t b50_u16 =
144  {
145  {
146  vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
147  vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
148  vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
149  vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
150  }
151  };
152 
153  const uint16x4x4_t b60_u16 =
154  {
155  {
156  vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
157  vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
158  vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
159  vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
160  }
161  };
162 
163  const uint16x4x4_t b70_u16 =
164  {
165  {
166  vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
167  vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
168  vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
169  vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
170  }
171  };
172 
173  // Accumulate 0:
174  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
175  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
176  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
177  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
178 
179  // Accumulate 1:
180  c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
181  c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
182  c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
183  c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
184 
185  // Accumulate 2:
186  c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
187  c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
188  c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
189  c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
190 
191  // Accumulate 3:
192  c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
193  c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
194  c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
195  c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
196 
197  // Accumulate 4:
198  c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
199  c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
200  c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
201  c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
202 
203  // Accumulate 5:
204  c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
205  c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
206  c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
207  c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
208 
209  // Accumulate 6:
210  c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
211  c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
212  c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
213  c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
214 
215  // Accumulate 7:
216  c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
217  c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
218  c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
219  c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
220 
221  vec_a += 8;
222  matrix_b += 8 * stride_b;
223  }
224 
225  // This for loop performs the left-over accumulations
226  for(; vec_a < vec_a_end_addr;)
227  {
228  const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
229  const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
230 
231  const uint16x4x4_t b00_u16 =
232  {
233  {
234  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
235  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
236  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
237  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
238  }
239  };
240 
241  // Convert a00_u8 to uint16_t and get the lower part
242  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
243 
244  // Accumulate 0:
245  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
246  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
247  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
248  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
249 
250  vec_a += 1;
251  matrix_b += stride_b;
252  }
253 
254  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
255  if(id.x() < (width_out - 16))
256  {
257  vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
258  vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
259  vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
260  vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
261  }
262  else
263  {
264  auto left_over = width_out - id.x();
265  for(auto k = 0; k < 4 && left_over; ++k)
266  {
267  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
268  {
269  *(vec_out + k * 4 + j) = c0.val[k][j];
270  }
271  }
272  }
273  },
274  ina, inb, out);
275 }
276 
277 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
278 {
279  execute_window_loop(window, [&](const Coordinates & id)
280  {
281  if(id.x() > width_b)
282  {
283  return;
284  }
285 
286  // Accumulators for the block 0
287  int32x4x4_t c0 =
288  {
289  {
290  vdupq_n_s32(0),
291  vdupq_n_s32(0),
292  vdupq_n_s32(0),
293  vdupq_n_s32(0)
294  }
295  };
296 
297  auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
298  auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
299  auto vec_a_end_addr = vec_a + width_a;
300 
301  // This for loop performs 8 accumulations
302  for(; vec_a <= (vec_a_end_addr - 8);)
303  {
304  const int8x8_t a00_s8 = vld1_s8(vec_a);
305  const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
306  const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
307  const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
308  const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
309  const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
310  const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
311  const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
312  const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
313 
314  // Convert a00_s8 to int16_t and get the lower part
315  const int16x4x2_t a00_s16 =
316  {
317  {
318  vget_low_s16(vmovl_s8(a00_s8)),
319  vget_high_s16(vmovl_s8(a00_s8))
320  }
321  };
322 
323  const int16x4x4_t b00_s16 =
324  {
325  {
326  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
327  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
328  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
329  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
330  }
331  };
332 
333  const int16x4x4_t b10_s16 =
334  {
335  {
336  vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
337  vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
338  vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
339  vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
340  }
341  };
342 
343  const int16x4x4_t b20_s16 =
344  {
345  {
346  vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
347  vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
348  vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
349  vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
350  }
351  };
352 
353  const int16x4x4_t b30_s16 =
354  {
355  {
356  vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
357  vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
358  vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
359  vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
360  }
361  };
362 
363  const int16x4x4_t b40_s16 =
364  {
365  {
366  vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
367  vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
368  vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
369  vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
370  }
371  };
372 
373  const int16x4x4_t b50_s16 =
374  {
375  {
376  vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
377  vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
378  vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
379  vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
380  }
381  };
382 
383  const int16x4x4_t b60_s16 =
384  {
385  {
386  vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
387  vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
388  vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
389  vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
390  }
391  };
392 
393  const int16x4x4_t b70_s16 =
394  {
395  {
396  vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
397  vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
398  vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
399  vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
400  }
401  };
402 
403  // Accumulate 0:
404  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
405  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
406  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
407  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
408 
409  // Accumulate 1:
410  c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
411  c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
412  c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
413  c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
414 
415  // Accumulate 2:
416  c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
417  c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
418  c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
419  c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
420 
421  // Accumulate 3:
422  c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
423  c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
424  c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
425  c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
426 
427  // Accumulate 4:
428  c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
429  c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
430  c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
431  c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
432 
433  // Accumulate 5:
434  c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
435  c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
436  c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
437  c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
438 
439  // Accumulate 6:
440  c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
441  c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
442  c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
443  c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
444 
445  // Accumulate 7:
446  c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
447  c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
448  c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
449  c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
450 
451  vec_a += 8;
452  matrix_b += 8 * stride_b;
453  }
454 
455  // This for loop performs the left-over accumulations
456  for(; vec_a < vec_a_end_addr;)
457  {
458  const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
459  const int8x16_t b00_s8 = vld1q_s8(matrix_b);
460 
461  const int16x4x4_t b00_s16 =
462  {
463  {
464  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
465  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
466  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
467  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
468  }
469  };
470 
471  // Convert a00_s8 to uint16_t and get the lower part
472  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
473 
474  // Accumulate 0:
475  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
476  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
477  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
478  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
479 
480  vec_a += 1;
481  matrix_b += stride_b;
482  }
483 
484  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
485  if(id.x() < (width_out - 16))
486  {
487  vst1q_s32(vec_out + 0, c0.val[0]);
488  vst1q_s32(vec_out + 4, c0.val[1]);
489  vst1q_s32(vec_out + 8, c0.val[2]);
490  vst1q_s32(vec_out + 12, c0.val[3]);
491  }
492  else
493  {
494  auto left_over = width_out - id.x();
495  for(auto k = 0; k < 4 && left_over; ++k)
496  {
497  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
498  {
499  *(vec_out + k * 4 + j) = c0.val[k][j];
500  }
501  }
502  }
503  },
504  ina, inb, out);
505 }
506 
507 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
508 {
509  const auto width_out = static_cast<int>(out_info.dimension(0));
510  const auto height_out = static_cast<int>(out_info.dimension(1));
511  const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
512  execute_window_loop(window, [&](const Coordinates & id)
513  {
514  const uint8_t *mtx_a0 = ina.ptr();
515  const uint8_t *mtx_b0 = inb.ptr();
516 
517  // Note: Since the input are all positives, we can use uint32_t
518  // Accumulators for the block 0
519  uint32x4x4_t c0 =
520  {
521  {
522  vdupq_n_u32(0),
523  vdupq_n_u32(0),
524  vdupq_n_u32(0),
525  vdupq_n_u32(0)
526  }
527  };
528 
529  // Accumulators for the block 1
530  uint32x4x4_t c1 =
531  {
532  {
533  vdupq_n_u32(0),
534  vdupq_n_u32(0),
535  vdupq_n_u32(0),
536  vdupq_n_u32(0)
537  }
538  };
539 
540  // Accumulators for the block 2
541  uint32x4x4_t c2 =
542  {
543  {
544  vdupq_n_u32(0),
545  vdupq_n_u32(0),
546  vdupq_n_u32(0),
547  vdupq_n_u32(0)
548  }
549  };
550 
551  // Accumulators for the block 3
552  uint32x4x4_t c3 =
553  {
554  {
555  vdupq_n_u32(0),
556  vdupq_n_u32(0),
557  vdupq_n_u32(0),
558  vdupq_n_u32(0)
559  }
560  };
561 
562  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
563  {
564  const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
565  const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
566 
567  // Convert a00_u8 to uint16_t and get the lower part
568  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
569 
570  // Convert b00_s8 to uint16_t
571  const uint16x4x4_t b00_u16 =
572  {
573  {
574  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
575  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
576  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
577  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
578  }
579  };
580 
581  // 4x4 block 0
582  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
583  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
584  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
585  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
586 
587  // 4x4 block 1
588  c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
589  c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
590  c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
591  c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
592 
593  // 4x4 block 2
594  c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
595  c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
596  c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
597  c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
598 
599  // 4x4 block 3
600  c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
601  c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
602  c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
603  c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
604  }
605 
606  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
607 
608  if(id.y() < height_out && id.x() < (width_out - 16))
609  {
610  vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
611  vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
612  vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
613  vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
614  if(id.y() + 1 < height_out)
615  {
616  vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
617  vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
618  vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
619  vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
620  if(id.y() + 2 < height_out)
621  {
622  vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
623  vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
624  vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
625  vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
626  if(id.y() + 3 < height_out)
627  {
628  vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
629  vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
630  vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
631  vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
632  }
633  }
634  }
635  }
636  else
637  {
638  const auto left_over_value = width_out - id.x();
639  auto left_over = left_over_value;
640  for(auto k = 0; k < 4 && left_over; ++k)
641  {
642  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
643  {
644  *(mtx_out + k * 4 + j) = c0.val[k][j];
645  }
646  }
647  if(id.y() + 1 < height_out)
648  {
649  left_over = left_over_value;
650  for(auto k = 0; k < 4 && left_over; ++k)
651  {
652  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
653  {
654  *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
655  }
656  }
657  if(id.y() + 2 < height_out)
658  {
659  left_over = left_over_value;
660  for(auto k = 0; k < 4 && left_over; ++k)
661  {
662  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
663  {
664  *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
665  }
666  }
667  if(id.y() + 3 < height_out)
668  {
669  left_over = left_over_value;
670  for(auto k = 0; k < 4 && left_over; ++k)
671  {
672  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
673  {
674  *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
675  }
676  }
677  }
678  }
679  }
680  }
681  },
682  ina, inb, out);
683 }
684 
685 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
686 {
687  const auto width_out = static_cast<int>(out_info.dimension(0));
688  const auto height_out = static_cast<int>(out_info.dimension(1));
689  const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
690  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
691  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
692  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
693  execute_window_loop(window, [&](const Coordinates & id)
694  {
695  auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
696  auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
697 
698  // Note: Since the input are all positives, we can use uint32_t
699  // Accumulators for the block 0
700  int32x4x4_t c0 =
701  {
702  {
703  vdupq_n_s32(0),
704  vdupq_n_s32(0),
705  vdupq_n_s32(0),
706  vdupq_n_s32(0)
707  }
708  };
709 
710  // Accumulators for the block 1
711  int32x4x4_t c1 =
712  {
713  {
714  vdupq_n_s32(0),
715  vdupq_n_s32(0),
716  vdupq_n_s32(0),
717  vdupq_n_s32(0)
718  }
719  };
720 
721  // Accumulators for the block 2
722  int32x4x4_t c2 =
723  {
724  {
725  vdupq_n_s32(0),
726  vdupq_n_s32(0),
727  vdupq_n_s32(0),
728  vdupq_n_s32(0)
729  }
730  };
731 
732  // Accumulators for the block 3
733  int32x4x4_t c3 =
734  {
735  {
736  vdupq_n_s32(0),
737  vdupq_n_s32(0),
738  vdupq_n_s32(0),
739  vdupq_n_s32(0)
740  }
741  };
742 
743  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
744  {
745  const int8x8_t a00_s8 = vld1_s8(mtx_a0);
746  const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
747 
748  // Convert a00_s8 to uint16_t and get the lower part
749  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
750 
751  // Convert b00_s8 to int16_t
752  const int16x4x4_t b00_s16 =
753  {
754  {
755  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
756  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
757  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
758  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
759  }
760  };
761 
762  // 4x4 block 0
763  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
764  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
765  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
766  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
767 
768  // 4x4 block 1
769  c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
770  c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
771  c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
772  c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
773 
774  // 4x4 block 2
775  c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
776  c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
777  c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
778  c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
779 
780  // 4x4 block 3
781  c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
782  c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
783  c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
784  c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
785  }
786  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
787  if(id.y() < height_out && id.x() < (width_out - 16))
788  {
789  vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
790  vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
791  vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
792  vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
793  if(id.y() + 1 < height_out)
794  {
795  vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
796  vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
797  vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
798  vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
799  if(id.y() + 2 < height_out)
800  {
801  vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
802  vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
803  vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
804  vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
805  if(id.y() + 3 < height_out)
806  {
807  vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
808  vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
809  vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
810  vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
811  }
812  }
813  }
814  }
815  else if(id.y() < height_out)
816  {
817  const auto left_over_value = width_out - id.x();
818  auto left_over = left_over_value;
819  for(auto k = 0; k < 4 && left_over; ++k)
820  {
821  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
822  {
823  *(mtx_out + k * 4 + j) = c0.val[k][j];
824  }
825  }
826  if(id.y() + 1 < height_out)
827  {
828  left_over = left_over_value;
829  for(auto k = 0; k < 4 && left_over; ++k)
830  {
831  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
832  {
833  *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
834  }
835  }
836  if(id.y() + 2 < height_out)
837  {
838  left_over = left_over_value;
839  for(auto k = 0; k < 4 && left_over; ++k)
840  {
841  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
842  {
843  *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
844  }
845  }
846  if(id.y() + 3 < height_out)
847  {
848  left_over = left_over_value;
849  for(auto k = 0; k < 4 && left_over; ++k)
850  {
851  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
852  {
853  *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
854  }
855  }
856  }
857  }
858  }
859  }
860 
861  },
862  ina, inb, out);
863 }
864 } // namespace
865 
866 namespace
867 {
868 Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
869 {
873 
874  TensorShape in0_shape = input0->tensor_shape();
875  TensorShape in1_shape = input1->tensor_shape();
876  TensorShape out_shape = output->tensor_shape();
877 
878  // Check vector-by-matrix case
879  if(out_shape[1] == 1)
880  {
881  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
882  }
883  else
884  {
885  in0_shape.collapse(2);
886  in1_shape.collapse(2);
887  out_shape.collapse(2);
888 
889  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
890  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
891  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
892  }
893 
894  return Status{};
895 }
896 } // namespace
897 
899  : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
900 {
901 }
902 
903 void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
904 {
905  ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
906  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
907 
908  TensorShape in1_shape = input1->info()->tensor_shape();
909  in1_shape.collapse(2);
910 
911  _input0 = input0;
912  _input1 = input1;
913  _output = output;
914  _slide_matrix_b = in1_shape[2] != 1;
915 
916  constexpr unsigned int num_elems_processed_per_iteration_x = 16;
917  constexpr unsigned int num_elems_processed_per_iteration_y = 4;
918 
919  Window win;
920 
921  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
922  if((output->info()->dimension(1) == 1))
923  {
924  // Configure kernel window
925  win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
926 
927  Coordinates coord;
928  coord.set_num_dimensions(output->info()->num_dimensions());
929  output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
930  }
931  else
932  {
933  win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
934  output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
935  }
936 
937  INEKernel::configure(win);
938 }
939 
941 {
942  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
943 
944  return Status{};
945 }
946 
948 {
949  ARM_COMPUTE_UNUSED(info);
952 
953  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
954  if((_output->info()->dimension(1) == 1))
955  {
956  const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
957  const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
958  const auto width_out = static_cast<int>(_output->info()->dimension(0));
959  const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
960 
961  // The implementation computes 16 elements per iteration
962  const int window_start_x = 16 * info.thread_id;
963  const int window_step_x = 16 * info.num_threads;
964  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
965  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
966 
967  Window win_out(window);
968  win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
969  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
970 
971  Window win_a(window);
972  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
973  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
974 
975  Window win_b;
976  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
977  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
978  if(_input1->info()->num_dimensions() >= 3)
979  {
980  win_b = window;
981  }
982  win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
983  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
984 
985  Iterator ina(_input0, win_a);
986  Iterator inb(_input1, win_b);
987  Iterator out(_output, win_out);
988 
989  switch(_input0->info()->data_type())
990  {
991  case DataType::S8:
993  {
994  vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
995  break;
996  }
997  case DataType::U8:
998  case DataType::QASYMM8:
999  {
1000  vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
1001  break;
1002  }
1003  default:
1004  {
1005  ARM_COMPUTE_ERROR("Not supported");
1006  break;
1007  }
1008  }
1009  }
1010  else
1011  {
1012  const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
1013  const int width_b = _input1->info()->dimension(0);
1014 
1015  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1016  Window win_a(window);
1017  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1018  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1019 
1020  // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1021  Window win_b;
1022  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1023  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1024  if(_slide_matrix_b)
1025  {
1026  win_b = window;
1027  }
1028  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1029  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1030 
1031  // The step x and step y for the output matrix has been already set using in configure()
1032  Iterator ina(_input0, win_a);
1033  Iterator inb(_input1, win_b);
1034  Iterator out(_output, window);
1035 
1036  switch(_input0->info()->data_type())
1037  {
1038  case DataType::S8:
1040  {
1041  matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
1042  break;
1043  }
1044  case DataType::U8:
1045  case DataType::QASYMM8:
1046  {
1047  matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
1048  break;
1049  }
1050  default:
1051  {
1052  ARM_COMPUTE_ERROR("Not supported");
1053  break;
1054  }
1055  }
1056  }
1057 }
1058 } // namespace arm_compute
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
Shape of a tensor.
Definition: TensorShape.h:39
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
1 channel, 1 U8 per channel
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
size_t dimension(size_t index) const override
Return the size of the requested dimension.
Definition: TensorInfo.h:233
virtual DataType data_type() const =0
Data type used for each element of the tensor.
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Describe one of the image&#39;s dimensions with a start, end and step.
Definition: Window.h:77
const Strides & strides_in_bytes() const override
The strides in bytes for accessing each dimension of the tensor.
Definition: TensorInfo.h:241
Status class.
Definition: Error.h:52
void configure(const ITensor *input0, const ITensor *input1, ITensor *output)
Initialise the kernel&#39;s input and output.
Interface for Neon tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2021 Arm Limited.
virtual void set_valid_region(const ValidRegion &valid_region)=0
Set the valid region of the tensor.
1 channel, 1 S32 per channel
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:71
quantized, asymmetric fixed-point 8-bit number unsigned
Class to describe a number of elements in each dimension.
Definition: Steps.h:40
Coordinates of an item.
Definition: Coordinates.h:37
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:106
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
Definition: Helpers.inl:139
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:941
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
Definition: CPPTypes.h:235
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
Definition: Window.h:154
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:45
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
Definition: Dimensions.h:149
quantized, asymmetric fixed-point 8-bit number signed
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
Container for valid region of a window.
Definition: Types.h:188
constexpr int end() const
Return the end of the dimension.
Definition: Window.h:99
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:46
constexpr int start() const
Return the start of the dimension.
Definition: Window.h:94
signed 8-bit number
size_t element_size() const override
Element size in bytes calculated as data_size() * num_channels()
Definition: TensorInfo.h:250
Describe a multidimensional execution window.
Definition: Window.h:39
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
Definition: TensorShape.h:133
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:205
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
Definition: Window.h:145