44 intermediate_type val = (op ==
ArithmeticOperation::ADD) ?
static_cast<intermediate_type
>(src1) +
static_cast<intermediate_type
>(src2) :
static_cast<intermediate_type
>
45 (src1) -
static_cast<intermediate_type
>(src2);
53 struct BroadcastUnroll
56 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
57 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
59 const bool src1_is_broadcast = (src1.shape()[dim - 1] !=
dst.shape()[dim - 1]);
60 const bool src2_is_broadcast = (src2.shape()[dim - 1] !=
dst.shape()[dim - 1]);
62 id_src1.set(dim - 1, 0);
63 id_src2.set(dim - 1, 0);
64 id_dst.set(dim - 1, 0);
66 #pragma omp parallel for
68 for(
size_t i = 0; i <
dst.shape()[dim - 1]; ++i)
70 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
72 id_src1[dim - 1] += !src1_is_broadcast;
73 id_src2[dim - 1] += !src2_is_broadcast;
80 struct BroadcastUnroll<0>
83 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
84 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
98 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
116 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
118 dst = convert_to_asymmetric<uint8_t>(dst_tmp,
dst.quantization_info());
124 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
143 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
145 dst = convert_to_asymmetric<int8_t>(dst_tmp,
dst.quantization_info());
151 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
170 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
172 dst = convert_to_symmetric<int16_t>(dst_tmp,
dst.quantization_info());
178 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
187 template <
typename T>
193 arithmetic_operation<T>(op, src1, src2,
dst, convert_policy);