47 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
68 auto vec_a =
reinterpret_cast<const uint8_t *
>(ina.ptr());
69 auto matrix_b =
reinterpret_cast<const uint8_t *
>(inb.ptr());
70 auto vec_a_end_addr = vec_a + width_a;
73 for(; vec_a <= (vec_a_end_addr - 8);)
75 const uint8x8_t a00_u8 = vld1_u8(vec_a);
76 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
86 const uint16x4x2_t a00_u16 =
89 vget_low_u16(vmovl_u8(a00_u8)),
90 vget_high_u16(vmovl_u8(a00_u8))
94 const uint16x4x4_t b00_u16 =
97 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
104 const uint16x4x4_t b10_u16 =
107 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
114 const uint16x4x4_t b20_u16 =
117 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
124 const uint16x4x4_t b30_u16 =
127 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
134 const uint16x4x4_t b40_u16 =
137 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
144 const uint16x4x4_t b50_u16 =
147 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
154 const uint16x4x4_t b60_u16 =
157 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
164 const uint16x4x4_t b70_u16 =
167 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
175 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
181 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
187 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
193 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
199 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
205 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
211 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
217 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
223 matrix_b += 8 * stride_b;
227 for(; vec_a < vec_a_end_addr;)
229 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
232 const uint16x4x4_t b00_u16 =
235 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
243 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
246 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
252 matrix_b += stride_b;
255 auto vec_out =
reinterpret_cast<int32_t *
>(out.ptr());
256 if(
id.x() < (width_out - 16))
258 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
259 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
260 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
261 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
265 auto left_over = width_out -
id.x();
266 for(
auto k = 0;
k < 4 && left_over; ++
k)
268 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
270 *(vec_out +
k * 4 + j) = c0.val[
k][j];
278 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
298 auto vec_a =
reinterpret_cast<const int8_t *
>(ina.ptr());
299 auto matrix_b =
reinterpret_cast<const int8_t *
>(inb.ptr());
300 auto vec_a_end_addr = vec_a + width_a;
303 for(; vec_a <= (vec_a_end_addr - 8);)
305 const int8x8_t a00_s8 = vld1_s8(vec_a);
306 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
307 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
308 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
309 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
310 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
311 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
312 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
313 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
316 const int16x4x2_t a00_s16 =
319 vget_low_s16(vmovl_s8(a00_s8)),
320 vget_high_s16(vmovl_s8(a00_s8))
324 const int16x4x4_t b00_s16 =
327 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
328 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
329 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
330 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
334 const int16x4x4_t b10_s16 =
337 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
338 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
339 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
340 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
344 const int16x4x4_t b20_s16 =
347 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
348 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
349 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
350 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
354 const int16x4x4_t b30_s16 =
357 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
358 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
359 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
360 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
364 const int16x4x4_t b40_s16 =
367 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
368 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
369 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
370 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
374 const int16x4x4_t b50_s16 =
377 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
378 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
379 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
380 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
384 const int16x4x4_t b60_s16 =
387 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
388 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
389 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
390 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
394 const int16x4x4_t b70_s16 =
397 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
398 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
399 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
400 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
405 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
406 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
407 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
408 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
411 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
412 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
413 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
414 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
417 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
418 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
419 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
420 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
423 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
424 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
425 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
426 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
429 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
430 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
431 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
432 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
435 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
436 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
437 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
438 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
441 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
442 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
443 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
444 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
447 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
448 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
449 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
450 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
453 matrix_b += 8 * stride_b;
457 for(; vec_a < vec_a_end_addr;)
459 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
460 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
462 const int16x4x4_t b00_s16 =
465 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
466 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
467 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
468 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
473 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
476 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
477 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
478 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
479 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
482 matrix_b += stride_b;
485 auto vec_out =
reinterpret_cast<int32_t *
>(out.ptr());
486 if(
id.x() < (width_out - 16))
488 vst1q_s32(vec_out + 0, c0.val[0]);
489 vst1q_s32(vec_out + 4, c0.val[1]);
490 vst1q_s32(vec_out + 8, c0.val[2]);
491 vst1q_s32(vec_out + 12, c0.val[3]);
495 auto left_over = width_out -
id.x();
496 for(
auto k = 0;
k < 4 && left_over; ++
k)
498 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
500 *(vec_out +
k * 4 + j) = c0.val[
k][j];
508 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out,
int width_b,
const TensorInfo &out_info,
const Window &window)
510 const auto width_out =
static_cast<int>(out_info.dimension(0));
511 const auto height_out =
static_cast<int>(out_info.dimension(1));
512 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
515 const uint8_t *mtx_a0 = ina.ptr();
516 const uint8_t *mtx_b0 = inb.ptr();
563 for(
int k = 0;
k < width_b;
k += 16, mtx_a0 += 4, mtx_b0 += 16)
565 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
566 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
569 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
572 const uint16x4x4_t b00_u16 =
575 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
576 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
577 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
578 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
583 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
584 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
585 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
586 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
589 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
590 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
591 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
592 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
595 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
596 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
597 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
598 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
601 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
602 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
603 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
604 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
607 auto mtx_out =
reinterpret_cast<int32_t *
>(out.ptr());
609 if(
id.y() < height_out &&
id.x() < (width_out - 16))
611 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
612 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
613 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
614 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
615 if(
id.y() + 1 < height_out)
617 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
618 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
619 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
620 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
621 if(
id.y() + 2 < height_out)
623 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
624 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
625 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
626 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
627 if(
id.y() + 3 < height_out)
629 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
630 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
631 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
632 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
639 const auto left_over_value = width_out -
id.x();
640 auto left_over = left_over_value;
641 for(
auto k = 0;
k < 4 && left_over; ++
k)
643 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
645 *(mtx_out +
k * 4 + j) = c0.val[
k][j];
648 if(
id.y() + 1 < height_out)
650 left_over = left_over_value;
651 for(
auto k = 0;
k < 4 && left_over; ++
k)
653 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
655 *(mtx_out + out_stride +
k * 4 + j) = c1.val[
k][j];
658 if(
id.y() + 2 < height_out)
660 left_over = left_over_value;
661 for(
auto k = 0;
k < 4 && left_over; ++
k)
663 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
665 *(mtx_out + out_stride * 2 +
k * 4 + j) = c2.val[
k][j];
668 if(
id.y() + 3 < height_out)
670 left_over = left_over_value;
671 for(
auto k = 0;
k < 4 && left_over; ++
k)
673 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
675 *(mtx_out + out_stride * 3 +
k * 4 + j) = c3.val[
k][j];
686 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out,
int width_b,
const TensorInfo &out_info,
const Window &window)
688 const auto width_out =
static_cast<int>(out_info.dimension(0));
689 const auto height_out =
static_cast<int>(out_info.dimension(1));
690 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
696 auto *mtx_a0 =
reinterpret_cast<const int8_t *
>(ina.ptr());
697 auto *mtx_b0 =
reinterpret_cast<const int8_t *
>(inb.ptr());
744 for(
int k = 0;
k < width_b;
k += 16, mtx_a0 += 4, mtx_b0 += 16)
746 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
747 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
750 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
753 const int16x4x4_t b00_s16 =
756 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
757 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
758 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
759 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
764 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
765 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
766 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
767 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
770 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
771 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
772 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
773 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
776 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
777 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
778 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
779 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
782 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
783 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
784 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
785 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
787 auto mtx_out =
reinterpret_cast<int32_t *
>(out.ptr());
788 if(
id.y() < height_out &&
id.x() < (width_out - 16))
790 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
791 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
792 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
793 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
794 if(
id.y() + 1 < height_out)
796 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
797 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
798 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
799 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
800 if(
id.y() + 2 < height_out)
802 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
803 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
804 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
805 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
806 if(
id.y() + 3 < height_out)
808 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
809 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
810 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
811 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
816 else if(
id.y() < height_out)
818 const auto left_over_value = width_out -
id.x();
819 auto left_over = left_over_value;
820 for(
auto k = 0;
k < 4 && left_over; ++
k)
822 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
824 *(mtx_out +
k * 4 + j) = c0.val[
k][j];
827 if(
id.y() + 1 < height_out)
829 left_over = left_over_value;
830 for(
auto k = 0;
k < 4 && left_over; ++
k)
832 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
834 *(mtx_out + out_stride +
k * 4 + j) = c1.val[
k][j];
837 if(
id.y() + 2 < height_out)
839 left_over = left_over_value;
840 for(
auto k = 0;
k < 4 && left_over; ++
k)
842 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
844 *(mtx_out + out_stride * 2 +
k * 4 + j) = c2.val[
k][j];
847 if(
id.y() + 3 < height_out)
849 left_over = left_over_value;
850 for(
auto k = 0;
k < 4 && left_over; ++
k)
852 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
854 *(mtx_out + out_stride * 3 +
k * 4 + j) = c3.val[
k][j];
872 TensorShape in0_shape = src0->tensor_shape();
873 TensorShape in1_shape = src1->tensor_shape();
874 TensorShape out_shape =
dst->tensor_shape();
877 if(out_shape[1] == 1)
883 in0_shape.collapse(2);
884 in1_shape.collapse(2);
885 out_shape.collapse(2);
888 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2],
"Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
905 _slide_matrix_b = in1_shape[2] != 1;
907 constexpr
unsigned int num_elems_processed_per_iteration_x = 16;
908 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
912 if((
dst->dimension(1) == 1))
922 ICpuKernel::configure(win);
942 if((
dst->info()->dimension(1) == 1))
944 const auto width_matrix_a =
static_cast<int>(src0->info()->dimension(0));
945 const auto width_matrix_b =
static_cast<int>(src1->info()->dimension(0));
946 const auto width_out =
static_cast<int>(
dst->info()->dimension(0));
947 const auto in_b_stride =
static_cast<int>(src1->info()->strides_in_bytes()[1] /
data_size_from_type(src1->info()->data_type()));
950 const int window_start_x = 16 *
info.thread_id;
951 const int window_step_x = 16 *
info.num_threads;
953 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
966 if(src1->info()->num_dimensions() >= 3)
977 switch(src0->info()->data_type())
982 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
window);
988 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
window);
1000 const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
1001 const int width_b = src1->info()->dimension(0);
1024 switch(src0->info()->data_type())
1029 matrix_multiply_s8(ina, inb, out, width_b, *
dst->info(),
window);
1035 matrix_multiply_u8(ina, inb, out, width_b, *
dst->info(),
window);
1049 return "CpuGemmLowpMatrixMultiplyKernel";