Compute Library
GCSoftmaxLayer Class Reference

Basic function to compute a SoftmaxLayer. More...

#include <GCSoftmaxLayer.h>

Collaboration diagram for GCSoftmaxLayer:

Public Member Functions

 GCSoftmaxLayer (std::shared_ptr< IMemoryManager > memory_manager=nullptr)
 Constructor. More...
void configure (const IGCTensor *input, IGCTensor *output, float beta=1.0f, int32_t axis=0)
 Set the input and output tensors. More...
void run () override
 Run the kernels contained in the function. More...
- Public Member Functions inherited from IFunction
virtual ~IFunction ()=default
 Destructor. More...
virtual void prepare ()
 Prepare the function for executing. More...

Detailed Description

Basic function to compute a SoftmaxLayer.

Softmax is calculated by :

\[ out = exp(x - max(x)) / sum(exp(x - max(x))) \]

This function runs the following kernels:

  1. GCLogits1DMaxKernel
  2. GCLogits1DShiftExpSumKernel
  3. GCLogits1DNormKernel
This function is deprecated and is intended to be removed in 21.05 release

Definition at line 49 of file GCSoftmaxLayer.h.

Constructor & Destructor Documentation

◆ GCSoftmaxLayer()

GCSoftmaxLayer ( std::shared_ptr< IMemoryManager memory_manager = nullptr)


Definition at line 32 of file GCSoftmaxLayer.cpp.

33  : _memory_group(std::move(memory_manager)), _max_kernel(), _shift_exp_sum_kernel(), _norm_kernel(), _max(), _sum(), _tmp()
34 {
35 }

Member Function Documentation

◆ configure()

void configure ( const IGCTensor input,
IGCTensor output,
float  beta = 1.0f,
int32_t  axis = 0 

Set the input and output tensors.

[in]inputSource tensor. Data types supported: F16/F32
[out]outputDestination tensor. Data types supported: same as input
[in]beta(Optional) A scaling factor for the exponent. Only beta = 1 is supported
[in]axis(Optional) The dimension in which to apply the function. E.g. for input of shape 4x5x6 and axis=1, softmax will be applied to 4x6=24 vectors of size 5. Defaults to 0
The value of axis must be always 0 for GLES

Definition at line 37 of file GCSoftmaxLayer.cpp.

References ITensorAllocator::allocate(), GCTensor::allocator(), ARM_COMPUTE_ERROR_ON, ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN, ARM_COMPUTE_ERROR_ON_MSG, ARM_COMPUTE_UNUSED, GCLogits1DMaxKernel::configure(), GCLogits1DShiftExpSumKernel::configure(), GCLogits1DNormKernel::configure(), ITensorInfo::data_type(), arm_compute::F16, arm_compute::F32, ITensor::info(), ITensorAllocator::init(), MemoryGroup::manage(), ITensorInfo::num_channels(), TensorShape::set(), arm_compute::test::validation::shape, and ITensorInfo::tensor_shape().

38 {
39  ARM_COMPUTE_UNUSED(beta, axis);
42  ARM_COMPUTE_ERROR_ON(beta != 1.0f);
43  ARM_COMPUTE_ERROR_ON_MSG(axis != 0, "axis must be 0 for GLES");
45  // Create intermediate tensors shapes
46  _tmp.allocator()->init(TensorInfo(input->info()->tensor_shape(), input->info()->num_channels(), input->info()->data_type()));
48  TensorShape shape = input->info()->tensor_shape();
49  shape.set(0, 1);
50  TensorInfo tensor_info_max_sum(shape, input->info()->num_channels(), input->info()->data_type());
51  _max.allocator()->init(tensor_info_max_sum);
52  _sum.allocator()->init(tensor_info_max_sum);
54  // Manage intermediate buffers
55  _memory_group.manage(&_tmp);
56  _memory_group.manage(&_max);
57  _memory_group.manage(&_sum);
59  // Configure Kernels
60  _max_kernel.configure(input, &_max);
61  _shift_exp_sum_kernel.configure(input, &_max, &_tmp, &_sum);
62  _norm_kernel.configure(&_tmp, &_sum, output);
64  // Allocate intermediate buffers
65  _tmp.allocator()->allocate();
66  _max.allocator()->allocate();
67  _sum.allocator()->allocate();
68 }
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
void init(const TensorInfo &input, size_t alignment=0)
Initialize a tensor based on the passed TensorInfo.
1 channel, 1 F16 per channel
void manage(IMemoryManageable *obj) override
Sets a object to be managed by the given memory group.
Definition: MemoryGroup.h:79
To avoid unused variables warnings.
Definition: Error.h:152
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
Definition: Validate.h:790
virtual void allocate()=0
Interface to be implemented by the child class to allocate the tensor.
void configure(const IGCTensor *input, IGCTensor *output)
Set the input and output tensors.
void configure(const IGCTensor *input, const IGCTensor *max, IGCTensor *output, IGCTensor *sum)
Set the input and output tensors.
ITensorAllocator * allocator()
Return a pointer to the tensor&#39;s allocator.
Definition: GCTensor.cpp:34
void configure(const IGCTensor *input, const IGCTensor *sum, IGCTensor *output)
Set the input and output tensors.

◆ run()

void run ( )

Run the kernels contained in the function.

For Neon kernels:

  • Multi-threading is used for the kernels which are parallelisable.
  • By default std::thread::hardware_concurrency() threads are used.
CPPScheduler::set_num_threads() can be used to manually set the number of threads

For OpenCL kernels:

  • All the kernels are enqueued on the queue associated with CLScheduler.
  • The queue is then flushed.
The function will not block until the kernels are executed. It is the user's responsibility to wait.
Will call prepare() on first run if hasn't been done

Implements IFunction.

Definition at line 70 of file GCSoftmaxLayer.cpp.

References GCScheduler::dispatch(), GCScheduler::get(), and GCScheduler::memory_barrier().

71 {
72  MemoryGroupResourceScope scope_mg(_memory_group);
74  GCScheduler::get().dispatch(_max_kernel, false);
76  GCScheduler::get().dispatch(_shift_exp_sum_kernel, false);
78  GCScheduler::get().dispatch(_norm_kernel);
79 }
void dispatch(IGCKernel &kernel, bool flush=true)
Schedule the execution of the passed kernel if possible.
Definition: GCScheduler.cpp:77
void memory_barrier()
Defines a barrier ordering memory transactions.
Definition: GCScheduler.cpp:86
static GCScheduler & get()
Access the scheduler singleton.
Definition: GCScheduler.cpp:70

The documentation for this class was generated from the following files: