24.02.1
|
Go to the documentation of this file.
42 template <
bool IS_LOG>
43 struct CLSoftmaxLayerGeneric<IS_LOG>::Impl
45 const ICLTensor *
src{
nullptr};
46 ICLTensor *
dst{
nullptr};
47 std::unique_ptr<OperatorType> op{
nullptr};
48 MemoryGroup memory_group{};
49 ITensorPack run_pack{};
50 WorkspaceData<CLTensor> workspace_tensors{};
53 template <
bool IS_LOG>
55 : _impl(std::make_unique<Impl>())
57 _impl->memory_group =
MemoryGroup(std::move(memory_manager));
60 template <
bool IS_LOG>
63 template <
bool IS_LOG>
69 template <
bool IS_LOG>
75 _impl->op = std::make_unique<OperatorType>();
78 _impl->op->configure(compile_context, *
input->info(), *output->
info(), softmax_info);
81 _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
84 template <
bool IS_LOG>
92 template <
bool IS_LOG>
98 _impl->op->run(_impl->run_pack);
Descriptor used by the softmax kernels.
im2col_func configure(src_target.info(), dst_target.info(), spatial_kernel, conv_info, has_bias)
SimpleTensor< float > src
void run() override
Run the kernels contained in the function.
Interface for OpenCL tensor.
CLSoftmaxLayerGeneric(std::shared_ptr< IMemoryManager > memory_manager=nullptr)
Constructor.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
void configure(const ICLTensor *input, ICLTensor *output, float beta=1.0f, int32_t axis=0)
Set the input and output tensors.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
~CLSoftmaxLayerGeneric()
Default destructor.
Basic function to compute a SoftmaxLayer.
opencl::ClGemm OperatorType
static Status validate(const ITensorInfo *input, const ITensorInfo *output, float beta=1.0f, int32_t axis=0)
Static function to check if given info will lead to a valid configuration of CLSoftmaxLayer.
Memory group resources scope handling class.
Copyright (c) 2017-2024 Arm Limited.
Store the tensor's metadata.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.