34 inline float32x4_t clamp_v4f32(float32x4_t block, float32x4_t quant_min_vec, float32x4_t quant_max_vec)
36 return vminq_f32(vmaxq_f32(block, quant_min_vec), quant_max_vec);
38 inline uint16x8_t fuse_words_f32(float32x4_t fb1, float32x4_t fb2)
40 return vcombine_u16(vmovn_u32(vcvtq_u32_f32(fb1)), vmovn_u32(vcvtq_u32_f32(fb2)));
42 inline uint8x16_t fuse_shorts_u16(uint16x8_t sb1, uint16x8_t sb2)
44 return vcombine_u8(vmovn_u16(sb1), vmovn_u16(sb2));
57 const int window_step_x = 16;
58 const int window_start_x =
static_cast<int>(window.
x().
start());
59 const int window_end_x =
static_cast<int>(window.
x().
end());
62 const float output_scale = qi_out.
scale;
63 const int output_offset = qi_out.
offset;
68 const float output_inv_scale = 1.0f / output_scale;
69 const float32x4_t quant_max_vec = vdupq_n_f32(255.0f);
70 const float32x4_t quant_min_vec = vdupq_n_f32(0.0f);
76 int x = window_start_x;
77 auto in_ptr =
reinterpret_cast<const uint8_t *
>(input_itr.
ptr());
78 auto out_ptr =
reinterpret_cast<uint8_t *
>(output_itr.
ptr());
80 uint32x4_t sum_vec = vdupq_n_u32(0);
81 uint32x4_t sum_sq_vec = vdupq_n_u32(0);
83 for (; x <= (window_end_x - window_step_x); x += window_step_x)
85 const uint8x16_t data = vld1q_u8(in_ptr + x);
86 sum_vec = vaddq_u32(sum_vec, vpaddlq_u16(vpaddlq_u8(data)));
87 const uint16x8_t squares_low = vmull_u8(vget_low_u8(data), vget_low_u8(data));
88 const uint16x8_t squares_high = vmull_u8(vget_high_u8(data), vget_high_u8(data));
89 sum_sq_vec = vaddq_u32(sum_sq_vec, vaddq_u32(vpaddlq_u16(squares_low), vpaddlq_u16(squares_high)));
93 sum_vec = vpaddq_u32(sum_vec, sum_vec);
94 sum_vec = vpaddq_u32(sum_vec, sum_vec);
95 uint32_t sum = vgetq_lane_u32(sum_vec, 0);
96 sum_sq_vec = vpaddq_u32(sum_sq_vec, sum_sq_vec);
97 sum_sq_vec = vpaddq_u32(sum_sq_vec, sum_sq_vec);
98 uint32_t sum_sq = vgetq_lane_u32(sum_sq_vec, 0);
99 #elif __arm__ // #ifdef __aarch64__
100 uint32_t sum = vgetq_lane_u32(sum_vec, 0) + vgetq_lane_u32(sum_vec, 1) + vgetq_lane_u32(sum_vec, 2) +
101 vgetq_lane_u32(sum_vec, 3);
103 uint32_t sum_sq = vgetq_lane_u32(sum_sq_vec, 0) + vgetq_lane_u32(sum_sq_vec, 1) +
104 vgetq_lane_u32(sum_sq_vec, 2) + vgetq_lane_u32(sum_sq_vec, 3);
105 #endif // #ifdef __aarch64__
106 for (; x < window_end_x; ++x)
108 auto data =
static_cast<uint32_t
>(*(in_ptr + x));
110 sum_sq += (data * data);
113 const float mean = (
static_cast<float>(sum) /
static_cast<float>(
input->info()->dimension(0)));
115 (
static_cast<float>(sum_sq) /
static_cast<float>(
input->info()->dimension(0))) - (mean * mean);
116 const float stdev_inv = 1.0f / sqrtf(var +
epsilon);
117 const float32x4_t v_scale = vdupq_n_f32(stdev_inv * output_inv_scale);
118 const float32x4_t v_offset = vdupq_n_f32(-mean * stdev_inv * output_inv_scale + output_offset);
119 for (x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
121 const uint8x16_t data = vld1q_u8(in_ptr + x);
122 float32x4_t db1 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(data)))));
123 float32x4_t db2 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(data)))));
124 float32x4_t db3 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(data)))));
125 float32x4_t db4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(data)))));
126 db1 = clamp_v4f32(vaddq_f32(vmulq_f32(db1, v_scale), v_offset), quant_min_vec, quant_max_vec);
127 db2 = clamp_v4f32(vaddq_f32(vmulq_f32(db2, v_scale), v_offset), quant_min_vec, quant_max_vec);
128 db3 = clamp_v4f32(vaddq_f32(vmulq_f32(db3, v_scale), v_offset), quant_min_vec, quant_max_vec);
129 db4 = clamp_v4f32(vaddq_f32(vmulq_f32(db4, v_scale), v_offset), quant_min_vec, quant_max_vec);
130 const uint8x16_t out = fuse_shorts_u16(fuse_words_f32(db1, db2), fuse_words_f32(db3, db4));
131 vst1q_u8(out_ptr + x, out);
134 for (; x < window_end_x; ++x)
136 auto data =
static_cast<float32_t>(*(in_ptr + x));
138 data * (stdev_inv * output_inv_scale) + (-mean * stdev_inv * output_inv_scale + output_offset);
139 *(out_ptr + x) = res;
142 input_itr, output_itr);