39 template <
bool IS_LOG>
41 : _memory_group(
std::move(memory_manager)),
55 template <
bool IS_LOG>
58 template <
bool IS_LOG>
64 template <
bool IS_LOG>
73 _needs_permute = actual_axis != 0;
75 const ICLTensor *tmp_input = _needs_permute ? &_input_permuted :
input;
78 _memory_group.
manage(&_input_permuted);
79 _memory_group.
manage(&_output_permuted);
81 tmp_output = &_output_permuted;
89 max_sum_shape.
set(0, 1);
91 _sum.
allocator()->
init(tmp_input->
info()->
clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type));
97 _memory_group.
manage(&_tmp);
98 _memory_group.
manage(&_max);
99 _memory_group.
manage(&_sum);
102 softmax_info.
beta = beta;
103 softmax_info.
is_log = IS_LOG;
107 _max_shift_exp_sum_kernel->configure(compile_context, tmp_input, &_max, &_tmp, &_sum, softmax_info);
108 _norm_kernel->configure(compile_context, &_tmp, &_sum, tmp_output, softmax_info);
122 template <
bool IS_LOG>
131 const bool needs_permute = actual_axis != 0;
136 TensorInfo input_permuted(input->
clone()->set_tensor_shape(permuted_shape));
138 TensorInfo output_permuted(output->
clone()->set_tensor_shape(permuted_shape));
144 TensorInfo tensor_info_tmp(input->
clone()->set_data_type(tmp_data_type).set_is_resizable(
true));
147 max_sum_shape.
set(0, 1);
148 TensorInfo tensor_info_max(input->
clone()->set_tensor_shape(max_sum_shape).set_is_resizable(
true));
149 TensorInfo tensor_info_sum(input->
clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(
QuantizationInfo()).set_is_resizable(
true));
152 softmax_info.
beta = beta;
153 softmax_info.is_log = IS_LOG;
154 softmax_info.input_data_type = input->
data_type();
162 template <
bool IS_LOG>
169 _permute_input.
run();
177 _permute_output.
run();
Interface for max, shifting, exponentiating and summing the logits.
void run() override
Run the kernels contained in the function.
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
TensorShape compute_permutation_output_shape(const ITensorInfo &input, const PermutationVector &perm)
Calculate the permuted shape of an input given a permutation vector.
CLSoftmaxLayerGeneric(std::shared_ptr< IMemoryManager > memory_manager=nullptr)
Constructor.
PermutationVector get_permutation_vector_from_softmax_axis(size_t axis)
Given a softmax axis, this function returns the permutation vector required to put the axis to the fr...
static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info)
Static function to check if given info will lead to a valid configuration of CLLogits1DNormKernel.
static CLScheduler & get()
Access the scheduler singleton.
float beta
A scaling factor for the exponent with default value 1.0.
void configure(const ICLTensor *input, ICLTensor *output, float beta=1.0f, int32_t axis=0)
Set the input and output tensors.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Store the tensor's metadata.
CLTensorAllocator * allocator()
Return a pointer to the tensor's allocator.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
void init(const TensorInfo &input, size_t alignment=0)
Initialize a tensor based on the passed TensorInfo.
Copyright (c) 2017-2021 Arm Limited.
void run() override
Run the kernels contained in the function.
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
1 channel, 1 S32 per channel
void manage(IMemoryManageable *obj) override
Sets a object to be managed by the given memory group.
Basic function to compute a SoftmaxLayer.
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
Quantization information.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
DataType input_data_type
Input tensor data type.
bool is_log
Flag used to perform Log Softmax operation.
Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, float beta=1.0f, int32_t axis=0)
Static function to check if given info will lead to a valid configuration of CLSoftmaxLayer.
void enqueue(ICLKernel &kernel, bool flush=true)
Schedule the execution of the passed kernel if possible.
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
Strides of an item in bytes.
void allocate() override
Allocate size specified by TensorInfo of OpenCL memory.
Memory group resources scope handling class.
Interface for OpenCL tensor.
static Status validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
Static function to check if given info will lead to a valid configuration of CLLogits1DMaxShiftExpSum...
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
void configure(const ICLTensor *input, ICLTensor *output, const PermutationVector &perm)
Set the input and output tensors.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
Static function to check if given info will lead to a valid configuration of CLPermute.
Descriptor used by the softmax kernels.
DataType
Available data types.
TensorShape & set(size_t dimension, size_t value, bool apply_dim_correction=true, bool increase_dim_unit=true)
Accessor to set the value of one of the dimensions.
~CLSoftmaxLayerGeneric()
Default destructor.