24.02.1
|
Go to the documentation of this file.
47 struct FuseBatchNormalizeSelectorData
52 cpuinfo::CpuIsaInfo
isa;
55 using FBNSelectorPtr = std::add_pointer<bool(
const FuseBatchNormalizeSelectorData &data)>
::type;
56 using FBNUKernelPtr = std::add_pointer<void(
const ITensor *,
74 static const FBNUKernel available_kernels[] = {
75 {
"fused_batch_normalization_conv_NHWC_F16",
76 [](
const FuseBatchNormalizeSelectorData &data)
82 {
"fused_batch_normalization_conv_NCHW_F16",
83 [](
const FuseBatchNormalizeSelectorData &data)
89 {
"fused_batch_normalization_dwc_NHWC_F16",
90 [](
const FuseBatchNormalizeSelectorData &data)
96 {
"fused_batch_normalization_dwc_NCHW_F16",
97 [](
const FuseBatchNormalizeSelectorData &data)
103 {
"fused_batch_normalization_conv_NHWC_F32",
104 [](
const FuseBatchNormalizeSelectorData &data)
110 {
"fused_batch_normalization_conv_NCHW_F32",
111 [](
const FuseBatchNormalizeSelectorData &data)
117 {
"fused_batch_normalization_dwc_NHWC_F32",
118 [](
const FuseBatchNormalizeSelectorData &data)
124 {
"fused_batch_normalization_dwc_NCHW_F32",
125 [](
const FuseBatchNormalizeSelectorData &data)
140 const FBNUKernel *get_implementation(
const FuseBatchNormalizeSelectorData &data)
142 for (
const auto &uk : available_kernels)
144 if (uk.is_selected(data))
153 const ITensorInfo *bn_mean,
154 const ITensorInfo *bn_var,
155 const ITensorInfo *fused_weights,
156 const ITensorInfo *fused_bias,
157 const ITensorInfo *input_bias,
158 const ITensorInfo *bn_beta,
159 const ITensorInfo *bn_gamma,
183 if (input_bias !=
nullptr)
189 if (bn_beta !=
nullptr)
195 if (bn_gamma !=
nullptr)
202 if (fused_weights !=
nullptr && fused_weights->total_size() != 0)
209 if (fused_bias !=
nullptr && fused_bias->total_size() != 0)
221 : _input_weights(nullptr),
222 _input_bias(nullptr),
227 _fused_weights(nullptr),
228 _fused_bias(nullptr),
230 _run_in_place_weights(false),
231 _run_in_place_bias(false),
249 _input_weights = input_weights;
250 _input_bias = input_bias;
254 _bn_gamma = bn_gamma;
255 _fused_weights = fused_weights;
256 _fused_bias = fused_bias;
259 _run_in_place_weights = (fused_weights ==
nullptr) || (fused_weights == input_weights);
260 _run_in_place_bias = (fused_bias ==
nullptr) || (input_bias !=
nullptr && fused_bias == input_bias);
263 if (_fused_weights !=
nullptr)
268 if (_fused_bias !=
nullptr)
277 (fused_weights !=
nullptr) ? fused_weights->
info() :
nullptr,
278 (fused_bias !=
nullptr) ? fused_bias->
info() :
nullptr, (input_bias !=
nullptr) ? input_bias->
info() :
nullptr,
279 (bn_beta !=
nullptr) ? bn_beta->
info() :
nullptr, (bn_gamma !=
nullptr) ? bn_gamma->
info() :
nullptr,
epsilon,
282 const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{
290 INEKernel::configure(win);
316 (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon,
@ NCHW
Num samples, channels, height, width.
virtual DataLayout data_layout() const =0
Get the data layout of the tensor.
decltype(strategy::transforms) typedef type
FuseBatchNormalizationType fbn_type
DataLayout
[DataLayout enum definition]
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
NEFuseBatchNormalizationKernel()
Default constructor.
void fused_batch_normalization_conv_f16(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
@ NHWC
Num samples, height, width, channels.
static Status validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, const ITensorInfo *input_bias=nullptr, const ITensorInfo *bn_beta=nullptr, const ITensorInfo *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Static function to check if given info will lead to a valid configuration of NEFuseBatchNormalization...
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
static CPUInfo & get()
Access the KernelLibrary singleton.
Interface for CPU tensor.
#define REGISTER_FP16_NEON(func_name)
void fused_batch_normalization_dwc_nchw_f16(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
Includes all wrapper headers at once.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
FuseBatchNormalizationType
Available FuseBatchNormalizationType.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
#define REGISTER_FP32_NEON(func_name)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
void configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var, ITensor *fused_weights, ITensor *fused_bias, const ITensor *input_bias=nullptr, const ITensor *bn_beta=nullptr, const ITensor *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Set the source, destination of the kernel.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
void fused_batch_normalization_conv_f32(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
#define ARM_COMPUTE_ERROR_THROW_ON(status)
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
@ CONVOLUTION
For Convolution weights.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
void fused_batch_normalization_dwc_nhwc_f32(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...)
@ DEPTHWISECONVOLUTION
For Depthwise Convolution weights.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
cpuinfo::CpuIsaInfo get_isa() const
Gets the current cpu's ISA information.
const Window & window() const
The maximum window the kernel can be executed on.
Information about executing thread and CPU.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
size_t get_data_layout_dimension_index(const DataLayout &data_layout, const DataLayoutDimension &data_layout_dimension)
Get the index of the given dimension.
Describe a multidimensional execution window.
Copyright (c) 2017-2024 Arm Limited.
@ F16
16-bit floating-point number
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
void fused_batch_normalization_dwc_nhwc_f16(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
@ F32
32-bit floating-point number
void fused_batch_normalization_dwc_nchw_f32(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
DataType
Available data types.
const FBNSelectorPtr is_selected