25 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
48 const float32x4x3_t r = {{vld1q_dup_f32(ptr), vld1q_dup_f32(1 + ptr), vld1q_dup_f32(2 + ptr)}};
59 template <
typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, u
int8_t>::value || std::is_same<T,
int8_t>::value)>
62 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
66 int32x4x3_t r = {{vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
67 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
68 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))}};
78 template <
unsigned int str
idex>
79 void store_results(
float *buffer,
const float32x4x2_t &values);
84 vst1q_f32(buffer, values.val[0]);
85 vst1q_f32(buffer + 4, values.val[1]);
91 vst1q_f32(buffer, values.val[0]);
97 vst1_f32(buffer, vget_low_f32(values.val[0]));
106 template <
unsigned int str
idex>
107 void store_results(int32_t *buffer,
const int32x4x2_t &values);
112 vst1q_s32(buffer, values.val[0]);
113 vst1q_s32(buffer + 4, values.val[1]);
119 vst1q_s32(buffer, values.val[0]);
125 vst1_s32(buffer, vget_low_s32(values.val[0]));
128 template <
unsigned int str
idex>
134 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
135 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
141 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
147 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
150 template <
unsigned int str
idex>
156 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
157 vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
163 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
169 vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
172 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
179 template <
unsigned int str
idex>
180 void store_results(float16_t *buffer,
const float16x8x2_t &values);
183 inline void store_results<1>(float16_t *buffer,
const float16x8x2_t &values)
185 vst1q_f16(buffer, values.val[0]);
186 vst1q_f16(buffer + 8, values.val[1]);
190 inline void store_results<2>(float16_t *buffer,
const float16x8x2_t &values)
192 vst1q_f16(buffer, values.val[0]);
196 inline void store_results<3>(float16_t *buffer,
const float16x8x2_t &values)
198 vst1_f16(buffer, vget_low_f16(values.val[0]));
201 template <
unsigned int str
idex>
207 vst1q_f16(buffer,
vaddq_f16(vld1q_f16(buffer), values.val[0]));
208 vst1q_f16(buffer + 8,
vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
214 vst1q_f16(buffer,
vaddq_f16(vld1q_f16(buffer), values.val[0]));
220 vst1_f16(buffer,
vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
239 const float32x4x3_t &m0,
240 const float32x4x3_t &m1,
241 const float32x4x3_t &m2,
242 const size_t dilation_x,
247 const float32x4x3_t vtop = {
248 {vld1q_f32(in_top), vld1q_f32(in_top + dilation_x), vld1q_f32(in_top + 2 * dilation_x)}};
249 const float32x4x3_t vmid = {
250 {vld1q_f32(in_mid), vld1q_f32(in_mid + dilation_x), vld1q_f32(in_mid + 2 * dilation_x)}};
251 const float32x4x3_t vlow = {
252 {vld1q_f32(in_low), vld1q_f32(in_low + dilation_x), vld1q_f32(in_low + 2 * dilation_x)}};
253 float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
254 out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
255 out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
257 out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
258 out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
259 out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
261 out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
262 out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
263 out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
284 const float32x4x3_t &m0,
285 const float32x4x3_t &m1,
286 const float32x4x3_t &m2,
287 const size_t dilation_x,
288 unsigned int stridex,
289 int input_offset = 0)
292 float32x4x2_t out = {
298 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
299 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
300 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
302 else if (stridex == 3)
304 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
323 template <
bool accumulate>
328 const float32x4x3_t &m0,
329 const float32x4x3_t &m1,
330 const float32x4x3_t &m2,
331 unsigned int stridex,
332 int input_offset = 0);
334 template <
bool accumulate>
339 const float32x4x3_t &m0,
340 const float32x4x3_t &m1,
341 const float32x4x3_t &m2,
342 unsigned int stridex,
348 float32x4x2_t out = {{vdupq_n_f32(0.f), vdupq_n_f32(0.f)}};
351 const float32x4x2_t vtop = vld2q_f32(in_top);
352 const float32x4x2_t vmid = vld2q_f32(in_mid);
353 const float32x4x2_t vlow = vld2q_f32(in_low);
354 const float32x4_t vtop_end = vld1q_f32(in_top + 8);
355 const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
356 const float32x4_t vlow_end = vld1q_f32(in_low + 8);
358 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
360 out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
361 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
363 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
364 out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
365 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
367 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
368 out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
369 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
375 const float32x4x3_t vtop = {{vld1q_f32(in_top), vld1q_f32(in_top + 4), vld1q_f32(in_top + 8)}};
376 const float32x4x3_t vmid = {{vld1q_f32(in_mid), vld1q_f32(in_mid + 4), vld1q_f32(in_mid + 8)}};
377 const float32x4x3_t vlow = {{vld1q_f32(in_low), vld1q_f32(in_low + 4), vld1q_f32(in_low + 8)}};
378 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
379 out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
381 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
382 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
384 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
385 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
386 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
388 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
389 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
390 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
392 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
393 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
395 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
396 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
397 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
399 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
400 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
401 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
405 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
427 template <
typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, u
int8_t>::value || std::is_same<T,
int8_t>::value)>
431 const int32x4x3_t &m0,
432 const int32x4x3_t &m1,
433 const int32x4x3_t &m2,
435 int32_t input_offset)
437 using VectorType =
typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>
::type;
440 const int32x4_t v_input_offset =
wrapper::vdup_n(input_offset, OutputTagType{});
442 const VectorType vtop = {
444 const VectorType vmid = {
446 const VectorType vlow = {
449 const int32x4x3_t vtop_s32 = {{
454 const int32x4x3_t vmid_s32 = {{
459 const int32x4x3_t vlow_s32 = {{
493 template <
typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, u
int8_t>::value || std::is_same<T,
int8_t>::value)>
497 const int32x4x3_t &m0,
498 const int32x4x3_t &m1,
499 const int32x4x3_t &m2,
500 const size_t dilation_x,
501 unsigned int stridex,
515 else if (stridex == 3)
543 const int32x4x3_t &m0,
544 const int32x4x3_t &m1,
545 const int32x4x3_t &m2,
546 unsigned int stridex,
547 int32_t input_offset)
550 using VectorType =
typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>
::type;
553 const int32x4_t v_input_offset =
wrapper::vdup_n(input_offset, OutputTagType{});
559 const int32x4x3_t vtop_s32 = {{
564 const int32x4x3_t vmid_s32 = {{
569 const int32x4x3_t vlow_s32 = {{
581 out.val[0] =
wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
585 out.val[0] =
wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
589 out.val[0] =
wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
594 out.val[1] =
wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
598 out.val[1] =
wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
602 out.val[1] =
wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
610 else if (stridex == 2)
618 else if (stridex == 3)
625 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
632 inline float16x8x3_t
load_matrix_row(
const float16_t *ptr,
int weights_offset = 0)
637 const float16x8x3_t r = {{vld1q_dup_f16(ptr), vld1q_dup_f16(1 + ptr), vld1q_dup_f16(2 + ptr)}};
654 const float16_t *in_mid,
655 const float16_t *in_low,
656 const float16x8x3_t &m0,
657 const float16x8x3_t &m1,
658 const float16x8x3_t &m2,
659 const size_t dilation_x,
660 int input_offset = 0)
663 const float16x8x3_t vtop = {
664 {vld1q_f16(in_top), vld1q_f16(in_top + dilation_x), vld1q_f16(in_top + 2 * dilation_x)}};
665 const float16x8x3_t vmid = {
666 {vld1q_f16(in_mid), vld1q_f16(in_mid + dilation_x), vld1q_f16(in_mid + 2 * dilation_x)}};
667 const float16x8x3_t vlow = {
668 {vld1q_f16(in_low), vld1q_f16(in_low + dilation_x), vld1q_f16(in_low + 2 * dilation_x)}};
669 float16x8_t out =
vmulq_f16(vtop.val[0], m0.val[0]);
698 const float16_t *in_mid,
699 const float16_t *in_low,
700 const float16x8x3_t &m0,
701 const float16x8x3_t &m1,
702 const float16x8x3_t &m2,
703 const size_t dilation_x,
704 unsigned int stridex,
705 int input_offset = 0)
707 float16x8x2_t out = {
713 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
714 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
715 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
716 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
717 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
718 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
719 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
721 else if (stridex == 3)
723 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
724 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
725 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
744 template <
bool accumulate>
746 const float16_t *in_mid,
747 const float16_t *in_low,
749 const float16x8x3_t &m0,
750 const float16x8x3_t &m1,
751 const float16x8x3_t &m2,
752 unsigned int stridex,
753 int input_offset = 0)
757 float16x8x2_t out = {{vdupq_n_f16(0), vdupq_n_f16(0)}};
760 const float16x8x2_t vtop = vld2q_f16(in_top);
761 const float16x8x2_t vmid = vld2q_f16(in_mid);
762 const float16x8x2_t vlow = vld2q_f16(in_low);
763 const float16x8_t vtop_end = vld1q_f16(in_top + 16);
764 const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
765 const float16x8_t vlow_end = vld1q_f16(in_low + 16);
767 out.val[0] =
vmulq_f16(vtop.val[0], m0.val[0]);
784 const float16x8x3_t vtop = {{vld1q_f16(in_top), vld1q_f16(in_top + 8), vld1q_f16(in_top + 16)}};
785 const float16x8x3_t vmid = {{vld1q_f16(in_mid), vld1q_f16(in_mid + 8), vld1q_f16(in_mid + 16)}};
786 const float16x8x3_t vlow = {{vld1q_f16(in_low), vld1q_f16(in_low + 8), vld1q_f16(in_low + 16)}};
787 out.val[0] =
vmulq_f16(vtop.val[0], m0.val[0]);
788 out.val[1] =
vmulq_f16(vtop.val[1], m0.val[0]);
809 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
810 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
811 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
835 return num_elems_written_per_iteration;
837 return num_elems_written_per_iteration << 1;
839 return num_elems_written_per_iteration * 3;