24.02.1
|
Go to the documentation of this file.
39 : _inst_norm_kernel(), _mean_var_kernel(), _mean_var_tensor(), _ctx(ctx)
58 bool use_mixed_precision)
61 auto w = std::make_unique<CLComputeMeanVariance>();
62 w->configure(compile_context,
input, &_mean_var_tensor, use_mixed_precision);
63 _mean_var_kernel = std::move(
w);
64 auto k = std::make_unique<CLInstanceNormalizationLayerKernel>();
65 k->configure(compile_context,
input, &_mean_var_tensor, output,
67 _inst_norm_kernel = std::move(k);
76 bool use_mixed_precision)
85 "The child class didn't set the CL kernel or function isn't configured");
void schedule_kernel_on_ctx(CLRuntimeContext *ctx, ICLKernel *kernel, bool flush=true)
Schedules a kernel using the context if not nullptr else uses the legacy scheduling flow.
~CLInstanceNormalizationLayer()
Default destructor.
Interface for OpenCL tensor.
void configure(ICLTensor *input, ICLTensor *output, float gamma=1.0f, float beta=0.0f, float epsilon=1e-12f, bool use_mixed_precision=true)
Set the input and output tensors.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
CLInstanceNormalizationLayer(CLRuntimeContext *ctx=nullptr)
Constructor.
void run() override
Run the kernels contained in the function.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, float gamma=1.0f, float beta=0.0f, float epsilon=1e-12f, bool use_mixed_precision=true)
Static function to check if given info will lead to a valid configuration of CLInstanceNormalizationL...
void allocate() override
Allocate size specified by TensorInfo of OpenCL memory.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
Static function to check if given info will lead to a valid configuration of CLInstanceNormalizationL...
CLTensorAllocator * allocator()
Return a pointer to the tensor's allocator.
Copyright (c) 2017-2024 Arm Limited.
Store the tensor's metadata.
#define ARM_COMPUTE_LOG_PARAMS(...)