24.02.1
|
Go to the documentation of this file.
24 #ifndef ARM_COMPUTE_NEFUSEBATCHNORMALIZATION_H
25 #define ARM_COMPUTE_NEFUSEBATCHNORMALIZATION_H
35 class NEFuseBatchNormalizationKernel;
83 const ITensor *input_bias =
nullptr,
84 const ITensor *bn_beta =
nullptr,
85 const ITensor *bn_gamma =
nullptr,
120 std::unique_ptr<NEFuseBatchNormalizationKernel> _fuse_bn_kernel;
FuseBatchNormalizationType fbn_type
Base class for all functions.
void run() override
Run the kernels contained in the function.
Interface for CPU tensor.
FuseBatchNormalizationType
Available FuseBatchNormalizationType.
NEFuseBatchNormalization & operator=(const NEFuseBatchNormalization &)=delete
Prevent instances of this class from being copied (As this class contains pointers)
@ CONVOLUTION
For Convolution weights.
static Status validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, const ITensorInfo *input_bias=nullptr, const ITensorInfo *bn_beta=nullptr, const ITensorInfo *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Static function to check if given info will lead to a valid configuration of NEFuseBatchNormalization...
NEFuseBatchNormalization()
Default constructor.
Copyright (c) 2017-2024 Arm Limited.
Basic function to fuse the batch normalization node to a preceding convolution node.
void configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var, ITensor *fused_weights, ITensor *fused_bias, const ITensor *input_bias=nullptr, const ITensor *bn_beta=nullptr, const ITensor *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Set the input and output tensors.
Store the tensor's metadata.
~NEFuseBatchNormalization()
Default destructor.