24 #ifndef ARM_COMPUTE_NEFUSEBATCHNORMALIZATION_H 25 #define ARM_COMPUTE_NEFUSEBATCHNORMALIZATION_H 35 class NEFuseBatchNormalizationKernel;
69 const ITensor *input_bias =
nullptr,
const ITensor *bn_beta =
nullptr,
const ITensor *bn_gamma =
nullptr,
97 std::unique_ptr<NEFuseBatchNormalizationKernel> _fuse_bn_kernel;
void configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var, ITensor *fused_weights, ITensor *fused_bias, const ITensor *input_bias=nullptr, const ITensor *bn_beta=nullptr, const ITensor *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Set the input and output tensors.
Base class for all functions.
NEFuseBatchNormalization()
Default constructor.
void run() override
Run the kernels contained in the function.
NEFuseBatchNormalization & operator=(const NEFuseBatchNormalization &)=delete
Prevent instances of this class from being copied (As this class contains pointers) ...
Store the tensor's metadata.
Interface for Neon tensor.
Basic function to fuse the batch normalization node to a preceding convolution node.
Copyright (c) 2017-2021 Arm Limited.
FuseBatchNormalizationType
Available FuseBatchNormalizationType.
~NEFuseBatchNormalization()
Default destructor.
static Status validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, const ITensorInfo *input_bias=nullptr, const ITensorInfo *bn_beta=nullptr, const ITensorInfo *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Static function to check if given info will lead to a valid configuration of NEFuseBatchNormalization...