45 void inline vector_matrix_multiply_u8(
Iterator &ina,
Iterator &inb,
Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
66 auto vec_a = reinterpret_cast<const uint8_t *>(ina.
ptr());
67 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.
ptr());
68 auto vec_a_end_addr = vec_a + width_a;
71 for(; vec_a <= (vec_a_end_addr - 8);)
73 const uint8x8_t a00_u8 = vld1_u8(vec_a);
74 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
75 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
76 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
77 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
78 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
79 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
80 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
81 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
84 const uint16x4x2_t a00_u16 =
87 vget_low_u16(vmovl_u8(a00_u8)),
88 vget_high_u16(vmovl_u8(a00_u8))
92 const uint16x4x4_t b00_u16 =
95 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
96 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
97 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
102 const uint16x4x4_t b10_u16 =
105 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
106 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
107 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
112 const uint16x4x4_t b20_u16 =
115 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
116 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
117 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
122 const uint16x4x4_t b30_u16 =
125 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
126 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
127 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
132 const uint16x4x4_t b40_u16 =
135 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
136 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
137 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
142 const uint16x4x4_t b50_u16 =
145 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
146 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
147 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
152 const uint16x4x4_t b60_u16 =
155 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
156 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
157 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
162 const uint16x4x4_t b70_u16 =
165 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
166 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
167 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
173 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
174 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
175 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
176 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
180 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
181 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
182 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
186 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
187 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
188 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
192 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
193 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
194 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
198 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
199 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
200 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
204 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
205 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
206 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
210 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
211 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
212 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
216 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
217 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
218 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221 matrix_b += 8 * stride_b;
225 for(; vec_a < vec_a_end_addr;)
227 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
228 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
230 const uint16x4x4_t b00_u16 =
233 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
234 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
235 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
241 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
245 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
246 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
247 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250 matrix_b += stride_b;
253 auto vec_out = reinterpret_cast<int32_t *>(out.
ptr());
254 if(
id.x() < (width_out - 16))
256 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
257 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
258 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
259 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
263 auto left_over = width_out -
id.x();
264 for(
auto k = 0; k < 4 && left_over; ++k)
266 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
268 *(vec_out + k * 4 + j) = c0.val[k][j];
276 void inline vector_matrix_multiply_s8(
Iterator &ina,
Iterator &inb,
Iterator &out,
int width_a,
int width_b,
int width_out,
size_t stride_b,
const Window &window)
296 auto vec_a = reinterpret_cast<const int8_t *>(ina.
ptr());
297 auto matrix_b = reinterpret_cast<const int8_t *>(inb.
ptr());
298 auto vec_a_end_addr = vec_a + width_a;
301 for(; vec_a <= (vec_a_end_addr - 8);)
303 const int8x8_t a00_s8 = vld1_s8(vec_a);
304 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
305 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
306 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
307 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
308 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
309 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
310 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
311 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
314 const int16x4x2_t a00_s16 =
317 vget_low_s16(vmovl_s8(a00_s8)),
318 vget_high_s16(vmovl_s8(a00_s8))
322 const int16x4x4_t b00_s16 =
325 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
326 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
327 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
328 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
332 const int16x4x4_t b10_s16 =
335 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
336 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
337 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
338 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
342 const int16x4x4_t b20_s16 =
345 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
346 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
347 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
348 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
352 const int16x4x4_t b30_s16 =
355 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
356 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
357 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
358 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
362 const int16x4x4_t b40_s16 =
365 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
366 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
367 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
368 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
372 const int16x4x4_t b50_s16 =
375 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
376 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
377 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
378 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
382 const int16x4x4_t b60_s16 =
385 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
386 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
387 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
388 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
392 const int16x4x4_t b70_s16 =
395 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
396 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
397 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
398 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
403 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
404 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
405 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
406 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
409 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
410 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
411 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
412 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
415 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
416 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
417 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
418 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
421 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
422 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
423 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
424 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
427 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
428 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
429 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
430 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
433 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
434 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
435 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
436 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
439 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
440 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
441 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
442 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
445 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
446 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
447 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
448 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
451 matrix_b += 8 * stride_b;
455 for(; vec_a < vec_a_end_addr;)
457 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
458 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
460 const int16x4x4_t b00_s16 =
463 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
464 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
465 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
466 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
471 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
474 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
475 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
476 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
477 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
480 matrix_b += stride_b;
483 auto vec_out = reinterpret_cast<int32_t *>(out.
ptr());
484 if(
id.x() < (width_out - 16))
486 vst1q_s32(vec_out + 0, c0.val[0]);
487 vst1q_s32(vec_out + 4, c0.val[1]);
488 vst1q_s32(vec_out + 8, c0.val[2]);
489 vst1q_s32(vec_out + 12, c0.val[3]);
493 auto left_over = width_out -
id.x();
494 for(
auto k = 0; k < 4 && left_over; ++k)
496 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
498 *(vec_out + k * 4 + j) = c0.val[k][j];
508 const auto width_out = static_cast<int>(out_info.
dimension(0));
509 const auto height_out = static_cast<int>(out_info.
dimension(1));
513 const uint8_t *mtx_a0 = ina.
ptr();
514 const uint8_t *mtx_b0 = inb.
ptr();
561 for(
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
563 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
564 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
567 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
570 const uint16x4x4_t b00_u16 =
573 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
574 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
575 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
576 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
581 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
582 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
583 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
584 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
587 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
588 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
589 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
590 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
593 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
594 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
595 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
596 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
599 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
600 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
601 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
602 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
605 auto mtx_out = reinterpret_cast<int32_t *>(out.
ptr());
607 if(
id.y() < height_out &&
id.x() < (width_out - 16))
609 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
610 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
611 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
612 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
613 if(
id.y() + 1 < height_out)
615 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
616 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
617 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
618 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
619 if(
id.y() + 2 < height_out)
621 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
622 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
623 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
624 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
625 if(
id.y() + 3 < height_out)
627 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
628 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
629 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
630 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
637 const auto left_over_value = width_out -
id.x();
638 auto left_over = left_over_value;
639 for(
auto k = 0; k < 4 && left_over; ++k)
641 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
643 *(mtx_out + k * 4 + j) = c0.val[k][j];
646 if(
id.y() + 1 < height_out)
648 left_over = left_over_value;
649 for(
auto k = 0; k < 4 && left_over; ++k)
651 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
653 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
656 if(
id.y() + 2 < height_out)
658 left_over = left_over_value;
659 for(
auto k = 0; k < 4 && left_over; ++k)
661 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
663 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
666 if(
id.y() + 3 < height_out)
668 left_over = left_over_value;
669 for(
auto k = 0; k < 4 && left_over; ++k)
671 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
673 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
686 const auto width_out = static_cast<int>(out_info.
dimension(0));
687 const auto height_out = static_cast<int>(out_info.
dimension(1));
694 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.
ptr());
695 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.
ptr());
742 for(
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
744 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
745 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
748 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
751 const int16x4x4_t b00_s16 =
754 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
755 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
756 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
757 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
762 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
763 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
764 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
765 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
768 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
769 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
770 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
771 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
774 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
775 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
776 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
777 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
780 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
781 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
782 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
783 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
785 auto mtx_out = reinterpret_cast<int32_t *>(out.
ptr());
786 if(
id.y() < height_out &&
id.x() < (width_out - 16))
788 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
789 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
790 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
791 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
792 if(
id.y() + 1 < height_out)
794 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
795 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
796 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
797 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
798 if(
id.y() + 2 < height_out)
800 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
801 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
802 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
803 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
804 if(
id.y() + 3 < height_out)
806 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
807 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
808 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
809 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
814 else if(
id.y() < height_out)
816 const auto left_over_value = width_out -
id.x();
817 auto left_over = left_over_value;
818 for(
auto k = 0; k < 4 && left_over; ++k)
820 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
822 *(mtx_out + k * 4 + j) = c0.val[k][j];
825 if(
id.y() + 1 < height_out)
827 left_over = left_over_value;
828 for(
auto k = 0; k < 4 && left_over; ++k)
830 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
832 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
835 if(
id.y() + 2 < height_out)
837 left_over = left_over_value;
838 for(
auto k = 0; k < 4 && left_over; ++k)
840 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
842 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
845 if(
id.y() + 3 < height_out)
847 left_over = left_over_value;
848 for(
auto k = 0; k < 4 && left_over; ++k)
850 for(
auto j = 0; j < 4 && left_over; ++j, --left_over)
852 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
878 if(out_shape[1] == 1)
889 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2],
"Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
898 : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
913 _slide_matrix_b = in1_shape[2] != 1;
915 constexpr
unsigned int num_elems_processed_per_iteration_x = 16;
916 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
931 INEKernel::configure(win);
950 const auto width_matrix_a = static_cast<int>(_input0->
info()->
dimension(0));
951 const auto width_matrix_b = static_cast<int>(_input1->
info()->
dimension(0));
952 const auto width_out = static_cast<int>(_output->
info()->
dimension(0));
956 const int window_start_x = 16 *
info.thread_id;
957 const int window_step_x = 16 *
info.num_threads;
959 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
988 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
window);
994 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
window);
1035 matrix_multiply_s8(ina, inb, out, width_b, *_output->
info(),
window);
1041 matrix_multiply_u8(ina, inb, out, width_b, *_output->
info(),
window);
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
NEGEMMLowpMatrixMultiplyKernel()
Constructor.
1 channel, 1 U8 per channel
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
size_t dimension(size_t index) const override
Return the size of the requested dimension.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
Store the tensor's metadata.
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Describe one of the image's dimensions with a start, end and step.
const Strides & strides_in_bytes() const override
The strides in bytes for accessing each dimension of the tensor.
void configure(const ITensor *input0, const ITensor *input1, ITensor *output)
Initialise the kernel's input and output.
Interface for CPU tensor.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 S32 per channel
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
auto ceil_to_multiple(S value, T divisor) -> decltype(((value+divisor - 1)/divisor) *divisor)
Computes the smallest number larger or equal to value that is a multiple of divisor.
quantized, asymmetric fixed-point 8-bit number unsigned
Class to describe a number of elements in each dimension.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
size_t data_size_from_type(DataType data_type)
The size in bytes of the data type.
constexpr uint8_t * ptr() const
Return a pointer to the current pixel.
void set(size_t dimension, const Dimension &dim)
Set the values of a given dimension.
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
constexpr const Dimension & y() const
Alias to access the second dimension of the window.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
quantized, asymmetric fixed-point 8-bit number signed
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
constexpr int end() const
Return the end of the dimension.
Iterator updated by execute_window_loop for each window element.
constexpr int start() const
Return the start of the dimension.
size_t element_size() const override
Element size in bytes calculated as data_size() * num_channels()
Describe a multidimensional execution window.
void collapse(size_t n, size_t first=0)
Collapse the first n dimensions.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
constexpr const Dimension & x() const
Alias to access the first dimension of the window.