39 static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
40 "quantized type should be either qasymm8_t or qasymm8_signed_t.");
46 const auto scale_beta_vec = vdupq_n_f32(scale_beta);
51 constexpr
int vec_size = 16;
58 const auto in_ptr =
reinterpret_cast<const T *
>(in_it.
ptr()) + start_x;
59 const auto out_ptr =
reinterpret_cast<T *
>(out_it.
ptr()) + start_x;
60 const auto tmp_ptr =
reinterpret_cast<float *
>(tmp);
68 const auto max_val = *
reinterpret_cast<const T *
>(max_it.
ptr());
72 float32x4x4_t vec_sum = {
81 for (; x <= (input_width - vec_size); x += vec_size)
85 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
89 vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
90 vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
91 vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
92 vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
93 vec_sum.val[0] = vaddq_f32(vec_sum.val[0],
vexpq_f32(vec_elements_flt.val[0]));
94 vec_sum.val[1] = vaddq_f32(vec_sum.val[1],
vexpq_f32(vec_elements_flt.val[1]));
95 vec_sum.val[2] = vaddq_f32(vec_sum.val[2],
vexpq_f32(vec_elements_flt.val[2]));
96 vec_sum.val[3] = vaddq_f32(vec_sum.val[3],
vexpq_f32(vec_elements_flt.val[3]));
100 vec_elements_flt.val[0] =
vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
101 vec_elements_flt.val[1] =
vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
102 vec_elements_flt.val[2] =
vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
103 vec_elements_flt.val[3] =
vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
104 vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
105 vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
106 vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
107 vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
110 vst4q_f32(tmp_ptr + x, vec_elements_flt);
114 const auto sum_16_byte =
115 vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
116 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
117 sum_res = vpadd_f32(sum_res, sum_res);
121 for (; x < input_width; ++x)
126 element = (max_val - in_ptr[x]) * scale_beta;
127 sum += std::exp(element);
131 element = std::exp((max_val - in_ptr[x]) * scale_beta);
135 tmp_ptr[x] = element;
140 sum_inversed = 256.f / sum;
150 constexpr
bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
153 for (; x <= (input_width - vec_size); x += vec_size)
156 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
157 int_vec_type normalized_value{};
160 const float32x4x4_t sub = {
161 vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
162 vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
163 vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
164 vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
166 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
170 float32x4x4_t mul = {
171 vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
172 vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
173 vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
174 vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
177 if (is_qasymm8_signed)
186 normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
191 for (; x < input_width; ++x)
195 out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
199 out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) -
200 (is_qasymm8_signed ? 128.f : 0));
205 in_it, max_it, out_it);