44 intermediate_type val;
50 val =
static_cast<intermediate_type
>(src1) +
static_cast<intermediate_type
>(src2);
55 val =
static_cast<intermediate_type
>(src1) -
static_cast<intermediate_type
>(src2);
60 val = std::min(
static_cast<intermediate_type
>(src1),
static_cast<intermediate_type
>(src2));
65 val = std::max(
static_cast<intermediate_type
>(src1),
static_cast<intermediate_type
>(src2));
70 intermediate_type tmp = (
static_cast<intermediate_type
>(src1) -
static_cast<intermediate_type
>(src2));
76 val = (
static_cast<intermediate_type
>(src1) /
static_cast<intermediate_type
>(src2));
77 if(std::is_integral<T>::value)
80 val = (src2 == 0) ? 0 : val;
81 if(
static_cast<int32_t
>(src1) %
static_cast<int32_t
>(src2) != 0 && ((src1 < 0) != (src2 < 0)))
90 val = std::pow(
static_cast<intermediate_type
>(src1),
static_cast<intermediate_type
>(src2));
95 const T x =
static_cast<intermediate_type
>(src1);
96 const T alpha =
static_cast<intermediate_type
>(src2);
97 val = (x > 0 ? x : alpha * x);
112 result =
static_cast<T
>(val);
117 template <
size_t dim>
118 struct BroadcastUnroll
120 template <
typename T>
121 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
122 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
124 const bool src1_is_broadcast = (src1.shape()[dim - 1] !=
dst.shape()[dim - 1]);
125 const bool src2_is_broadcast = (src2.shape()[dim - 1] !=
dst.shape()[dim - 1]);
127 id_src1.set(dim - 1, 0);
128 id_src2.set(dim - 1, 0);
129 id_dst.set(dim - 1, 0);
131 for(
size_t i = 0; i <
dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
133 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
135 id_src1[dim - 1] += !src1_is_broadcast;
136 id_src2[dim - 1] += !src2_is_broadcast;
142 struct BroadcastUnroll<0>
144 template <
typename T>
145 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
146 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
153 template <
typename T>
156 Coordinates id_src1{};
157 Coordinates id_src2{};
158 Coordinates id_dst{};
160 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
174 Coordinates id_src1{};
175 Coordinates id_src2{};
176 Coordinates id_dst{};
178 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
180 dst = convert_to_asymmetric<uint8_t>(dst_tmp,
dst.quantization_info());
186 Coordinates id_src1{};
187 Coordinates id_src2{};
188 Coordinates id_dst{};
190 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
204 Coordinates id_src1{};
205 Coordinates id_src2{};
206 Coordinates id_dst{};
208 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
210 dst = convert_to_asymmetric<int8_t>(dst_tmp,
dst.quantization_info());
216 Coordinates id_src1{};
217 Coordinates id_src2{};
218 Coordinates id_dst{};
220 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
231 SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
232 SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
235 Coordinates id_src1{};
236 Coordinates id_src2{};
237 Coordinates id_dst{};
239 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
241 dst = convert_to_symmetric<int16_t>(dst_tmp,
dst.quantization_info());
247 Coordinates id_src1{};
248 Coordinates id_src2{};
249 Coordinates id_dst{};
251 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
262 template <
typename T>
268 arithmetic_operation<T>(op, src1, src2,
dst, convert_policy);