49 struct ElementwiseKernel
56 template <DataType dt>
62 template <DataType input_data_type, DataType output_data_type = input_data_type>
63 static ElementwiseKernel generate_kernel(UKernelType *
ukernel)
70 return { kernel_name.c_str(), is_selected<input_data_type>, ukernel };
73 template <ArithmeticOperation op>
74 std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
75 configure_arithm_func(
const ITensorInfo *input1,
const ITensorInfo *input2, ITensorInfo *output)
78 static ElementwiseKernel kernels[] =
80 #if defined(__ARM_FEATURE_SVE) 81 generate_kernel<DataType::F32>(
REGISTER_FP32_SVE((arm_compute::cpu::sve::elementwise_arithmetic_op<op, float32_t>))),
82 generate_kernel<DataType::S32>(
REGISTER_INTEGER_SVE((arm_compute::cpu::sve::elementwise_arithmetic_op<op, int32_t>))),
87 #
if defined(__ARM_FEATURE_SVE2)
88 generate_kernel<DataType::QASYMM8>(
REGISTER_QASYMM8_SVE((arm_compute::cpu::sve::elementwise_arithmetic_quantized_op<op, uint8_t>))),
89 generate_kernel<DataType::QASYMM8_SIGNED>(
REGISTER_QASYMM8_SIGNED_SVE((arm_compute::cpu::sve::elementwise_arithmetic_quantized_op<op, int8_t>))),
91 generate_kernel<DataType::QASYMM8>(
REGISTER_QASYMM8_NEON((arm_compute::cpu::elementwise_arithm_op_quantized<op>))),
94 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
95 #
if defined(__ARM_FEATURE_SVE)
96 generate_kernel<DataType::F16>(
REGISTER_FP16_SVE((arm_compute::cpu::sve::elementwise_arithmetic_op<op, float16_t>))),
104 for(
const auto &uk : kernels)
106 if(uk.is_selected(input1->data_type()))
115 template <ComparisonOperation op>
116 std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
117 configure_comp_func(
const ITensorInfo *input1,
const ITensorInfo *input2, ITensorInfo *output)
120 static ElementwiseKernel kernels[] =
122 #if defined(__ARM_FEATURE_SVE) 123 generate_kernel<DataType::U8, DataType::U8>(
REGISTER_INTEGER_SVE((arm_compute::cpu::sve::elementwise_comparison_op<op, uint8_t>))),
124 generate_kernel<DataType::F32, DataType::U8>(
REGISTER_FP32_SVE((arm_compute::cpu::sve::elementwise_comparison_op<op, float>))),
125 generate_kernel<DataType::S16, DataType::U8>(
REGISTER_INTEGER_SVE((arm_compute::cpu::sve::elementwise_comparison_op<op, int16_t>))),
126 generate_kernel<DataType::S32, DataType::U8>(
REGISTER_INTEGER_SVE((arm_compute::cpu::sve::elementwise_comparison_op<op, int32_t>))),
128 generate_kernel<DataType::U8, DataType::U8>(
REGISTER_INTEGER_NEON((arm_compute::cpu::elementwise_comp_op_8<op, uint8_t, uint8x16_t>))),
129 generate_kernel<DataType::F32, DataType::U8>(
REGISTER_FP32_NEON((arm_compute::cpu::elementwise_comp_op_32<op, float, float32x4_t>))),
130 generate_kernel<DataType::S16, DataType::U8>(
REGISTER_INTEGER_NEON((arm_compute::cpu::elementwise_comp_op_16<op, int16_t, int16x8_t>))),
131 generate_kernel<DataType::S32, DataType::U8>(
REGISTER_INTEGER_NEON((arm_compute::cpu::elementwise_comp_op_32<op, int32_t, int32x4_t>))),
133 #
if defined(__ARM_FEATURE_SVE2)
134 generate_kernel<DataType::QASYMM8_SIGNED, DataType::U8>(
REGISTER_QASYMM8_SIGNED_SVE((arm_compute::cpu::sve::elementwise_comparison_quantized_op<op, int8_t>))),
135 generate_kernel<DataType::QASYMM8, DataType::U8>(
REGISTER_QASYMM8_SVE((arm_compute::cpu::sve::elementwise_comparison_quantized_op<op, uint8_t>))),
137 generate_kernel<DataType::QASYMM8_SIGNED, DataType::U8>(
REGISTER_QASYMM8_SIGNED_NEON((arm_compute::cpu::elementwise_comp_op_quantized_signed<op>))),
138 generate_kernel<DataType::QASYMM8, DataType::U8>(
REGISTER_QASYMM8_NEON((arm_compute::cpu::elementwise_comp_op_quantized<op>))),
140 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
141 #
if defined(__ARM_FEATURE_SVE)
142 generate_kernel<DataType::F16, DataType::U8>(
REGISTER_FP16_SVE((arm_compute::cpu::sve::elementwise_comparison_op<op, float16_t>))),
144 generate_kernel<DataType::F16, DataType::U8>(
REGISTER_FP16_NEON((arm_compute::cpu::elementwise_comp_op_16<op, float16_t, float16x8_t>))),
149 for(
const auto &uk : kernels)
151 if(uk.is_selected(input1->data_type()))
161 Status CpuElementwiseKernel::validate_arguments_common(
const ITensorInfo &input1,
const ITensorInfo &input2,
const ITensorInfo &output)
171 if(output.total_size() > 0)
174 "Wrong shape for output");
180 void CpuElementwiseKernel::configure_common(
const ITensorInfo *input1,
const ITensorInfo *input2, ITensorInfo *output)
186 const TensorShape &out_shape = broadcast_pair.first;
187 const ValidRegion &
valid_region = broadcast_pair.second;
194 ICpuKernel::configure(win);
207 auto function = get_implementation(src0->info(), src1->info(),
dst->info());
209 function(src0, src1,
dst, window);
216 configure_common(input1, input2, output);
228 return validate_arguments_common(input1, input2, output);
239 std::function<CpuElementwiseKernel::ElementwiseFunction>
245 return configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
247 return configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
249 return configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
251 return configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
253 return configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
255 return configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
267 configure_common(input1, input2, output);
274 return CpuArithmeticKernel::validate_arguments(input1, input2, output);
288 configure_common(input1, input2, output);
295 return CpuArithmeticKernel::validate_arguments(input1, input2, output);
309 configure_common(input1, input2, output);
321 return validate_arguments_common(input1, input2, output);
332 std::function<CpuElementwiseKernel::ElementwiseFunction>
338 return configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
340 return configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
342 return configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
344 return configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
346 return configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
348 return configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of CpuPowerKernel.
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
ArithmeticOperation
Available element-wise operations.
const Window & window() const
The maximum window the kernel can be executed on.
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
#define REGISTER_FP16_NEON(func_name)
void configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
Configure kernel.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
1 channel, 1 U8 per channel
#define REGISTER_FP32_NEON(func_name)
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
1 channel, 1 F32 per channel
#define REGISTER_FP32_SVE(func_name)
static TensorShape broadcast_shape(const Shapes &... shapes)
If shapes are broadcast compatible, return the broadcasted shape.
#define REGISTER_QASYMM8_SVE(func_name)
void configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
Configure kernel.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
#define REGISTER_QASYMM8_SIGNED_NEON(func_name)
Store the tensor's metadata.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
const ValidRegion valid_region
static std::pair< TensorShape, ValidRegion > broadcast_shape_and_valid_region(const Infos &... infos)
If infos are broadcast compatible tensor info's, return the broadcasted shape and the intersection of...
decltype(strategy::transforms) typedef type
void configure(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
Configure kernel.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
Greater equal comparison ( )
#define REGISTER_INTEGER_NEON(func_name)
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
#define REGISTER_QASYMM8_SIGNED_SVE(func_name)
1 channel, 1 S32 per channel
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
const std::string & string_from_data_type(DataType dt)
Convert a data type identity into a string.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
#define REGISTER_QASYMM8_NEON(func_name)
quantized, asymmetric fixed-point 8-bit number unsigned
#define REGISTER_INTEGER_SVE(func_name)
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool have_different_dimensions(const Dimensions< T > &dim1, const Dimensions< T > &dim2, unsigned int upper_dim)
const ElementwiseSelector is_selected
ComparisonOperation
Supported comparison operations.
y*x if x < 0, x otherwise
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
1 channel, 1 S16 per channel
static Status validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of cpu::kernels::CpuCompari...
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Information about executing thread and CPU.
virtual size_t total_size() const =0
Returns the total size of the tensor in bytes.
#define REGISTER_FP16_SVE(func_name)
void(const ITensor *, const ITensor *, ITensor *, const Window &) ElementwiseFunction
Common signature for all the specialised arithmetic functions.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
Less equal comparison ( )
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
static Status validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of cpu::kernels::CpuArithme...
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
quantized, asymmetric fixed-point 8-bit number signed
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of CpuDivisionKernel.
DataType
Available data types.
void configure(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
Configure kernel.
Describe a multidimensional execution window.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)