Compute Library
 20.05
NEGEMMLowpMatrixMultiplyKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2020 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
27 #include "arm_compute/core/Error.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
35 
36 #include <arm_neon.h>
37 #include <cstddef>
38 #include <cstdint>
39 #include <tuple>
40 
41 using namespace arm_compute;
42 
43 namespace arm_compute
44 {
45 namespace
46 {
47 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
48 {
49  execute_window_loop(window, [&](const Coordinates & id)
50  {
51  if(id.x() > width_b)
52  {
53  return;
54  }
55 
56  // Note: Since the input are all positives, we can use uint32_t
57  // Accumulators for the block 0
58  uint32x4x4_t c0 =
59  {
60  {
61  vdupq_n_u32(0),
62  vdupq_n_u32(0),
63  vdupq_n_u32(0),
64  vdupq_n_u32(0)
65  }
66  };
67 
68  auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
69  auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
70  auto vec_a_end_addr = vec_a + width_a;
71 
72  // This for loop performs 8 accumulations
73  for(; vec_a <= (vec_a_end_addr - 8);)
74  {
75  const uint8x8_t a00_u8 = vld1_u8(vec_a);
76  const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77  const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78  const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79  const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80  const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81  const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82  const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83  const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
84 
85  // Convert a00_u8 to uint16_t and get the lower part
86  const uint16x4x2_t a00_u16 =
87  {
88  {
89  vget_low_u16(vmovl_u8(a00_u8)),
90  vget_high_u16(vmovl_u8(a00_u8))
91  }
92  };
93 
94  const uint16x4x4_t b00_u16 =
95  {
96  {
97  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101  }
102  };
103 
104  const uint16x4x4_t b10_u16 =
105  {
106  {
107  vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108  vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109  vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110  vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111  }
112  };
113 
114  const uint16x4x4_t b20_u16 =
115  {
116  {
117  vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118  vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119  vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120  vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121  }
122  };
123 
124  const uint16x4x4_t b30_u16 =
125  {
126  {
127  vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128  vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129  vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130  vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131  }
132  };
133 
134  const uint16x4x4_t b40_u16 =
135  {
136  {
137  vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138  vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139  vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140  vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141  }
142  };
143 
144  const uint16x4x4_t b50_u16 =
145  {
146  {
147  vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148  vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149  vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150  vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151  }
152  };
153 
154  const uint16x4x4_t b60_u16 =
155  {
156  {
157  vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158  vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159  vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160  vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161  }
162  };
163 
164  const uint16x4x4_t b70_u16 =
165  {
166  {
167  vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168  vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169  vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170  vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171  }
172  };
173 
174  // Accumulate 0:
175  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179 
180  // Accumulate 1:
181  c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182  c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183  c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184  c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185 
186  // Accumulate 2:
187  c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188  c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189  c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190  c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191 
192  // Accumulate 3:
193  c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194  c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195  c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196  c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197 
198  // Accumulate 4:
199  c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200  c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201  c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202  c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203 
204  // Accumulate 5:
205  c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206  c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207  c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208  c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209 
210  // Accumulate 6:
211  c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212  c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213  c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214  c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215 
216  // Accumulate 7:
217  c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218  c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219  c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220  c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221 
222  vec_a += 8;
223  matrix_b += 8 * stride_b;
224  }
225 
226  // This for loop performs the left-over accumulations
227  for(; vec_a < vec_a_end_addr;)
228  {
229  const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230  const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231 
232  const uint16x4x4_t b00_u16 =
233  {
234  {
235  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239  }
240  };
241 
242  // Convert a00_u8 to uint16_t and get the lower part
243  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244 
245  // Accumulate 0:
246  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250 
251  vec_a += 1;
252  matrix_b += stride_b;
253  }
254 
255  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
256  vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
257  vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
258  vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
259  vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
260  },
261  ina, inb, out);
262 }
263 
264 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
265 {
266  execute_window_loop(window, [&](const Coordinates & id)
267  {
268  if(id.x() > width_b)
269  {
270  return;
271  }
272 
273  // Accumulators for the block 0
274  int32x4x4_t c0 =
275  {
276  {
277  vdupq_n_s32(0),
278  vdupq_n_s32(0),
279  vdupq_n_s32(0),
280  vdupq_n_s32(0)
281  }
282  };
283 
284  auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
285  auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
286  auto vec_a_end_addr = vec_a + width_a;
287 
288  // This for loop performs 8 accumulations
289  for(; vec_a <= (vec_a_end_addr - 8);)
290  {
291  const int8x8_t a00_s8 = vld1_s8(vec_a);
292  const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
293  const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
294  const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
295  const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
296  const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
297  const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
298  const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
299  const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
300 
301  // Convert a00_s8 to int16_t and get the lower part
302  const int16x4x2_t a00_s16 =
303  {
304  {
305  vget_low_s16(vmovl_s8(a00_s8)),
306  vget_high_s16(vmovl_s8(a00_s8))
307  }
308  };
309 
310  const int16x4x4_t b00_s16 =
311  {
312  {
313  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
314  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
315  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
316  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
317  }
318  };
319 
320  const int16x4x4_t b10_s16 =
321  {
322  {
323  vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
324  vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
325  vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
326  vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
327  }
328  };
329 
330  const int16x4x4_t b20_s16 =
331  {
332  {
333  vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
334  vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
335  vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
336  vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
337  }
338  };
339 
340  const int16x4x4_t b30_s16 =
341  {
342  {
343  vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
344  vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
345  vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
346  vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
347  }
348  };
349 
350  const int16x4x4_t b40_s16 =
351  {
352  {
353  vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
354  vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
355  vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
356  vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
357  }
358  };
359 
360  const int16x4x4_t b50_s16 =
361  {
362  {
363  vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
364  vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
365  vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
366  vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
367  }
368  };
369 
370  const int16x4x4_t b60_s16 =
371  {
372  {
373  vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
374  vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
375  vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
376  vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
377  }
378  };
379 
380  const int16x4x4_t b70_s16 =
381  {
382  {
383  vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
384  vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
385  vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
386  vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
387  }
388  };
389 
390  // Accumulate 0:
391  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
392  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
393  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
394  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
395 
396  // Accumulate 1:
397  c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
398  c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
399  c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
400  c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
401 
402  // Accumulate 2:
403  c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
404  c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
405  c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
406  c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
407 
408  // Accumulate 3:
409  c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
410  c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
411  c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
412  c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
413 
414  // Accumulate 4:
415  c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
416  c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
417  c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
418  c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
419 
420  // Accumulate 5:
421  c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
422  c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
423  c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
424  c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
425 
426  // Accumulate 6:
427  c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
428  c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
429  c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
430  c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
431 
432  // Accumulate 7:
433  c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
434  c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
435  c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
436  c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
437 
438  vec_a += 8;
439  matrix_b += 8 * stride_b;
440  }
441 
442  // This for loop performs the left-over accumulations
443  for(; vec_a < vec_a_end_addr;)
444  {
445  const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
446  const int8x16_t b00_s8 = vld1q_s8(matrix_b);
447 
448  const int16x4x4_t b00_s16 =
449  {
450  {
451  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
452  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
453  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
454  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
455  }
456  };
457 
458  // Convert a00_s8 to uint16_t and get the lower part
459  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
460 
461  // Accumulate 0:
462  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
463  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
464  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
465  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
466 
467  vec_a += 1;
468  matrix_b += stride_b;
469  }
470 
471  auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
472  vst1q_s32(vec_out + 0, c0.val[0]);
473  vst1q_s32(vec_out + 4, c0.val[1]);
474  vst1q_s32(vec_out + 8, c0.val[2]);
475  vst1q_s32(vec_out + 12, c0.val[3]);
476  },
477  ina, inb, out);
478 }
479 
480 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
481 {
482  execute_window_loop(window, [&](const Coordinates &)
483  {
484  const uint8_t *mtx_a0 = ina.ptr();
485  const uint8_t *mtx_b0 = inb.ptr();
486 
487  // Note: Since the input are all positives, we can use uint32_t
488  // Accumulators for the block 0
489  uint32x4x4_t c0 =
490  {
491  {
492  vdupq_n_u32(0),
493  vdupq_n_u32(0),
494  vdupq_n_u32(0),
495  vdupq_n_u32(0)
496  }
497  };
498 
499  // Accumulators for the block 1
500  uint32x4x4_t c1 =
501  {
502  {
503  vdupq_n_u32(0),
504  vdupq_n_u32(0),
505  vdupq_n_u32(0),
506  vdupq_n_u32(0)
507  }
508  };
509 
510  // Accumulators for the block 2
511  uint32x4x4_t c2 =
512  {
513  {
514  vdupq_n_u32(0),
515  vdupq_n_u32(0),
516  vdupq_n_u32(0),
517  vdupq_n_u32(0)
518  }
519  };
520 
521  // Accumulators for the block 3
522  uint32x4x4_t c3 =
523  {
524  {
525  vdupq_n_u32(0),
526  vdupq_n_u32(0),
527  vdupq_n_u32(0),
528  vdupq_n_u32(0)
529  }
530  };
531 
532  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
533  {
534  const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
535  const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
536 
537  // Convert a00_u8 to uint16_t and get the lower part
538  const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
539 
540  // Convert b00_s8 to uint16_t
541  const uint16x4x4_t b00_u16 =
542  {
543  {
544  vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
545  vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
546  vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
547  vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
548  }
549  };
550 
551  // 4x4 block 0
552  c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
553  c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
554  c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
555  c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
556 
557  // 4x4 block 1
558  c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
559  c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
560  c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
561  c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
562 
563  // 4x4 block 2
564  c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
565  c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
566  c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
567  c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
568 
569  // 4x4 block 3
570  c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
571  c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
572  c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
573  c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
574  }
575 
576  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
577  vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
578  vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
579  vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
580  vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
581  vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
582  vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
583  vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
584  vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
585  vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
586  vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
587  vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
588  vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
589  vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
590  vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
591  vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
592  vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
593  },
594  ina, inb, out);
595 }
596 
597 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
598 {
599  // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
600  // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
601  // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
602  execute_window_loop(window, [&](const Coordinates &)
603  {
604  auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
605  auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
606 
607  // Note: Since the input are all positives, we can use uint32_t
608  // Accumulators for the block 0
609  int32x4x4_t c0 =
610  {
611  {
612  vdupq_n_s32(0),
613  vdupq_n_s32(0),
614  vdupq_n_s32(0),
615  vdupq_n_s32(0)
616  }
617  };
618 
619  // Accumulators for the block 1
620  int32x4x4_t c1 =
621  {
622  {
623  vdupq_n_s32(0),
624  vdupq_n_s32(0),
625  vdupq_n_s32(0),
626  vdupq_n_s32(0)
627  }
628  };
629 
630  // Accumulators for the block 2
631  int32x4x4_t c2 =
632  {
633  {
634  vdupq_n_s32(0),
635  vdupq_n_s32(0),
636  vdupq_n_s32(0),
637  vdupq_n_s32(0)
638  }
639  };
640 
641  // Accumulators for the block 3
642  int32x4x4_t c3 =
643  {
644  {
645  vdupq_n_s32(0),
646  vdupq_n_s32(0),
647  vdupq_n_s32(0),
648  vdupq_n_s32(0)
649  }
650  };
651 
652  for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
653  {
654  const int8x8_t a00_s8 = vld1_s8(mtx_a0);
655  const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
656 
657  // Convert a00_s8 to uint16_t and get the lower part
658  const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
659 
660  // Convert b00_s8 to int16_t
661  const int16x4x4_t b00_s16 =
662  {
663  {
664  vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
665  vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
666  vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
667  vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
668  }
669  };
670 
671  // 4x4 block 0
672  c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
673  c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
674  c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
675  c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
676 
677  // 4x4 block 1
678  c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
679  c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
680  c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
681  c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
682 
683  // 4x4 block 2
684  c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
685  c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
686  c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
687  c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
688 
689  // 4x4 block 3
690  c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
691  c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
692  c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
693  c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
694  }
695 
696  auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
697  vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
698  vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
699  vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
700  vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
701  vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
702  vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
703  vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
704  vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
705  vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
706  vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
707  vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
708  vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
709  vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
710  vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
711  vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
712  vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
713  },
714  ina, inb, out);
715 }
716 } // namespace
717 
718 class Coordinates;
719 } // namespace arm_compute
720 
721 namespace
722 {
723 Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
724 {
728 
729  TensorShape in0_shape = input0->tensor_shape();
730  TensorShape in1_shape = input1->tensor_shape();
731  TensorShape out_shape = output->tensor_shape();
732 
733  // Check vector-by-matrix case
734  if(out_shape[1] == 1)
735  {
736  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
737  }
738  else
739  {
740  in0_shape.collapse(2);
741  in1_shape.collapse(2);
742  out_shape.collapse(2);
743 
744  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
745  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
746  ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
747  }
748 
749  return Status{};
750 }
751 
752 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
753 {
754  constexpr unsigned int num_elems_processed_per_iteration_x = 16;
755  constexpr unsigned int num_elems_processed_per_iteration_y = 4;
756 
757  Window win;
758  bool window_changed = false;
759 
760  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
761  if((output->dimension(1) == 1))
762  {
763  // Configure kernel window
764  win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
765 
766  // We cannot read out-of-bound elements from matrix A as we use the left-over for loop
767  AccessWindowStatic in0_access(input0, 0, 0, input0->tensor_shape().x(), 1);
768  AccessWindowHorizontal in1_access(input1, 0, num_elems_processed_per_iteration_x);
769  AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
770 
771  window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
772 
773  Coordinates coord;
774  coord.set_num_dimensions(output->num_dimensions());
775  output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
776  }
777  else
778  {
779  win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
780 
781  unsigned int num_k_iterations = ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x) / 16;
782  // For each iteration of "k" we increment the input pointer by 4, and we load 8 elements a the time:
783  AccessWindowStatic in0_access(input0, 0, 0, (num_k_iterations - 1) * 4 + 8, input0->dimension(1));
784  AccessWindowStatic in1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
785  AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
786 
787  window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
788 
789  output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
790  }
791 
792  Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
793  return std::make_pair(err, win);
794 }
795 } // namespace
796 
798  : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
799 {
800 }
801 
802 void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
803 {
804  ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
805  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
806 
807  TensorShape in1_shape = input1->info()->tensor_shape();
808  in1_shape.collapse(2);
809 
810  _input0 = input0;
811  _input1 = input1;
812  _output = output;
813  _slide_matrix_b = in1_shape[2] != 1;
814 
815  // Configure kernel window
816  auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
817  ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
818  INEKernel::configure(win_config.second);
819 }
820 
822 {
823  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
824  ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
825 
826  return Status{};
827 }
828 
830 {
834 
835  // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
836  if((_output->info()->dimension(1) == 1))
837  {
838  const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
839  const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
840  const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
841 
842  // The implementation computes 16 elements per iteration
843  const int window_start_x = 16 * info.thread_id;
844  const int window_step_x = 16 * info.num_threads;
845  // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
846  const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
847 
848  Window win_out(window);
849  win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
850  win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
851 
852  Window win_a(window);
853  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
854  win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
855 
856  Window win_b;
857  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
858  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
859  if(_input1->info()->num_dimensions() >= 3)
860  {
861  win_b = window;
862  }
863  win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
864  win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
865 
866  Iterator ina(_input0, win_a);
867  Iterator inb(_input1, win_b);
868  Iterator out(_output, win_out);
869 
870  switch(_input0->info()->data_type())
871  {
872  case DataType::S8:
874  {
875  vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
876  break;
877  }
878  case DataType::U8:
879  case DataType::QASYMM8:
880  {
881  vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
882  break;
883  }
884  default:
885  {
886  ARM_COMPUTE_ERROR("Not supported");
887  break;
888  }
889  }
890  }
891  else
892  {
893  const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
894  const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size();
895 
896  // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
897  Window win_a(window);
898  win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
899  win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
900 
901  // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
902  Window win_b;
903  // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
904  // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
905  if(_slide_matrix_b)
906  {
907  win_b = window;
908  }
909  win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
910  win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
911 
912  // The step x and step y for the output matrix has been already set using in configure()
913  Iterator ina(_input0, win_a);
914  Iterator inb(_input1, win_b);
915  Iterator out(_output, window);
916 
917  const int width_b = _input1->info()->dimension(0);
918  switch(_input0->info()->data_type())
919  {
920  case DataType::S8:
922  {
923  matrix_multiply_s8(ina, inb, out, width_b, out_stride, window);
924  break;
925  }
926  case DataType::U8:
927  case DataType::QASYMM8:
928  {
929  matrix_multiply_u8(ina, inb, out, width_b, out_stride, window);
930  break;
931  }
932  default:
933  {
934  ARM_COMPUTE_ERROR("Not supported");
935  break;
936  }
937  }
938  }
939 }
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
Shape of a tensor.
Definition: TensorShape.h:39
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
1 channel, 1 U8 per channel
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
Store the tensor's metadata.
Definition: ITensorInfo.h:40
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Describe one of the image's dimensions with a start, end and step.
Definition: Window.h:75
Status class.
Definition: Error.h:52
void configure(const ITensor *input0, const ITensor *input1, ITensor *output)
Initialise the kernel's input and output.
Interface for NEON tensor.
Definition: ITensor.h:36
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps=Steps(), bool skip_border=false, BorderSize border_size=BorderSize())
Calculate the maximum window for a given tensor shape and border setting.
Definition: Helpers.cpp:28
Copyright (c) 2017-2020 ARM Limited.
Implementation of a static rectangular access pattern.
1 channel, 1 S32 per channel
T x() const
Alias to access the size of the first dimension.
Definition: Dimensions.h:81
Implementation of a rectangular access pattern.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
bool update_window_and_padding(Window &win, Ts &&... patterns)
Update window and padding size for each of the access patterns.
Definition: Helpers.h:437
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
Definition: Utils.h:67
quantized, asymmetric fixed-point 8-bit number unsigned
Class to describe a number of elements in each dimension.
Definition: Steps.h:40
Coordinates of an item.
Definition: Coordinates.h:37
Implementation of a row access pattern.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
Definition: Utils.h:102
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
Definition: Helpers.inl:185
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
Definition: Window.inl:49
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
Information about executing thread and CPU.
Definition: CPPTypes.h:225
#define ARM_COMPUTE_CREATE_ERROR(error_code, msg)
Creates an error with a given message.
Definition: Error.h:159
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
Definition: Window.h:152
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:123
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
Definition: Dimensions.h:128
quantized, asymmetric fixed-point 8-bit number signed
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
Container for valid region of a window.
Definition: Types.h:187
constexpr int end() const
Return the end of the dimension.
Definition: Window.h:97
Iterator updated by execute_window_loop for each window element.
Definition: Helpers.h:353
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:205
constexpr int start() const
Return the start of the dimension.
Definition: Window.h:92
signed 8-bit number
Describe a multidimensional execution window.
Definition: Window.h:39
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
Definition: TensorShape.h:132
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:941
constexpr const Dimension & x() const
Alias to access the first dimension of the window.
Definition: Window.h:143