ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
BatchNormImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchNormImpl.hpp"
7 #include "RefWorkloadUtils.hpp"
8 
9 #include <armnn/Tensor.hpp>
10 
12 
13 #include <cmath>
14 
15 namespace armnn
16 {
17 
19  Decoder<float>& meanDecoder,
20  Decoder<float>& varianceDecoder,
21  Decoder<float>& betaDecoder,
22  Decoder<float>& gammaDecoder,
23  Decoder<float>& inputDecoder,
24  Encoder<float>& outputEncoder)
25 {
26  const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
27  const TensorShape inputShape = inputInfo.GetShape();
28 
30 
31  unsigned int inputBatches = inputShape[0];
32  unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
33  unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
34  unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
35 
36  for (unsigned int c = 0; c < inputChannels; c++)
37  {
38  meanDecoder[c];
39  varianceDecoder[c];
40  betaDecoder[c];
41  gammaDecoder[c];
42  float mean = meanDecoder.Get();
43  float var = varianceDecoder.Get();
44  float beta = betaDecoder.Get();
45  float gamma = gammaDecoder.Get();
46 
47  float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
48  float add = beta - mult * mean;
49 
50  for (unsigned int n = 0; n < inputBatches; n++)
51  {
52  for (unsigned int h = 0; h < inputHeight; h++)
53  {
54  for (unsigned int w = 0; w < inputWidth; w++)
55  {
56  unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
57  inputDecoder[index];
58  outputEncoder[index];
59  outputEncoder.Set(mult * inputDecoder.Get() + add);
60  }
61  }
62  }
63  }
64 }
65 
66 } // namespace armnn
virtual IType Get() const =0
virtual void Set(IType right)=0
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
unsigned int GetWidthIndex() const
unsigned int GetHeightIndex() const
unsigned int GetChannelsIndex() const
Copyright (c) 2021 ARM Limited and Contributors.
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void BatchNormImpl(const BatchNormalizationQueueDescriptor &data, Decoder< float > &meanDecoder, Decoder< float > &varianceDecoder, Decoder< float > &betaDecoder, Decoder< float > &gammaDecoder, Decoder< float > &inputDecoder, Encoder< float > &outputEncoder)
float m_Eps
Value to add to the variance. Used to avoid dividing by zero.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
std::vector< ITensorHandle * > m_Inputs