16 inline bool ValidateAxis(
int axis,
unsigned int numDimensions)
18 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
19 return axis < sNumDimensions && axis >= -sNumDimensions;
35 "Axis index is not in range [-numDimensions, numDimensions).");
37 unsigned int uAxis = descriptor.
m_Axis < 0 ?
38 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.
m_Axis)) :
39 armnn::numeric_cast<unsigned int>(descriptor.
m_Axis);
43 const unsigned int axisSize = inputShape[uAxis];
48 for (
unsigned int outer = 0; outer < outerSize; ++outer)
50 for (
unsigned int inner = 0; inner < innerSize; ++inner)
53 input[outer * axisSize * innerSize + inner];
54 float maxValue = input.
Get();
55 for (
unsigned int i = 1u; i < axisSize; ++i)
57 input[(outer * axisSize + i) * innerSize + inner];
58 maxValue = std::max(maxValue, input.
Get());
63 for (
unsigned int i = 0u; i < axisSize; ++i)
65 input[(outer * axisSize + i) * innerSize + inner];
66 sum += std::exp((input.
Get() - maxValue) * descriptor.
m_Beta);
70 const float logSum = std::log(sum);
73 for (
unsigned int i = 0u; i < axisSize; ++i)
75 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80 output.
Set((input.
Get() - maxValue) * descriptor.
m_Beta - logSum);
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
virtual IType Get() const =0
virtual void Set(IType right)=0
unsigned int GetNumDimensions() const
const TensorShape & GetShape() const
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Copyright (c) 2021 ARM Limited and Contributors.
void LogSoftmax(Decoder< float > &input, Encoder< float > &output, const TensorInfo &inputInfo, const LogSoftmaxDescriptor &descriptor)
unsigned int GetNumElementsBetween(const armnn::TensorShape &shape, unsigned int firstAxisInclusive, unsigned int lastAxisExclusive)
A SoftmaxDescriptor for the SoftmaxLayer.
int m_Axis
Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed o...
float m_Beta
Exponentiation value.