Compute Library
 22.02
ClSoftmax Class Reference

#include <ClSoftmax.h>

Collaboration diagram for ClSoftmax:
[legend]

Public Member Functions

 ClSoftmax ()
 Constructor. More...
 
void configure (const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info)
 Configure the operator. More...
 
void run (ITensorPack &tensors) override
 Run the kernels contained in the function. More...
 
experimental::MemoryRequirements workspace () const override
 Return the memory requirements required by the workspace. More...
 
- Public Member Functions inherited from ICLOperator
 ICLOperator (IRuntimeContext *ctx=nullptr)
 Constructor. More...
 
 ICLOperator (const ICLOperator &)=delete
 Prevent instances of this class from being copied (As this class contains pointers) More...
 
 ICLOperator (ICLOperator &&)=default
 Default move constructor. More...
 
ICLOperatoroperator= (const ICLOperator &)=delete
 Prevent instances of this class from being copied (As this class contains pointers) More...
 
ICLOperatoroperator= (ICLOperator &&)=default
 Default move assignment operator. More...
 
void prepare (ITensorPack &constants) override
 Prepare the function for executing. More...
 
- Public Member Functions inherited from IOperator
virtual ~IOperator ()=default
 Destructor. More...
 

Static Public Member Functions

static Status validate (const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
 Static function to check if given info will lead to a valid configuration. More...
 

Detailed Description

Definition at line 43 of file ClSoftmax.h.

Constructor & Destructor Documentation

◆ ClSoftmax()

ClSoftmax ( )

Constructor.

Definition at line 41 of file ClSoftmax.cpp.

42  : _permute_input(std::make_unique<ClPermute>()),
43  _permute_output(std::make_unique<ClPermute>()),
44  _max_shift_exp_sum_kernel(std::make_unique<kernels::ClLogits1DMaxShiftExpSumKernel>()),
45  _norm_kernel(std::make_unique<kernels::ClLogits1DNormKernel>()),
46  _max_info(),
47  _sum_info(),
48  _tmp_info(),
49  _permuted_src_info(),
50  _permuted_dst_info(),
51  _aux_mem(InternalTensorIdx::COUNT)
52 {
53 }

Member Function Documentation

◆ configure()

void configure ( const CLCompileContext compile_context,
const ITensorInfo src,
ITensorInfo dst,
const SoftmaxKernelInfo info 
)

Configure the operator.

Parameters
[in]compile_contextThe compile context to be used.
[in]srcSource tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 for Softmax and F16/F32 for Log Softmax
[out]dstDestination tensor info. Data types supported: same as src
[in]infoContains information consumed by kernels for softmax described in SoftmaxKernelInfo.

Definition at line 55 of file ClSoftmax.cpp.

References ARM_COMPUTE_ERROR_THROW_ON, ARM_COMPUTE_LOG_PARAMS, SoftmaxKernelInfo::axis, ICloneable< T >::clone(), ITensorInfo::data_type(), arm_compute::test::validation::dst, CLScheduler::get(), arm_compute::softmax_helpers::get_permutation_vector_from_softmax_axis(), arm_compute::is_data_type_quantized_asymmetric(), arm_compute::MAX, ITensorInfo::num_dimensions(), arm_compute::offset_int_vec(), arm_compute::S32, arm_compute::test::validation::src, arm_compute::SUM, ITensorInfo::tensor_shape(), TensorInfo::total_size(), ClSoftmax::validate(), and arm_compute::wrap_around().

56 {
59 
60  const size_t actual_axis = static_cast<size_t>(wrap_around(info.axis, static_cast<int32_t>(src.num_dimensions())));
61 
62  _needs_permute = actual_axis != 0;
63 
64  const ITensorInfo &tmp_input_info = _needs_permute ? _permuted_src_info : src;
65  ITensorInfo &tmp_output_info = _needs_permute ? _permuted_dst_info : dst;
66 
67  if(_needs_permute)
68  {
69  const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
70  _permute_input->configure(compile_context, &src, &_permuted_src_info, perm_info);
71  }
72 
73  DataType tmp_data_type = is_data_type_quantized_asymmetric(tmp_input_info.data_type()) ? DataType::S32 : tmp_input_info.data_type();
74  _tmp_info = tmp_input_info.clone()->set_data_type(tmp_data_type);
75 
76  TensorShape max_sum_shape = tmp_input_info.tensor_shape();
77  _max_info = tmp_input_info.clone()->set_tensor_shape(max_sum_shape);
78  _sum_info = tmp_input_info.clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type);
79 
80  // Set GPU target to kernels
81  _max_shift_exp_sum_kernel->set_target(CLScheduler::get().target());
82 
83  _max_shift_exp_sum_kernel->configure(compile_context, tmp_input_info, _max_info, _tmp_info, _sum_info, info);
84  _norm_kernel->configure(compile_context, _tmp_info, _sum_info, tmp_output_info, info);
85 
86  if(_needs_permute)
87  {
88  const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
89  _permute_output->configure(compile_context, &_permuted_dst_info, &dst, perm_info);
90  }
91 
92  _aux_mem[InternalTensorIdx::SUM] = MemoryInfo(offset_int_vec(InternalTensorIdx::SUM), MemoryLifetime::Temporary, _sum_info.total_size());
93  _aux_mem[InternalTensorIdx::TMP] = MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp_info.total_size());
94  _aux_mem[InternalTensorIdx::MAX] = MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max_info.total_size());
95 
96  _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), MemoryLifetime::Temporary, _permuted_src_info.total_size());
97  _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST), MemoryLifetime::Temporary, _permuted_dst_info.total_size());
98 }
std::unique_ptr< ITensorInfo > clone() const override
Provide a clone of the current object of class T.
Definition: TensorInfo.cpp:282
PermutationVector get_permutation_vector_from_softmax_axis(size_t axis)
Given a softmax axis, this function returns the permutation vector required to put the axis to the fr...
static CLScheduler & get()
Access the scheduler singleton.
static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
Static function to check if given info will lead to a valid configuration.
Definition: ClSoftmax.cpp:100
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
SimpleTensor< float > src
Definition: DFT.cpp:155
1 channel, 1 S32 per channel
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
Definition: Helpers.h:247
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:250
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
Definition: Utils.h:1018
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_LOG_PARAMS(...)
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
DataType
Available data types.
Definition: Types.h:79

◆ run()

void run ( ITensorPack tensors)
overridevirtual

Run the kernels contained in the function.

Parameters
[in]tensorsVector that contains the tensors to operate on.

Reimplemented from ICLOperator.

Definition at line 133 of file ClSoftmax.cpp.

References arm_compute::ACL_DST, arm_compute::ACL_INT_0, arm_compute::ACL_INT_1, arm_compute::ACL_SRC, ITensorPack::add_const_tensor(), ITensorPack::add_tensor(), arm_compute::test::validation::dst, CLScheduler::enqueue_op(), CLScheduler::get(), CLAuxTensorHandler::get(), ITensorPack::get_const_tensor(), ITensorPack::get_tensor(), arm_compute::MAX, arm_compute::offset_int_vec(), arm_compute::test::validation::pack, arm_compute::test::validation::src, and arm_compute::SUM.

134 {
135  auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
136  auto dst = tensors.get_tensor(TensorType::ACL_DST);
137 
138  CLAuxTensorHandler sum(offset_int_vec(InternalTensorIdx::SUM), _sum_info, tensors, false);
139  CLAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp_info, tensors, false);
140  CLAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max_info, tensors, false);
141 
142  CLAuxTensorHandler permuted_src(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info, tensors, false);
143  CLAuxTensorHandler permuted_dst(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info, tensors, false);
144 
145  if(_needs_permute)
146  {
147  ITensorPack pack;
149  pack.add_tensor(TensorType::ACL_DST, permuted_src.get());
150  _permute_input.get()->run(pack);
151  }
152 
153  ITensorPack sum_pack;
154  ITensorPack norm_pack;
155  if(_needs_permute)
156  {
157  sum_pack.add_const_tensor(TensorType::ACL_SRC, permuted_src.get());
158  norm_pack.add_tensor(TensorType::ACL_DST, permuted_dst.get());
159  }
160  else
161  {
162  sum_pack.add_const_tensor(TensorType::ACL_SRC, src);
163  norm_pack.add_tensor(TensorType::ACL_DST, dst);
164  }
165  sum_pack.add_tensor(TensorType::ACL_DST, tmp.get());
166  sum_pack.add_tensor(TensorType::ACL_INT_0, max.get());
167  sum_pack.add_tensor(TensorType::ACL_INT_1, sum.get());
168 
169  norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp.get());
170  norm_pack.add_tensor(TensorType::ACL_INT_0, sum.get());
171 
172  CLScheduler::get().enqueue_op(*_max_shift_exp_sum_kernel.get(), sum_pack, false);
173  CLScheduler::get().enqueue_op(*_norm_kernel.get(), norm_pack, false);
174 
175  if(_needs_permute)
176  {
177  ITensorPack pack;
178  pack.add_const_tensor(TensorType::ACL_SRC, permuted_dst.get());
179  pack.add_tensor(TensorType::ACL_DST, dst);
180  _permute_output.get()->run(pack);
181  }
182 }
void add_const_tensor(int id, const ITensor *tensor)
Add const tensor to the pack.
Definition: ITensorPack.cpp:49
static CLScheduler & get()
Access the scheduler singleton.
SimpleTensor< float > src
Definition: DFT.cpp:155
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38

◆ validate()

Status validate ( const ITensorInfo src,
const ITensorInfo dst,
const SoftmaxKernelInfo info 
)
static

Static function to check if given info will lead to a valid configuration.

Similar to ClSoftmax::configure()

Returns
a status

Definition at line 100 of file ClSoftmax.cpp.

References ARM_COMPUTE_RETURN_ERROR_ON, ARM_COMPUTE_RETURN_ERROR_ON_MSG, ARM_COMPUTE_RETURN_ON_ERROR, ARM_COMPUTE_UNUSED, SoftmaxKernelInfo::axis, SoftmaxKernelInfo::beta, ICloneable< T >::clone(), arm_compute::misc::shape_calculator::compute_permutation_output_shape(), ITensorInfo::data_type(), arm_compute::softmax_helpers::get_permutation_vector_from_softmax_axis(), arm_compute::is_data_type_quantized_asymmetric(), ITensorInfo::num_dimensions(), arm_compute::S32, TensorShape::set(), ITensorInfo::tensor_shape(), ClPermute::validate(), ClLogits1DMaxShiftExpSumKernel::validate(), ClLogits1DNormKernel::validate(), and arm_compute::wrap_around().

Referenced by ClSoftmax::configure().

101 {
102  ARM_COMPUTE_RETURN_ERROR_ON_MSG(src.num_dimensions() > 4, "Only up to 4 dimensions are supported");
103  ARM_COMPUTE_UNUSED(info.beta);
104  ARM_COMPUTE_RETURN_ERROR_ON(info.axis < static_cast<int32_t>(-src.num_dimensions()) || static_cast<int32_t>(src.num_dimensions()) <= info.axis);
105 
106  const size_t actual_axis = static_cast<size_t>(wrap_around(info.axis, static_cast<int32_t>(src.num_dimensions())));
107  const bool needs_permute = actual_axis != 0;
108  if(needs_permute)
109  {
111  const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(src, permutation_vector);
112  TensorInfo input_permuted(src.clone()->set_tensor_shape(permuted_shape));
113  ARM_COMPUTE_RETURN_ON_ERROR(ClPermute::validate(&src, &input_permuted, permutation_vector));
114  TensorInfo output_permuted(dst.clone()->set_tensor_shape(permuted_shape));
115  ARM_COMPUTE_RETURN_ON_ERROR(ClPermute::validate(&output_permuted, &dst, permutation_vector));
116  }
117 
118  // Create intermediate tensor info
119  DataType tmp_data_type = is_data_type_quantized_asymmetric(src.data_type()) ? DataType::S32 : src.data_type();
120  TensorInfo tensor_info_tmp(src.clone()->set_data_type(tmp_data_type).set_is_resizable(true));
121 
122  TensorShape max_sum_shape = src.tensor_shape();
123  max_sum_shape.set(0, 1);
124  TensorInfo tensor_info_max(src.clone()->set_tensor_shape(max_sum_shape).set_is_resizable(true));
125  TensorInfo tensor_info_sum(src.clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(QuantizationInfo()).set_is_resizable(true));
126 
127  ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClLogits1DMaxShiftExpSumKernel::validate(src, tensor_info_max, tensor_info_tmp, tensor_info_sum));
129 
130  return Status{};
131 }
TensorShape compute_permutation_output_shape(const ITensorInfo &input, const PermutationVector &perm)
Calculate the permuted shape of an input given a permutation vector.
static Status validate(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
Static function to check if given info will lead to a valid configuration.
PermutationVector get_permutation_vector_from_softmax_axis(size_t axis)
Given a softmax axis, this function returns the permutation vector required to put the axis to the fr...
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
Strides PermutationVector
Permutation vector.
Definition: Types.h:51
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
SimpleTensor< float > src
Definition: DFT.cpp:155
static Status validate(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum)
Static function to check if given info will lead to a valid configuration.
1 channel, 1 S32 per channel
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
Definition: Helpers.h:247
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
Definition: Utils.h:1018
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
DataType
Available data types.
Definition: Types.h:79
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
Static function to check if given info will lead to a valid configuration.
Definition: ClPermute.cpp:43

◆ workspace()

experimental::MemoryRequirements workspace ( ) const
overridevirtual

Return the memory requirements required by the workspace.

Reimplemented from ICLOperator.

Definition at line 184 of file ClSoftmax.cpp.

185 {
186  return _aux_mem;
187 }

The documentation for this class was generated from the following files: