ArmNN
 24.02
Activation.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Activation.hpp"
7 
8 #include <cmath>
9 
10 namespace armnn
11 {
12 
13 float Activation(float in,
14  ActivationFunction function,
15  float a,
16  float b)
17 {
18  float output;
19 
20  // Compute the result of the activation function.
21  switch (function)
22  {
24  {
25  output = a * in + b;
26  break;
27  }
29  {
30  output = 1.f / (1.f + expf(-in));
31  break;
32  }
34  {
35  output = std::max(0.f, in);
36  break;
37  }
39  {
40  output = std::min(a, std::max(b, in));
41  break;
42  }
44  {
45  output = logf(1.0f + expf(in));
46  break;
47  }
49  {
50  output = in > 0.0f ? in : (in * a);
51  break;
52  }
54  {
55  output = in < 0 ? -in : in;
56  break;
57  }
59  {
60  output = sqrtf(in);
61  break;
62  }
64  {
65  output = in * in;
66  break;
67  }
69  {
70  output = a * tanhf(b * in);
71  break;
72  }
74  {
75  output = (in >= 0) ? in : a * (expf(in) - 1);
76  break;
77  }
79  {
80  // hard_swish(x) = x * relu6(x+3) / 6
81  // relu6(x) = min(max(x,0),6)
82  output = in * (std::min(std::max((in + 3),0.0f),6.0f)) / 6;
83  break;
84  }
86  {
87  // gelu(x) = x * 1/2 * (1 + erf(x / sqrt(2))),
88  // where erf is Gaussian error function
89  output = in * (0.5f * (1.0f + erff(static_cast<float>(in / std::sqrt(2)))));
90  break;
91  }
92  default:
93  {
94  throw InvalidArgumentException("Unsupported activation function");
95  }
96  }
97 
98  return output;
99 }
100 
101 
103  Encoder<float>& out,
104  const TensorInfo& tensorInfo,
105  ActivationFunction function,
106  float a,
107  float b)
108 {
109  unsigned int numElements = tensorInfo.GetNumElements();
110 
111  for (unsigned int i = 0; i < numElements; i++)
112  {
113  out.Set(Activation(in.Get(), function, a, b));
114  ++in;
115  ++out;
116  }
117  in -= numElements;
118  out -= numElements;
119 }
120 
121 } //namespace armnn
armnn::Decoder< float >
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
armnn::Encoder::Set
virtual void Set(IType right)=0
armnn::ActivationFunction::LeakyReLu
@ LeakyReLu
armnn::ActivationFunction::SoftReLu
@ SoftReLu
armnn::ActivationFunction::Sqrt
@ Sqrt
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::ActivationFunction::TanH
@ TanH
armnn::ActivationFunction::BoundedReLu
@ BoundedReLu
min(a, max(b, input)) ReLu1 & ReLu6.
armnn::ActivationFunction::HardSwish
@ HardSwish
armnn::ActivationFunction::Gelu
@ Gelu
armnn::Encoder< float >
armnn::ActivationFunction::Elu
@ Elu
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::ActivationFunction::Linear
@ Linear
Activation.hpp
armnn::ActivationFunction
ActivationFunction
Definition: Types.hpp:86
armnn::Decoder::Get
virtual IType Get() const =0
armnn::ActivationFunction::Abs
@ Abs
armnn::ActivationFunction::ReLu
@ ReLu
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::ActivationFunction::Square
@ Square
armnn::Activation
float Activation(float in, ActivationFunction function, float a, float b)
Definition: Activation.cpp:13
armnn::ActivationFunction::Sigmoid
@ Sigmoid