16 inline bool ValidateAxis(
int axis,
unsigned int numDimensions)
18 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
19 return axis < sNumDimensions && axis >= -sNumDimensions;
35 "Axis index is not in range [-numDimensions, numDimensions).");
37 unsigned int uAxis = descriptor.
m_Axis < 0 ?
38 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.
m_Axis)) :
39 armnn::numeric_cast<unsigned int>(descriptor.
m_Axis);
43 const unsigned int axisSize = inputShape[uAxis];
48 for (
unsigned int outer = 0; outer < outerSize; ++outer)
50 for (
unsigned int inner = 0; inner < innerSize; ++inner)
53 input[outer * axisSize * innerSize + inner];
54 float maxValue = input.
Get();
55 for (
unsigned int i = 1u; i < axisSize; ++i)
57 input[(outer * axisSize + i) * innerSize + inner];
58 maxValue = std::max(maxValue, input.
Get());
63 for (
unsigned int i = 0u; i < axisSize; ++i)
65 input[(outer * axisSize + i) * innerSize + inner];
66 sum += std::exp((input.
Get() - maxValue) * descriptor.
m_Beta);
70 const float logSum = std::log(sum);
73 for (
unsigned int i = 0u; i < axisSize; ++i)
75 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80 output.
Set((input.
Get() - maxValue) * descriptor.
m_Beta - logSum);