18 inline bool ValidateAxis(
int axis,
unsigned int numDimensions)
20 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
21 return axis < sNumDimensions && axis >= -sNumDimensions;
36 bool axisIsValid = ValidateAxis(descriptor.
m_Axis, numDimensions);
38 "Axis index is not in range [-numDimensions, numDimensions).");
41 unsigned int uAxis = descriptor.
m_Axis < 0 ?
42 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.
m_Axis)) :
43 armnn::numeric_cast<unsigned int>(descriptor.
m_Axis);
47 const unsigned int axisSize = inputShape[uAxis];
52 for (
unsigned int outer = 0; outer < outerSize; ++outer)
54 for (
unsigned int inner = 0; inner < innerSize; ++inner)
57 input[outer * axisSize * innerSize + inner];
58 float maxValue = input.
Get();
59 for (
unsigned int i = 1u; i < axisSize; ++i)
61 input[(outer * axisSize + i) * innerSize + inner];
62 maxValue = std::max(maxValue, input.
Get());
67 for (
unsigned int i = 0u; i < axisSize; ++i)
69 input[(outer * axisSize + i) * innerSize + inner];
70 sum += std::exp((input.
Get() - maxValue) * descriptor.
m_Beta);
74 const float logSum = std::log(sum);
77 for (
unsigned int i = 0u; i < axisSize; ++i)
79 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
84 output.
Set((input.
Get() - maxValue) * descriptor.
m_Beta - logSum);