24.02.1
|
Go to the documentation of this file.
44 CpuSoftmaxGeneric::CpuSoftmaxGeneric()
51 _needs_permute(false),
52 _aux_mem(InternalTensorIdx::COUNT)
63 const unsigned int actual_axis =
64 static_cast<unsigned int>(
wrap_around(axis,
static_cast<int32_t
>(
src->num_dimensions())));
66 _needs_permute = actual_axis > 0;
76 const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted :
src);
90 auto sm = std::make_unique<kernels::CpuSoftmaxKernel>();
94 sm->configure(tmp_input, &_output_permuted, beta, is_log, &_tmp);
103 sm->configure(tmp_input,
dst, beta, is_log, &_tmp);
105 _softmax_kernel = std::move(sm);
109 _aux_mem[InternalTensorIdx::TMP] =
114 MemoryLifetime::Temporary, _input_permuted.
total_size());
116 MemoryLifetime::Temporary, _output_permuted.
total_size());
127 static_cast<int32_t
>(
src->num_dimensions()) <= axis);
134 tensor_info_tmp =
src->clone()->set_data_type(
DataType::F32).set_is_resizable(
true);
137 const unsigned int actual_axis =
138 static_cast<unsigned int>(
wrap_around(axis,
static_cast<int32_t
>(
src->num_dimensions())));
140 const bool needs_permute = actual_axis > 0;
148 TensorInfo input_permuted(
src->clone()->set_tensor_shape(permuted_shape));
150 TensorInfo output_permuted(
dst->clone()->set_tensor_shape(permuted_shape));
177 _permute_input.
run(permute_in_pack);
195 _permute_output.
run(permute_out_pack);
Tensor handler to wrap and handle tensor allocations on workspace buffers.
std::vector< MemoryInfo > MemoryRequirements
SimpleTensor< float > src
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
TensorShape compute_permutation_output_shape(const ITensorInfo &input, const PermutationVector &perm)
Calculate the permuted shape of an input given a permutation vector.
void configure(const ITensorInfo *src, ITensorInfo *dst, float beta=1.0f, int32_t axis=0, bool is_log=false)
Set the input and output tensors.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
Static function to check if given info will lead to a valid configuration.
virtual void schedule_op(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors)=0
Runs the kernel in the same thread as the caller synchronously.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
Static function to check if given info will lead to a valid configuration.
void add_tensor(int id, ITensor *tensor)
Add tensor to the pack.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
T wrap_around(T x, T m)
Wrap-around a number within the range 0 <= x < m.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Strides of an item in bytes.
bool empty() const
Checks if pack is empty.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
size_t total_size() const override
Returns the total size of the tensor in bytes.
static IScheduler & get()
Access the scheduler singleton.
PermutationVector get_permutation_vector_from_softmax_axis(size_t axis)
Given a softmax axis, this function returns the permutation vector required to put the axis to the fr...
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
Store the tensor's metadata.
int offset_int_vec(int offset)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Copyright (c) 2017-2024 Arm Limited.
void configure(const ITensorInfo *src, ITensorInfo *dst, const PermutationVector &perm)
Configure operator for a given list of arguments.
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, float beta=1.0f, int32_t axis=0, bool is_log=false)
Static function to check if given info will lead to a valid configuration.
Store the tensor's metadata.
@ F32
32-bit floating-point number
#define ARM_COMPUTE_LOG_PARAMS(...)