46 void inline vector_matrix_multiply_u8(
Iterator &ina,
Iterator &inb,
Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
67 auto vec_a =
reinterpret_cast<const uint8_t *
>(ina.
ptr());
68 auto matrix_b =
reinterpret_cast<const uint8_t *
>(inb.
ptr());
69 auto vec_a_end_addr = vec_a + width_a;
72 for(; vec_a <= (vec_a_end_addr - 8);)
74 const uint8x8_t a00_u8 = vld1_u8(vec_a);
75 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
76 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
77 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
78 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
79 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
80 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
81 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
82 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
85 const uint16x4x2_t a00_u16 =
88 vget_low_u16(vmovl_u8(a00_u8)),
89 vget_high_u16(vmovl_u8(a00_u8))
93 const uint16x4x4_t b00_u16 =
96 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
97 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
99 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
103 const uint16x4x4_t b10_u16 =
106 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
107 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
109 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
113 const uint16x4x4_t b20_u16 =
116 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
117 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
119 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
123 const uint16x4x4_t b30_u16 =
126 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
127 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
129 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
133 const uint16x4x4_t b40_u16 =
136 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
137 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
139 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
143 const uint16x4x4_t b50_u16 =
146 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
147 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
149 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
153 const uint16x4x4_t b60_u16 =
156 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
157 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
159 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
163 const uint16x4x4_t b70_u16 =
166 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
167 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
169 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
174 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
175 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
176 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
177 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
180 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
181 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
182 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
183 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
186 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
187 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
188 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
189 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
192 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
193 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
194 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
195 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
198 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
199 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
200 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
201 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
204 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
205 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
206 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
207 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
210 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
211 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
212 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
213 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
216 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
217 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
218 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
219 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
222 matrix_b += 8 * stride_b;
226 for(; vec_a < vec_a_end_addr;)
228 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
229 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231 const uint16x4x4_t b00_u16 =
234 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
235 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
237 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
242 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
245 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
246 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
247 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
248 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
251 matrix_b += stride_b;
254 auto vec_out =
reinterpret_cast<int32_t *
>(out.
ptr());
255 if(
id.x() < (width_out - 16))
257 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
258 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
259 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
260 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
264 auto left_over = width_out -
id.x();
265 for(
auto k = 0; k < 4 && left_over; ++k)
267 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
269 *(vec_out + k * 4 + j) = c0.val[k][j];
277 void inline vector_matrix_multiply_s8(
Iterator &ina,
Iterator &inb,
Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
297 auto vec_a =
reinterpret_cast<const int8_t *
>(ina.
ptr());
298 auto matrix_b =
reinterpret_cast<const int8_t *
>(inb.
ptr());
299 auto vec_a_end_addr = vec_a + width_a;
302 for(; vec_a <= (vec_a_end_addr - 8);)
304 const int8x8_t a00_s8 = vld1_s8(vec_a);
305 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
306 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
307 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
308 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
309 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
310 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
311 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
312 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
315 const int16x4x2_t a00_s16 =
318 vget_low_s16(vmovl_s8(a00_s8)),
319 vget_high_s16(vmovl_s8(a00_s8))
323 const int16x4x4_t b00_s16 =
326 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
327 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
328 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
329 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
333 const int16x4x4_t b10_s16 =
336 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
337 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
338 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
339 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
343 const int16x4x4_t b20_s16 =
346 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
347 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
348 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
349 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
353 const int16x4x4_t b30_s16 =
356 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
357 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
358 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
359 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
363 const int16x4x4_t b40_s16 =
366 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
367 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
368 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
369 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
373 const int16x4x4_t b50_s16 =
376 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
377 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
378 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
379 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
383 const int16x4x4_t b60_s16 =
386 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
387 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
388 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
389 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
393 const int16x4x4_t b70_s16 =
396 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
397 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
398 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
399 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
404 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
405 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
406 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
407 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
410 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
411 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
412 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
413 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
416 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
417 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
418 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
419 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
422 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
423 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
424 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
425 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
428 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
429 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
430 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
431 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
434 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
435 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
436 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
437 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
440 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
441 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
442 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
443 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
446 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
447 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
448 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
449 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
452 matrix_b += 8 * stride_b;
456 for(; vec_a < vec_a_end_addr;)
458 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
459 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
461 const int16x4x4_t b00_s16 =
464 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
465 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
466 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
467 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
472 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
475 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
476 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
477 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
478 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
481 matrix_b += stride_b;
484 auto vec_out =
reinterpret_cast<int32_t *
>(out.
ptr());
485 if(
id.x() < (width_out - 16))
487 vst1q_s32(vec_out + 0, c0.val[0]);
488 vst1q_s32(vec_out + 4, c0.val[1]);
489 vst1q_s32(vec_out + 8, c0.val[2]);
490 vst1q_s32(vec_out + 12, c0.val[3]);
494 auto left_over = width_out -
id.x();
495 for(
auto k = 0; k < 4 && left_over; ++k)
497 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
499 *(vec_out + k * 4 + j) = c0.val[k][j];
509 const auto width_out =
static_cast<int>(out_info.
dimension(0));
510 const auto height_out =
static_cast<int>(out_info.
dimension(1));
514 const uint8_t *mtx_a0 = ina.
ptr();
515 const uint8_t *mtx_b0 = inb.
ptr();
562 for(
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
564 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
565 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
568 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
571 const uint16x4x4_t b00_u16 =
574 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
575 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
576 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
577 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
582 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
583 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
584 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
585 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
588 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
589 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
590 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
591 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
594 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
595 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
596 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
597 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
600 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
601 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
602 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
603 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
606 auto mtx_out =
reinterpret_cast<int32_t *
>(out.
ptr());
608 if(
id.y() < height_out &&
id.x() < (width_out - 16))
610 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
611 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
612 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
613 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
614 if(
id.y() + 1 < height_out)
616 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
617 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
618 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
619 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
620 if(
id.y() + 2 < height_out)
622 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
623 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
624 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
625 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
626 if(
id.y() + 3 < height_out)
628 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
629 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
630 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
631 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
638 const auto left_over_value = width_out -
id.x();
639 auto left_over = left_over_value;
640 for(
auto k = 0; k < 4 && left_over; ++k)
642 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
644 *(mtx_out + k * 4 + j) = c0.val[k][j];
647 if(
id.y() + 1 < height_out)
649 left_over = left_over_value;
650 for(
auto k = 0; k < 4 && left_over; ++k)
652 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
654 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
657 if(
id.y() + 2 < height_out)
659 left_over = left_over_value;
660 for(
auto k = 0; k < 4 && left_over; ++k)
662 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
664 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
667 if(
id.y() + 3 < height_out)
669 left_over = left_over_value;
670 for(
auto k = 0; k < 4 && left_over; ++k)
672 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
674 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
687 const auto width_out =
static_cast<int>(out_info.
dimension(0));
688 const auto height_out =
static_cast<int>(out_info.
dimension(1));
695 auto *mtx_a0 =
reinterpret_cast<const int8_t *
>(ina.
ptr());
696 auto *mtx_b0 =
reinterpret_cast<const int8_t *
>(inb.
ptr());
743 for(
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
745 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
746 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
749 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
752 const int16x4x4_t b00_s16 =
755 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
756 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
757 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
758 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
763 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
764 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
765 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
766 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
769 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
770 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
771 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
772 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
775 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
776 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
777 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
778 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
781 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
782 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
783 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
784 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
786 auto mtx_out =
reinterpret_cast<int32_t *
>(out.
ptr());
787 if(
id.y() < height_out &&
id.x() < (width_out - 16))
789 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
790 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
791 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
792 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
793 if(
id.y() + 1 < height_out)
795 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
796 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
797 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
798 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
799 if(
id.y() + 2 < height_out)
801 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
802 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
803 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
804 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
805 if(
id.y() + 3 < height_out)
807 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
808 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
809 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
810 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
815 else if(
id.y() < height_out)
817 const auto left_over_value = width_out -
id.x();
818 auto left_over = left_over_value;
819 for(
auto k = 0; k < 4 && left_over; ++k)
821 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
823 *(mtx_out + k * 4 + j) = c0.val[k][j];
826 if(
id.y() + 1 < height_out)
828 left_over = left_over_value;
829 for(
auto k = 0; k < 4 && left_over; ++k)
831 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
833 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
836 if(
id.y() + 2 < height_out)
838 left_over = left_over_value;
839 for(
auto k = 0; k < 4 && left_over; ++k)
841 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
843 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
846 if(
id.y() + 3 < height_out)
848 left_over = left_over_value;
849 for(
auto k = 0; k < 4 && left_over; ++k)
851 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
853 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
879 if(out_shape[1] == 1)
890 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2],
"Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
899 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
914 _slide_matrix_b = in1_shape[2] != 1;
916 constexpr
unsigned int num_elems_processed_per_iteration_x = 16;
917 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
937 INEKernel::configure(win);
956 const auto width_matrix_a =
static_cast<int>(_input0->
info()->
dimension(0));
957 const auto width_matrix_b =
static_cast<int>(_input1->
info()->
dimension(0));
958 const auto width_out =
static_cast<int>(_output->
info()->
dimension(0));
962 const int window_start_x = 16 * info.
thread_id;
965 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
994 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
1000 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
1041 matrix_multiply_s8(ina, inb, out, width_b, *_output->
info(),
window);
1047 matrix_multiply_u8(ina, inb, out, width_b, *_output->
info(),
window);
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
NEGEMMLowpMatrixMultiplyKernel()
Constructor.
1 channel, 1 U8 per channel
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
size_t dimension(size_t index) const override
Return the size of the requested dimension.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
Store the tensor's metadata.
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Describe one of the image's dimensions with a start, end and step.
const Strides & strides_in_bytes() const override
The strides in bytes for accessing each dimension of the tensor.
void configure(const ITensor *input0, const ITensor *input1, ITensor *output)
Initialise the kernel's input and output.
Interface for Neon tensor.
Copyright (c) 2017-2021 Arm Limited.
virtual void set_valid_region(const ValidRegion &valid_region)=0
Set the valid region of the tensor.
1 channel, 1 S32 per channel
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
quantized, asymmetric fixed-point 8-bit number unsigned
Class to describe a number of elements in each dimension.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
quantized, asymmetric fixed-point 8-bit number signed
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
Container for valid region of a window.
constexpr int end() const
Return the end of the dimension.
Iterator updated by execute_window_loop for each window element.
constexpr int start() const
Return the start of the dimension.
size_t element_size() const override
Element size in bytes calculated as data_size() * num_channels()
Describe a multidimensional execution window.
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
constexpr const Dimension & x() const
Alias to access the first dimension of the window.