48 void inline vector_matrix_multiply_u8(Iterator &ina,
59 [&](
const Coordinates &
id)
68 uint32x4x4_t c0 = {{vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0)}};
70 auto vec_a =
reinterpret_cast<const uint8_t *
>(ina.ptr());
71 auto matrix_b =
reinterpret_cast<const uint8_t *
>(inb.ptr());
72 auto vec_a_end_addr = vec_a + width_a;
75 for (; vec_a <= (vec_a_end_addr - 8);)
77 const uint8x8_t a00_u8 = vld1_u8(vec_a);
78 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
79 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
80 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
81 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
82 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
83 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
84 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
85 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
88 const uint16x4x2_t a00_u16 = {{vget_low_u16(vmovl_u8(a00_u8)), vget_high_u16(vmovl_u8(a00_u8))}};
90 const uint16x4x4_t b00_u16 = {
91 {vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
92 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))}};
94 const uint16x4x4_t b10_u16 = {
95 {vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
96 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))}};
98 const uint16x4x4_t b20_u16 = {
99 {vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
100 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))}};
102 const uint16x4x4_t b30_u16 = {
103 {vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
104 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))}};
106 const uint16x4x4_t b40_u16 = {
107 {vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
108 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))}};
110 const uint16x4x4_t b50_u16 = {
111 {vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
112 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))}};
114 const uint16x4x4_t b60_u16 = {
115 {vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
116 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))}};
118 const uint16x4x4_t b70_u16 = {
119 {vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
120 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))}};
123 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
124 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
125 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
126 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
129 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
130 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
131 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
132 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
135 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
136 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
137 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
138 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
141 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
142 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
143 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
144 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
147 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
148 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
149 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
150 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
153 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
154 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
155 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
156 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
159 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
160 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
161 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
162 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
165 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
166 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
167 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
168 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
171 matrix_b += 8 * stride_b;
175 for (; vec_a < vec_a_end_addr;)
177 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
178 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
180 const uint16x4x4_t b00_u16 = {
181 {vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
182 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))}};
185 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
188 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
189 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
190 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
191 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
194 matrix_b += stride_b;
197 auto vec_out =
reinterpret_cast<int32_t *
>(out.ptr());
198 if (
id.x() < (width_out - 16))
200 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
201 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
202 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
203 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
207 auto left_over = width_out -
id.x();
208 for (
auto k = 0; k < 4 && left_over; ++k)
210 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
212 *(vec_out + k * 4 + j) = c0.val[k][j];
220 void inline vector_matrix_multiply_s8(Iterator &ina,
227 const Window &window)
231 [&](
const Coordinates &
id)
233 if (
id.x() > width_b)
239 int32x4x4_t c0 = {{vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0)}};
241 auto vec_a =
reinterpret_cast<const int8_t *
>(ina.ptr());
242 auto matrix_b =
reinterpret_cast<const int8_t *
>(inb.ptr());
243 auto vec_a_end_addr = vec_a + width_a;
246 for (; vec_a <= (vec_a_end_addr - 8);)
248 const int8x8_t a00_s8 = vld1_s8(vec_a);
249 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
250 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
251 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
252 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
253 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
254 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
255 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
256 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
259 const int16x4x2_t a00_s16 = {{vget_low_s16(vmovl_s8(a00_s8)), vget_high_s16(vmovl_s8(a00_s8))}};
261 const int16x4x4_t b00_s16 = {
262 {vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
263 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))}};
265 const int16x4x4_t b10_s16 = {
266 {vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
267 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))}};
269 const int16x4x4_t b20_s16 = {
270 {vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
271 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))}};
273 const int16x4x4_t b30_s16 = {
274 {vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
275 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))}};
277 const int16x4x4_t b40_s16 = {
278 {vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
279 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))}};
281 const int16x4x4_t b50_s16 = {
282 {vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
283 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))}};
285 const int16x4x4_t b60_s16 = {
286 {vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
287 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))}};
289 const int16x4x4_t b70_s16 = {
290 {vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
291 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))}};
294 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
295 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
296 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
297 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
300 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
301 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
302 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
303 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
306 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
307 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
308 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
309 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
312 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
313 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
314 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
315 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
318 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
319 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
320 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
321 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
324 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
325 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
326 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
327 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
330 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
331 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
332 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
333 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
336 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
337 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
338 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
339 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
342 matrix_b += 8 * stride_b;
346 for (; vec_a < vec_a_end_addr;)
348 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
349 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
351 const int16x4x4_t b00_s16 = {
352 {vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
353 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))}};
356 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
359 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
360 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
361 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
362 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
365 matrix_b += stride_b;
368 auto vec_out =
reinterpret_cast<int32_t *
>(out.ptr());
369 if (
id.x() < (width_out - 16))
371 vst1q_s32(vec_out + 0, c0.val[0]);
372 vst1q_s32(vec_out + 4, c0.val[1]);
373 vst1q_s32(vec_out + 8, c0.val[2]);
374 vst1q_s32(vec_out + 12, c0.val[3]);
378 auto left_over = width_out -
id.x();
379 for (
auto k = 0; k < 4 && left_over; ++k)
381 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
383 *(vec_out + k * 4 + j) = c0.val[k][j];
391 void inline matrix_multiply_u8(
392 Iterator &ina, Iterator &inb, Iterator &out,
int width_b,
const TensorInfo &out_info,
const Window &window)
394 const auto width_out =
static_cast<int>(out_info.dimension(0));
395 const auto height_out =
static_cast<int>(out_info.dimension(1));
396 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
399 [&](
const Coordinates &
id)
401 const uint8_t *mtx_a0 = ina.ptr();
402 const uint8_t *mtx_b0 = inb.ptr();
406 uint32x4x4_t c0 = {{vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0)}};
409 uint32x4x4_t c1 = {{vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0)}};
412 uint32x4x4_t c2 = {{vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0)}};
415 uint32x4x4_t c3 = {{vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0)}};
417 for (
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
419 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
420 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
423 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
426 const uint16x4x4_t b00_u16 = {
427 {vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
428 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))), vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))}};
431 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
432 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
433 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
434 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
437 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
438 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
439 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
440 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
443 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
444 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
445 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
446 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
449 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
450 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
451 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
452 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
455 auto mtx_out =
reinterpret_cast<int32_t *
>(out.ptr());
457 if (
id.y() < height_out &&
id.x() < (width_out - 16))
459 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
460 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
461 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
462 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
463 if (
id.y() + 1 < height_out)
465 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
466 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
467 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
468 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
469 if (
id.y() + 2 < height_out)
471 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
472 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
473 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
474 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
475 if (
id.y() + 3 < height_out)
477 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
478 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
479 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
480 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
487 const auto left_over_value = width_out -
id.x();
488 auto left_over = left_over_value;
489 for (
auto k = 0; k < 4 && left_over; ++k)
491 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
493 *(mtx_out + k * 4 + j) = c0.val[k][j];
496 if (
id.y() + 1 < height_out)
498 left_over = left_over_value;
499 for (
auto k = 0; k < 4 && left_over; ++k)
501 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
503 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
506 if (
id.y() + 2 < height_out)
508 left_over = left_over_value;
509 for (
auto k = 0; k < 4 && left_over; ++k)
511 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
513 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
516 if (
id.y() + 3 < height_out)
518 left_over = left_over_value;
519 for (
auto k = 0; k < 4 && left_over; ++k)
521 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
523 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
534 void inline matrix_multiply_s8(
535 Iterator &ina, Iterator &inb, Iterator &out,
int width_b,
const TensorInfo &out_info,
const Window &window)
537 const auto width_out =
static_cast<int>(out_info.dimension(0));
538 const auto height_out =
static_cast<int>(out_info.dimension(1));
539 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
545 [&](
const Coordinates &
id)
547 auto *mtx_a0 =
reinterpret_cast<const int8_t *
>(ina.ptr());
548 auto *mtx_b0 =
reinterpret_cast<const int8_t *
>(inb.ptr());
552 int32x4x4_t c0 = {{vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0)}};
555 int32x4x4_t c1 = {{vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0)}};
558 int32x4x4_t c2 = {{vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0)}};
561 int32x4x4_t c3 = {{vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0)}};
563 for (
int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
565 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
566 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
569 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
572 const int16x4x4_t b00_s16 = {
573 {vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
574 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))), vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))}};
577 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
578 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
579 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
580 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
583 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
584 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
585 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
586 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
589 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
590 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
591 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
592 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
595 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
596 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
597 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
598 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
600 auto mtx_out =
reinterpret_cast<int32_t *
>(out.ptr());
601 if (
id.y() < height_out &&
id.x() < (width_out - 16))
603 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
604 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
605 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
606 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
607 if (
id.y() + 1 < height_out)
609 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
610 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
611 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
612 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
613 if (
id.y() + 2 < height_out)
615 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
616 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
617 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
618 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
619 if (
id.y() + 3 < height_out)
621 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
622 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
623 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
624 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
629 else if (
id.y() < height_out)
631 const auto left_over_value = width_out -
id.x();
632 auto left_over = left_over_value;
633 for (
auto k = 0; k < 4 && left_over; ++k)
635 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
637 *(mtx_out + k * 4 + j) = c0.val[k][j];
640 if (
id.y() + 1 < height_out)
642 left_over = left_over_value;
643 for (
auto k = 0; k < 4 && left_over; ++k)
645 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
647 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
650 if (
id.y() + 2 < height_out)
652 left_over = left_over_value;
653 for (
auto k = 0; k < 4 && left_over; ++k)
655 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
657 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
660 if (
id.y() + 3 < height_out)
662 left_over = left_over_value;
663 for (
auto k = 0; k < 4 && left_over; ++k)
665 for (
auto j = 0; j < 4 && left_over; ++j, --left_over)
667 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
687 TensorShape in0_shape = src0->tensor_shape();
688 TensorShape in1_shape = src1->tensor_shape();
689 TensorShape out_shape =
dst->tensor_shape();
692 if (out_shape[1] == 1)
695 "The number of input0's columns must be equal to input1's rows");
699 in0_shape.collapse(2);
700 in1_shape.collapse(2);
701 out_shape.collapse(2);
704 "Output tensor must have the same number of batches of input0 tensor");
706 in1_shape[2] != 1 && in0_shape[2] != in1_shape[2],
707 "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
724 _slide_matrix_b = in1_shape[2] != 1;
726 constexpr
unsigned int num_elems_processed_per_iteration_x = 16;
727 constexpr
unsigned int num_elems_processed_per_iteration_y = 4;
731 if ((
dst->dimension(1) == 1))
742 ICpuKernel::configure(win);
763 if ((
dst->info()->dimension(1) == 1))
765 const auto width_matrix_a =
static_cast<int>(src0->info()->dimension(0));
766 const auto width_matrix_b =
static_cast<int>(src1->info()->dimension(0));
767 const auto width_out =
static_cast<int>(
dst->info()->dimension(0));
768 const auto in_b_stride =
769 static_cast<int>(src1->info()->strides_in_bytes()[1] /
data_size_from_type(src1->info()->data_type()));
772 const int window_start_x = 16 *
info.thread_id;
773 const int window_step_x = 16 *
info.num_threads;
775 const int window_end_x =
ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
788 if (src1->info()->num_dimensions() >= 3)
799 switch (src0->info()->data_type())
804 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
811 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride,
824 const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
825 const int width_b = src1->info()->dimension(0);
848 switch (src0->info()->data_type())
853 matrix_multiply_s8(ina, inb, out, width_b, *
dst->info(),
window);
859 matrix_multiply_u8(ina, inb, out, width_b, *
dst->info(),
window);
873 return "CpuGemmLowpMatrixMultiplyKernel";