44 intermediate_type val;
50 val =
static_cast<intermediate_type
>(src1) + static_cast<intermediate_type>(src2);
55 val =
static_cast<intermediate_type
>(src1) - static_cast<intermediate_type>(src2);
60 val = std::min(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
65 val = std::max(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
70 intermediate_type tmp = (
static_cast<intermediate_type
>(src1) - static_cast<intermediate_type>(src2));
76 val = (
static_cast<intermediate_type
>(src1) / static_cast<intermediate_type>(src2));
77 if(std::is_integral<T>::value)
80 val = (src2 == 0) ? 0 : val;
81 if(static_cast<int32_t>(src1) %
static_cast<int32_t
>(src2) != 0 && ((src1 < 0) != (src2 < 0)))
90 val = std::pow(static_cast<intermediate_type>(src1), static_cast<intermediate_type>(src2));
95 const T x =
static_cast<intermediate_type
>(src1);
96 const T alpha =
static_cast<intermediate_type
>(src2);
97 val = (x > 0 ? x : alpha * x);
112 result =
static_cast<T
>(val);
117 template <
size_t dim>
118 struct BroadcastUnroll
120 template <
typename T>
121 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
122 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
124 const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
125 const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
127 id_src1.set(dim - 1, 0);
128 id_src2.set(dim - 1, 0);
129 id_dst.set(dim - 1, 0);
131 for(
size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
133 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
135 id_src1[dim - 1] += !src1_is_broadcast;
136 id_src2[dim - 1] += !src2_is_broadcast;
142 struct BroadcastUnroll<0>
144 template <
typename T>
145 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &dst,
146 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
153 template <
typename T>
156 Coordinates id_src1{};
157 Coordinates id_src2{};
158 Coordinates id_dst{};
160 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
174 Coordinates id_src1{};
175 Coordinates id_src2{};
176 Coordinates id_dst{};
178 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
180 dst = convert_to_asymmetric<uint8_t>(dst_tmp, dst.quantization_info());
186 Coordinates id_src1{};
187 Coordinates id_src2{};
188 Coordinates id_dst{};
190 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
204 Coordinates id_src1{};
205 Coordinates id_src2{};
206 Coordinates id_dst{};
208 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
210 dst = convert_to_asymmetric<int8_t>(dst_tmp, dst.quantization_info());
216 Coordinates id_src1{};
217 Coordinates id_src2{};
218 Coordinates id_dst{};
220 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
231 SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
232 SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
235 Coordinates id_src1{};
236 Coordinates id_src2{};
237 Coordinates id_dst{};
239 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
241 dst = convert_to_symmetric<int16_t>(dst_tmp, dst.quantization_info());
247 Coordinates id_src1{};
248 Coordinates id_src2{};
249 Coordinates id_dst{};
251 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
262 template <
typename T>
268 arithmetic_operation<T>(op, src1, src2,
dst, convert_policy);
bool is_data_type_quantized(DataType dt)
Check if a given data type is of quantized type.
quantized, symmetric fixed-point 16-bit number
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
static TensorShape broadcast_shape(const Shapes &... shapes)
If shapes are broadcast compatible, return the broadcasted shape.
SimpleTensor< float > convert_from_asymmetric(const SimpleTensor< uint8_t > &src)
SimpleTensor< T > arithmetic_operation(ArithmeticOperation op, const SimpleTensor< T > &src1, const SimpleTensor< T > &src2, SimpleTensor< T > &dst, ConvertPolicy convert_policy)
Copyright (c) 2017-2021 Arm Limited.
int coord2index(const TensorShape &shape, const Coordinates &coord)
Linearise the given coordinate.
quantized, asymmetric fixed-point 8-bit number unsigned
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
y*x if x < 0, x otherwise
ArithmeticOperation
Arithmetic operation types.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
ConvertPolicy
Policy to handle overflow.