44 intermediate_type val = (op ==
ArithmeticOperation::ADD) ? static_cast<intermediate_type>(src1) + static_cast<intermediate_type>(src2) : static_cast<intermediate_type>
45 (src1) - static_cast<intermediate_type>(src2);
53 struct BroadcastUnroll
56 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
57 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
59 const bool src1_is_broadcast = (src1.shape()[dim - 1] !=
dst.shape()[dim - 1]);
60 const bool src2_is_broadcast = (src2.shape()[dim - 1] !=
dst.shape()[dim - 1]);
62 id_src1.set(dim - 1, 0);
63 id_src2.set(dim - 1, 0);
64 id_dst.set(dim - 1, 0);
66 #pragma omp parallel for 68 for(
size_t i = 0; i <
dst.shape()[dim - 1]; ++i)
70 BroadcastUnroll < dim - 1 >::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
72 id_src1[dim - 1] += !src1_is_broadcast;
73 id_src2[dim - 1] += !src2_is_broadcast;
80 struct BroadcastUnroll<0>
83 static void unroll(
ArithmeticOperation op,
const SimpleTensor<T> &src1,
const SimpleTensor<T> &src2, SimpleTensor<T> &
dst,
84 ConvertPolicy convert_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
98 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
116 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
118 dst = convert_to_asymmetric<uint8_t>(dst_tmp,
dst.quantization_info());
124 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
143 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
145 dst = convert_to_asymmetric<int8_t>(dst_tmp,
dst.quantization_info());
151 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
170 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst);
172 dst = convert_to_symmetric<int16_t>(dst_tmp,
dst.quantization_info());
178 BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2,
dst, convert_policy, id_src1, id_src2, id_dst);
187 template <
typename T>
193 arithmetic_operation<T>(op, src1, src2,
dst, convert_policy);
bool is_data_type_quantized(DataType dt)
Check if a given data type is of quantized type.
ArithmeticOperation
Available element-wise operations.
quantized, symmetric fixed-point 16-bit number
static TensorShape broadcast_shape(const Shapes &... shapes)
If shapes are broadcast compatible, return the broadcasted shape.
SimpleTensor< float > convert_from_asymmetric(const SimpleTensor< uint8_t > &src)
TensorShape shape() const override
Shape of the tensor.
SimpleTensor< T > arithmetic_operation(ArithmeticOperation op, const SimpleTensor< T > &src1, const SimpleTensor< T > &src2, SimpleTensor< T > &dst, ConvertPolicy convert_policy)
Copyright (c) 2017-2021 Arm Limited.
int coord2index(const TensorShape &shape, const Coordinates &coord)
Linearise the given coordinate.
quantized, asymmetric fixed-point 8-bit number unsigned
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Simple tensor object that stores elements in a consecutive chunk of memory.
ArithmeticOperation
Arithmetic operation types.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
ConvertPolicy
Policy to handle overflow.