ArmNN
 24.02
LogSoftmax.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "LogSoftmax.hpp"
7 
12 
13 #include <cmath>
14 
15 namespace
16 {
17 
18 inline bool ValidateAxis(int axis, unsigned int numDimensions)
19 {
20  const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
21  return axis < sNumDimensions && axis >= -sNumDimensions;
22 }
23 
24 } // anonymous namespace
25 
26 namespace armnn
27 {
28 
30  Encoder<float>& output,
31  const TensorInfo& inputInfo,
32  const LogSoftmaxDescriptor& descriptor)
33 {
34  const unsigned int numDimensions = inputInfo.GetNumDimensions();
35 
36  bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
37  ARMNN_ASSERT_MSG(axisIsValid,
38  "Axis index is not in range [-numDimensions, numDimensions).");
39  IgnoreUnused(axisIsValid);
40 
41  unsigned int uAxis = descriptor.m_Axis < 0 ?
42  numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
43  armnn::numeric_cast<unsigned int>(descriptor.m_Axis);
44 
45  const TensorShape& inputShape = inputInfo.GetShape();
46  const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
47  const unsigned int axisSize = inputShape[uAxis];
48  const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
49  uAxis + 1,
50  inputShape.GetNumDimensions());
51 
52  for (unsigned int outer = 0; outer < outerSize; ++outer)
53  {
54  for (unsigned int inner = 0; inner < innerSize; ++inner)
55  {
56  // Find max
57  input[outer * axisSize * innerSize + inner];
58  float maxValue = input.Get();
59  for (unsigned int i = 1u; i < axisSize; ++i)
60  {
61  input[(outer * axisSize + i) * innerSize + inner];
62  maxValue = std::max(maxValue, input.Get());
63  }
64 
65  // Compute sum
66  float sum = 0.0f;
67  for (unsigned int i = 0u; i < axisSize; ++i)
68  {
69  input[(outer * axisSize + i) * innerSize + inner];
70  sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
71  }
72 
73  // Compute log sum
74  const float logSum = std::log(sum);
75 
76  // Compute result
77  for (unsigned int i = 0u; i < axisSize; ++i)
78  {
79  const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80 
81  input [index];
82  output[index];
83 
84  output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
85  }
86  }
87  }
88 }
89 
90 } // namespace armnn
armnn::Decoder< float >
armnn::Encoder::Set
virtual void Set(IType right)=0
armnn::SoftmaxDescriptor::m_Beta
float m_Beta
Exponentiation value.
Definition: Descriptors.hpp:190
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::LogSoftmax
void LogSoftmax(Decoder< float > &input, Encoder< float > &output, const TensorInfo &inputInfo, const LogSoftmaxDescriptor &descriptor)
Definition: LogSoftmax.cpp:29
LogSoftmax.hpp
IgnoreUnused.hpp
ARMNN_ASSERT_MSG
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
NumericCast.hpp
TensorUtils.hpp
Assert.hpp
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Encoder< float >
armnn::TensorShape::GetNumDimensions
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
armnn::Decoder::Get
virtual IType Get() const =0
armnnUtils::GetNumElementsBetween
unsigned int GetNumElementsBetween(const armnn::TensorShape &shape, unsigned int firstAxisInclusive, unsigned int lastAxisExclusive)
Definition: TensorUtils.cpp:209
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::SoftmaxDescriptor::m_Axis
int m_Axis
Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed o...
Definition: Descriptors.hpp:192
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::SoftmaxDescriptor
A SoftmaxDescriptor for the SoftmaxLayer.
Definition: Descriptors.hpp:177