Compute Library
 22.02
CpuGemmLowpMatrixMultiplyKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Utils.h"
36 
37 #include <arm_neon.h>
38 
39 namespace arm_compute
40 {
41 namespace cpu
42 {
43 namespace kernels
44 {
45 namespace
46 {
47 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
48 {
49  execute_window_loop(window, [&](const Coordinates & id)
50  {
51  if(id.x() > width_b)
52  {
53  return;
54  }
55 
56  // Note: Since the input are all positives, we can use uint32_t
57  // Accumulators for the block 0
58  uint32x4x4_t c0 =
59  {
60  {
61  vdupq_n_u32(0),
62  vdupq_n_u32(0),
63  vdupq_n_u32(0),
64  vdupq_n_u32(0)
65  }
66  };
67 
68  auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
69  auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
70  auto vec_a_end_addr = vec_a + width_a;
71 
72  // This for loop performs 8 accumulations
73  for(; vec_a <= (vec_a_end_addr - 8);)
74  {
75  const uint8x8_t a00_u8 = vld1_u8(vec_a);
76  const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77  const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78  const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79  const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80  const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81  const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82  const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83  const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
84 
85  // Convert a00_u8 to uint16_t and get the lower part
86  const uint16x4x2_t a00_u16 =
87  {
88  {
89  vget_low_u16(vmovl_u8(a00_u8)),
90  vget_high_u16(vmovl_u8(a00_u8))
91  }
92  };
93 
94  const uint16x4x4_t b00_u16 =
95  {
96  {
97  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101  }
102  };
103 
104  const uint16x4x4_t b10_u16 =
105  {
106  {
107  vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108  vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109  vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110  vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111  }
112  };
113 
114  const uint16x4x4_t b20_u16 =
115  {
116  {
117  vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118  vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119  vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120  vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121  }
122  };
123 
124  const uint16x4x4_t b30_u16 =
125  {
126  {
127  vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128  vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129  vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130  vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131  }
132  };
133 
134  const uint16x4x4_t b40_u16 =
135  {
136  {
137  vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138  vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139  vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140  vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141  }
142  };
143 
144  const uint16x4x4_t b50_u16 =
145  {
146  {
147  vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148  vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149  vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150  vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151  }
152  };
153 
154  const uint16x4x4_t b60_u16 =
155  {
156  {
157  vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158  vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159  vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160  vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161  }
162  };
163 
164  const uint16x4x4_t b70_u16 =
165  {
166  {
167  vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168  vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169  vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170  vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171  }
172  };
173 
174  // Accumulate 0:
175  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179 
180  // Accumulate 1:
181  c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182  c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183  c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184  c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185 
186  // Accumulate 2:
187  c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188  c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189  c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190  c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191 
192  // Accumulate 3:
193  c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194  c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195  c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196  c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197 
198  // Accumulate 4:
199  c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200  c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201  c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202  c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203 
204  // Accumulate 5:
205  c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206  c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207  c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208  c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209 
210  // Accumulate 6:
211  c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212  c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213  c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214  c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215 
216  // Accumulate 7:
217  c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218  c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219  c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220  c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221 
222  vec_a += 8;
223  matrix_b += 8 * stride_b;
224  }
225 
226  // This for loop performs the left-over accumulations
227  for(; vec_a < vec_a_end_addr;)
228  {
229  const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230  const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231 
232  const uint16x4x4_t b00_u16 =
233  {
234  {
235  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239  }
240  };
241 
242  // Convert a00_u8 to uint16_t and get the lower part
243  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244 
245  // Accumulate 0:
246  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250 
251  vec_a += 1;
252  matrix_b += stride_b;
253  }
254 
255  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
256  if(id.x() < (width_out - 16))
257  {
258  vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
259  vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
260  vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
261  vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
262  }
263  else
264  {
265  auto left_over = width_out - id.x();
266  for(auto k = 0; k < 4 && left_over; ++k)
267  {
268  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
269  {
270  *(vec_out + k * 4 + j) = c0.val[k][j];
271  }
272  }
273  }
274  },
275  ina, inb, out);
276 }
277 
278 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
279 {
280  execute_window_loop(window, [&](const Coordinates & id)
281  {
282  if(id.x() > width_b)
283  {
284  return;
285  }
286 
287  // Accumulators for the block 0
288  int32x4x4_t c0 =
289  {
290  {
291  vdupq_n_s32(0),
292  vdupq_n_s32(0),
293  vdupq_n_s32(0),
294  vdupq_n_s32(0)
295  }
296  };
297 
298  auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
299  auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
300  auto vec_a_end_addr = vec_a + width_a;
301 
302  // This for loop performs 8 accumulations
303  for(; vec_a <= (vec_a_end_addr - 8);)
304  {
305  const int8x8_t a00_s8 = vld1_s8(vec_a);
306  const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
307  const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
308  const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
309  const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
310  const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
311  const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
312  const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
313  const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
314 
315  // Convert a00_s8 to int16_t and get the lower part
316  const int16x4x2_t a00_s16 =
317  {
318  {
319  vget_low_s16(vmovl_s8(a00_s8)),
320  vget_high_s16(vmovl_s8(a00_s8))
321  }
322  };
323 
324  const int16x4x4_t b00_s16 =
325  {
326  {
327  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
328  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
329  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
330  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
331  }
332  };
333 
334  const int16x4x4_t b10_s16 =
335  {
336  {
337  vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
338  vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
339  vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
340  vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
341  }
342  };
343 
344  const int16x4x4_t b20_s16 =
345  {
346  {
347  vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
348  vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
349  vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
350  vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
351  }
352  };
353 
354  const int16x4x4_t b30_s16 =
355  {
356  {
357  vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
358  vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
359  vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
360  vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
361  }
362  };
363 
364  const int16x4x4_t b40_s16 =
365  {
366  {
367  vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
368  vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
369  vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
370  vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
371  }
372  };
373 
374  const int16x4x4_t b50_s16 =
375  {
376  {
377  vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
378  vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
379  vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
380  vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
381  }
382  };
383 
384  const int16x4x4_t b60_s16 =
385  {
386  {
387  vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
388  vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
389  vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
390  vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
391  }
392  };
393 
394  const int16x4x4_t b70_s16 =
395  {
396  {
397  vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
398  vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
399  vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
400  vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
401  }
402  };
403 
404  // Accumulate 0:
405  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
406  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
407  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
408  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
409 
410  // Accumulate 1:
411  c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
412  c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
413  c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
414  c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
415 
416  // Accumulate 2:
417  c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
418  c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
419  c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
420  c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
421 
422  // Accumulate 3:
423  c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
424  c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
425  c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
426  c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
427 
428  // Accumulate 4:
429  c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
430  c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
431  c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
432  c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
433 
434  // Accumulate 5:
435  c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
436  c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
437  c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
438  c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
439 
440  // Accumulate 6:
441  c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
442  c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
443  c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
444  c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
445 
446  // Accumulate 7:
447  c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
448  c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
449  c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
450  c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
451 
452  vec_a += 8;
453  matrix_b += 8 * stride_b;
454  }
455 
456  // This for loop performs the left-over accumulations
457  for(; vec_a < vec_a_end_addr;)
458  {
459  const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
460  const int8x16_t b00_s8 = vld1q_s8(matrix_b);
461 
462  const int16x4x4_t b00_s16 =
463  {
464  {
465  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
466  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
467  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
468  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
469  }
470  };
471 
472  // Convert a00_s8 to uint16_t and get the lower part
473  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
474 
475  // Accumulate 0:
476  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
477  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
478  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
479  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
480 
481  vec_a += 1;
482  matrix_b += stride_b;
483  }
484 
485  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
486  if(id.x() < (width_out - 16))
487  {
488  vst1q_s32(vec_out + 0, c0.val[0]);
489  vst1q_s32(vec_out + 4, c0.val[1]);
490  vst1q_s32(vec_out + 8, c0.val[2]);
491  vst1q_s32(vec_out + 12, c0.val[3]);
492  }
493  else
494  {
495  auto left_over = width_out - id.x();
496  for(auto k = 0; k < 4 && left_over; ++k)
497  {
498  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
499  {
500  *(vec_out + k * 4 + j) = c0.val[k][j];
501  }
502  }
503  }
504  },
505  ina, inb, out);
506 }
507 
508 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
509 {
510  const auto width_out = static_cast<int>(out_info.dimension(0));
511  const auto height_out = static_cast<int>(out_info.dimension(1));
512  const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
513  execute_window_loop(window, [&](const Coordinates & id)
514  {
515  const uint8_t *mtx_a0 = ina.ptr();
516  const uint8_t *mtx_b0 = inb.ptr();
517 
518  // Note: Since the input are all positives, we can use uint32_t
519  // Accumulators for the block 0
520  uint32x4x4_t c0 =
521  {
522  {
523  vdupq_n_u32(0),
524  vdupq_n_u32(0),
525  vdupq_n_u32(0),
526  vdupq_n_u32(0)
527  }
528  };
529 
530  // Accumulators for the block 1
531  uint32x4x4_t c1 =
532  {
533  {
534  vdupq_n_u32(0),
535  vdupq_n_u32(0),
536  vdupq_n_u32(0),
537  vdupq_n_u32(0)
538  }
539  };
540 
541  // Accumulators for the block 2
542  uint32x4x4_t c2 =
543  {
544  {
545  vdupq_n_u32(0),
546  vdupq_n_u32(0),
547  vdupq_n_u32(0),
548  vdupq_n_u32(0)
549  }
550  };
551 
552  // Accumulators for the block 3
553  uint32x4x4_t c3 =
554  {
555  {
556  vdupq_n_u32(0),
557  vdupq_n_u32(0),
558  vdupq_n_u32(0),
559  vdupq_n_u32(0)
560  }
561  };
562 
563  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
564  {
565  const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
566  const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
567 
568  // Convert a00_u8 to uint16_t and get the lower part
569  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
570 
571  // Convert b00_s8 to uint16_t
572  const uint16x4x4_t b00_u16 =
573  {
574  {
575  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
576  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
577  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
578  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
579  }
580  };
581 
582  // 4x4 block 0
583  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
584  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
585  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
586  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
587 
588  // 4x4 block 1
589  c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
590  c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
591  c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
592  c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
593 
594  // 4x4 block 2
595  c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
596  c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
597  c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
598  c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
599 
600  // 4x4 block 3
601  c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
602  c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
603  c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
604  c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
605  }
606 
607  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
608 
609  if(id.y() < height_out && id.x() < (width_out - 16))
610  {
611  vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
612  vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
613  vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
614  vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
615  if(id.y() + 1 < height_out)
616  {
617  vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
618  vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
619  vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
620  vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
621  if(id.y() + 2 < height_out)
622  {
623  vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
624  vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
625  vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
626  vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
627  if(id.y() + 3 < height_out)
628  {
629  vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
630  vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
631  vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
632  vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
633  }
634  }
635  }
636  }
637  else
638  {
639  const auto left_over_value = width_out - id.x();
640  auto left_over = left_over_value;
641  for(auto k = 0; k < 4 && left_over; ++k)
642  {
643  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
644  {
645  *(mtx_out + k * 4 + j) = c0.val[k][j];
646  }
647  }
648  if(id.y() + 1 < height_out)
649  {
650  left_over = left_over_value;
651  for(auto k = 0; k < 4 && left_over; ++k)
652  {
653  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
654  {
655  *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
656  }
657  }
658  if(id.y() + 2 < height_out)
659  {
660  left_over = left_over_value;
661  for(auto k = 0; k < 4 && left_over; ++k)
662  {
663  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
664  {
665  *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
666  }
667  }
668  if(id.y() + 3 < height_out)
669  {
670  left_over = left_over_value;
671  for(auto k = 0; k < 4 && left_over; ++k)
672  {
673  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
674  {
675  *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
676  }
677  }
678  }
679  }
680  }
681  }
682  },
683  ina, inb, out);
684 }
685 
686 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
687 {
688  const auto width_out = static_cast<int>(out_info.dimension(0));
689  const auto height_out = static_cast<int>(out_info.dimension(1));
690  const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
691  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
692  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
693  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
694  execute_window_loop(window, [&](const Coordinates & id)
695  {
696  auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
697  auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
698 
699  // Note: Since the input are all positives, we can use uint32_t
700  // Accumulators for the block 0
701  int32x4x4_t c0 =
702  {
703  {
704  vdupq_n_s32(0),
705  vdupq_n_s32(0),
706  vdupq_n_s32(0),
707  vdupq_n_s32(0)
708  }
709  };
710 
711  // Accumulators for the block 1
712  int32x4x4_t c1 =
713  {
714  {
715  vdupq_n_s32(0),
716  vdupq_n_s32(0),
717  vdupq_n_s32(0),
718  vdupq_n_s32(0)
719  }
720  };
721 
722  // Accumulators for the block 2
723  int32x4x4_t c2 =
724  {
725  {
726  vdupq_n_s32(0),
727  vdupq_n_s32(0),
728  vdupq_n_s32(0),
729  vdupq_n_s32(0)
730  }
731  };
732 
733  // Accumulators for the block 3
734  int32x4x4_t c3 =
735  {
736  {
737  vdupq_n_s32(0),
738  vdupq_n_s32(0),
739  vdupq_n_s32(0),
740  vdupq_n_s32(0)
741  }
742  };
743 
744  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
745  {
746  const int8x8_t a00_s8 = vld1_s8(mtx_a0);
747  const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
748 
749  // Convert a00_s8 to uint16_t and get the lower part
750  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
751 
752  // Convert b00_s8 to int16_t
753  const int16x4x4_t b00_s16 =
754  {
755  {
756  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
757  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
758  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
759  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
760  }
761  };
762 
763  // 4x4 block 0
764  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
765  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
766  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
767  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
768 
769  // 4x4 block 1
770  c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
771  c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
772  c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
773  c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
774 
775  // 4x4 block 2
776  c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
777  c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
778  c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
779  c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
780 
781  // 4x4 block 3
782  c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
783  c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
784  c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
785  c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
786  }
787  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
788  if(id.y() < height_out && id.x() < (width_out - 16))
789  {
790  vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
791  vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
792  vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
793  vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
794  if(id.y() + 1 < height_out)
795  {
796  vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
797  vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
798  vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
799  vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
800  if(id.y() + 2 < height_out)
801  {
802  vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
803  vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
804  vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
805  vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
806  if(id.y() + 3 < height_out)
807  {
808  vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
809  vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
810  vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
811  vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
812  }
813  }
814  }
815  }
816  else if(id.y() < height_out)
817  {
818  const auto left_over_value = width_out - id.x();
819  auto left_over = left_over_value;
820  for(auto k = 0; k < 4 && left_over; ++k)
821  {
822  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
823  {
824  *(mtx_out + k * 4 + j) = c0.val[k][j];
825  }
826  }
827  if(id.y() + 1 < height_out)
828  {
829  left_over = left_over_value;
830  for(auto k = 0; k < 4 && left_over; ++k)
831  {
832  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
833  {
834  *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
835  }
836  }
837  if(id.y() + 2 < height_out)
838  {
839  left_over = left_over_value;
840  for(auto k = 0; k < 4 && left_over; ++k)
841  {
842  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
843  {
844  *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
845  }
846  }
847  if(id.y() + 3 < height_out)
848  {
849  left_over = left_over_value;
850  for(auto k = 0; k < 4 && left_over; ++k)
851  {
852  for(auto j = 0; j < 4 && left_over; ++j, --left_over)
853  {
854  *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
855  }
856  }
857  }
858  }
859  }
860  }
861 
862  },
863  ina, inb, out);
864 }
865 
866 Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
867 {
871 
872  TensorShape in0_shape = src0->tensor_shape();
873  TensorShape in1_shape = src1->tensor_shape();
874  TensorShape out_shape = dst->tensor_shape();
875 
876  // Check vector-by-matrix case
877  if(out_shape[1] == 1)
878  {
879  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
880  }
881  else
882  {
883  in0_shape.collapse(2);
884  in1_shape.collapse(2);
885  out_shape.collapse(2);
886 
887  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
888  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
889  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
890  }
891 
892  return Status{};
893 }
894 } // namespace
895 
897 {
898  ARM_COMPUTE_UNUSED(src0);
899  ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
900  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, dst));
901 
902  TensorShape in1_shape = src1->tensor_shape();
903  in1_shape.collapse(2);
904 
905  _slide_matrix_b = in1_shape[2] != 1;
906 
907  constexpr unsigned int num_elems_processed_per_iteration_x = 16;
908  constexpr unsigned int num_elems_processed_per_iteration_y = 4;
909 
910  Window win;
911  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
912  if((dst->dimension(1) == 1))
913  {
914  // Configure kernel window
915  win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
916  }
917  else
918  {
919  win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
920  }
921 
922  ICpuKernel::configure(win);
923 }
924 
926 {
927  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, dst));
928  return Status{};
929 }
930 
932 {
933  ARM_COMPUTE_UNUSED(info);
936 
937  auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
938  auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
939  auto dst = tensors.get_tensor(TensorType::ACL_DST);
940 
941  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
942  if((dst->info()->dimension(1) == 1))
943  {
944  const auto width_matrix_a = static_cast<int>(src0->info()->dimension(0));
945  const auto width_matrix_b = static_cast<int>(src1->info()->dimension(0));
946  const auto width_out = static_cast<int>(dst->info()->dimension(0));
947  const auto in_b_stride = static_cast<int>(src1->info()->strides_in_bytes()[1] / data_size_from_type(src1->info()->data_type()));
948 
949  // The implementation computes 16 elements per iteration
950  const int window_start_x = 16 * info.thread_id;
951  const int window_step_x = 16 * info.num_threads;
952  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
953  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
954 
955  Window win_out(window);
956  win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
957  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
958 
959  Window win_a(window);
960  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
961  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
962 
963  Window win_b;
964  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
965  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
966  if(src1->info()->num_dimensions() >= 3)
967  {
968  win_b = window;
969  }
970  win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
971  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
972 
973  Iterator ina(src0, win_a);
974  Iterator inb(src1, win_b);
975  Iterator out(dst, win_out);
976 
977  switch(src0->info()->data_type())
978  {
979  case DataType::S8:
981  {
982  vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
983  break;
984  }
985  case DataType::U8:
986  case DataType::QASYMM8:
987  {
988  vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
989  break;
990  }
991  default:
992  {
993  ARM_COMPUTE_ERROR("Not supported");
994  break;
995  }
996  }
997  }
998  else
999  {
1000  const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
1001  const int width_b = src1->info()->dimension(0);
1002 
1003  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1004  Window win_a(window);
1005  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1006  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1007 
1008  // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1009  Window win_b;
1010  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1011  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1012  if(_slide_matrix_b)
1013  {
1014  win_b = window;
1015  }
1016  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1017  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1018 
1019  // The step x and step y for the output matrix has been already set using in configure()
1020  Iterator ina(src0, win_a);
1021  Iterator inb(src1, win_b);
1022  Iterator out(dst, window);
1023 
1024  switch(src0->info()->data_type())
1025  {
1026  case DataType::S8:
1028  {
1029  matrix_multiply_s8(ina, inb, out, width_b, *dst->info(), window);
1030  break;
1031  }
1032  case DataType::U8:
1033  case DataType::QASYMM8:
1034  {
1035  matrix_multiply_u8(ina, inb, out, width_b, *dst->info(), window);
1036  break;
1037  }
1038  default:
1039  {
1040  ARM_COMPUTE_ERROR("Not supported");
1041  break;
1042  }
1043  }
1044  }
1045 }
1046 
1048 {
1049  return "CpuGemmLowpMatrixMultiplyKernel";
1050 }
1051 } // namespace kernels
1052 } // namespace cpu
1053 } // namespace arm_compute
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
Shape of a tensor.
Definition: TensorShape.h:39
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
1 channel, 1 U8 per channel
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Describe one of the image&#39;s dimensions with a start, end and step.
Definition: Window.h:77
Status class.
Definition: Error.h:52
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 S32 per channel
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:54
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
void configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
Initialise the kernel&#39;s input and output.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:71
quantized, asymmetric fixed-point 8-bit number unsigned
Class to describe a number of elements in each dimension.
Definition: Steps.h:40
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:106
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:915
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
Static function to check if given info will lead to a valid configuration.
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:64
Information about executing thread and CPU.
Definition: CPPTypes.h:169
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
Definition: Window.h:154
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
Tensor packing service.
Definition: ITensorPack.h:39
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
quantized, asymmetric fixed-point 8-bit number signed
constexpr int end() const
Return the end of the dimension.
Definition: Window.h:99
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:46
constexpr int start() const
Return the start of the dimension.
Definition: Window.h:94
signed 8-bit number
Describe a multidimensional execution window.
Definition: Window.h:39
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
Definition: TensorShape.h:133
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:201
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
Definition: Window.h:145