37 void run_permute(ClPermute *op,
const ITensor *
src, ITensor *
dst)
48 : _permute_input(std::make_unique<
ClPermute>()),
49 _permute_output(std::make_unique<
ClPermute>()),
50 _max_shift_exp_sum_kernel(std::make_unique<kernels::ClLogits1DMaxShiftExpSumKernel>()),
51 _norm_kernel(std::make_unique<kernels::ClLogits1DNormKernel>()),
52 _max_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::
MAX)]),
53 _sum_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::
SUM)]),
54 _tmp_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::TMP)]),
55 _permuted_src_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)]),
56 _permuted_dst_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)])
60 TensorType ClSoftmax::convert_internal_idx_to_tensor_type(InternalTensorIdx idx)
const 64 case InternalTensorIdx::MAX:
66 case InternalTensorIdx::SUM:
68 case InternalTensorIdx::TMP:
70 case InternalTensorIdx::PERMUTED_SRC:
72 case InternalTensorIdx::PERMUTED_DST:
81 void ClSoftmax::create_internal_tensor(TensorInfo &
info, InternalTensorIdx idx)
83 const auto tensor_idx = static_cast<uint32_t>(idx);
84 if(!_internal_tensor[tensor_idx])
86 _internal_tensor[tensor_idx] = std::make_unique<CLTensor>();
88 _internal_tensor[tensor_idx]->allocator()->init(
info);
91 void ClSoftmax::create_internal_tensor()
93 for(uint32_t i = 0; i < static_cast<uint32_t>(InternalTensorIdx::COUNT); i++)
95 const auto tensor_idx = static_cast<InternalTensorIdx>(i);
97 if(!_needs_permute && (tensor_idx == InternalTensorIdx::PERMUTED_DST || tensor_idx == InternalTensorIdx::PERMUTED_SRC))
101 create_internal_tensor(_internal_info[i], static_cast<InternalTensorIdx>(i));
109 const size_t actual_axis = static_cast<size_t>(
wrap_around(
info.axis, static_cast<int32_t>(
src.num_dimensions())));
111 _needs_permute = actual_axis != 0;
113 const ITensorInfo &tmp_input_info = _needs_permute ? _permuted_src_info :
src;
114 ITensorInfo &tmp_output_info = _needs_permute ? _permuted_dst_info :
dst;
119 _permute_input->configure(compile_context, &
src, &_permuted_src_info, perm_info);
123 _tmp_info = tmp_input_info.
clone()->set_data_type(tmp_data_type);
126 _max_info = tmp_input_info.
clone()->set_tensor_shape(max_sum_shape);
127 _sum_info = tmp_input_info.
clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type);
132 _max_shift_exp_sum_kernel->configure(compile_context, tmp_input_info, _max_info, _tmp_info, _sum_info,
info);
133 _norm_kernel->configure(compile_context, _tmp_info, _sum_info, tmp_output_info,
info);
138 _permute_output->configure(compile_context, &_permuted_dst_info, &
dst, perm_info);
148 const size_t actual_axis = static_cast<size_t>(
wrap_around(
info.axis, static_cast<int32_t>(
src.num_dimensions())));
149 const bool needs_permute = actual_axis != 0;
154 TensorInfo input_permuted(
src.clone()->set_tensor_shape(permuted_shape));
156 TensorInfo output_permuted(
dst.clone()->set_tensor_shape(permuted_shape));
162 TensorInfo tensor_info_tmp(
src.clone()->set_data_type(tmp_data_type).set_is_resizable(
true));
165 max_sum_shape.set(0, 1);
166 TensorInfo tensor_info_max(
src.clone()->set_tensor_shape(max_sum_shape).set_is_resizable(
true));
167 TensorInfo tensor_info_sum(
src.clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(
QuantizationInfo()).set_is_resizable(
true));
175 void ClSoftmax::import_workspace_memory(
ITensorPack &tensors)
177 auto import_workspace_memory = [
this, &tensors](InternalTensorIdx idx)
179 const auto workspace_idx = convert_internal_idx_to_tensor_type(idx);
180 auto imported_tensor = tensors.
get_tensor(workspace_idx);
183 auto imported_memory = utils::cast::polymorphic_downcast<ICLTensor *>(imported_tensor)->cl_buffer();
184 _internal_tensor[static_cast<uint32_t>(idx)].get()->allocator()->import_memory(imported_memory);
188 import_workspace_memory(InternalTensorIdx::PERMUTED_SRC);
189 import_workspace_memory(InternalTensorIdx::PERMUTED_DST);
190 import_workspace_memory(InternalTensorIdx::MAX);
191 import_workspace_memory(InternalTensorIdx::SUM);
192 import_workspace_memory(InternalTensorIdx::TMP);
195 void ClSoftmax::run_source_permute(
const ITensor *
src)
199 auto permuted_src = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get();
200 run_permute(_permute_input.get(),
src, permuted_src);
204 void ClSoftmax::run_destination_permute(ITensor *
dst)
208 auto permuted_dst = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get();
209 run_permute(_permute_output.get(), permuted_dst,
dst);
213 void ClSoftmax::run_max_sum(
const ITensor *
src)
215 auto max = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::MAX)].get();
216 auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
217 auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
221 ITensorPack sum_pack;
230 void ClSoftmax::run_norm(ITensor *
dst)
232 auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
233 auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
237 ITensorPack norm_pack;
247 create_internal_tensor();
252 import_workspace_memory(tensors);
253 run_source_permute(
src);
254 run_max_sum(!_needs_permute ?
src : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get());
255 run_norm(!_needs_permute ?
dst : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get());
256 run_destination_permute(
dst);
263 req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::SUM), _sum_info.
total_size(), 0);
264 req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::TMP), _tmp_info.
total_size(), 0);
265 req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::MAX), _max_info.
total_size(), 0);
269 req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info.
total_size(), 0);
270 req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info.
total_size(), 0);
TensorShape compute_permutation_output_shape(const ITensorInfo &input, const PermutationVector &perm)
Calculate the permuted shape of an input given a permutation vector.
static Status validate(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
Static function to check if given info will lead to a valid configuration of ClLogits1DNormKernel.
PermutationVector get_permutation_vector_from_softmax_axis(size_t axis)
Given a softmax axis, this function returns the permutation vector required to put the axis to the fr...
static CLScheduler & get()
Access the scheduler singleton.
void configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info)
Configure the operator.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
Static function to check if the given info will lead to a valid configuration.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Store the tensor's metadata.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
std::vector< MemoryInfo > MemoryRequirements
static Status validate(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum)
Static function to check if given info will lead to a valid configuration of ClLogits1DMaxShiftExpSum...
1 channel, 1 S32 per channel
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
Quantization information.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
size_t total_size() const override
Returns the total size of the tensor in bytes.
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
Strides of an item in bytes.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
Descriptor used by the softmax kernels.
Basic function to run kernels::ClPermuteKernel.
DataType
Available data types.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
Static function to check if given info will lead to a valid configuration of kernels::ClPermuteKernel...