ArmNN
 25.11
Loading...
Searching...
No Matches
BatchNormImpl.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchNormImpl.hpp"
8
9#include <armnn/Tensor.hpp>
10
12
13#include <cmath>
14
15namespace armnn
16{
17
19 Decoder<float>& meanDecoder,
20 Decoder<float>& varianceDecoder,
21 Decoder<float>& betaDecoder,
22 Decoder<float>& gammaDecoder,
23 Decoder<float>& inputDecoder,
24 Encoder<float>& outputEncoder)
25{
26 const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
27 const TensorShape inputShape = inputInfo.GetShape();
28
30
31 unsigned int inputBatches = inputShape[0];
32 unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
33 unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
34 unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
35
36 for (unsigned int c = 0; c < inputChannels; c++)
37 {
38 meanDecoder[c];
39 varianceDecoder[c];
40 betaDecoder[c];
41 gammaDecoder[c];
42 float mean = meanDecoder.Get();
43 float var = varianceDecoder.Get();
44 float beta = betaDecoder.Get();
45 float gamma = gammaDecoder.Get();
46
47 float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
48 float add = beta - mult * mean;
49
50 for (unsigned int n = 0; n < inputBatches; n++)
51 {
52 for (unsigned int h = 0; h < inputHeight; h++)
53 {
54 for (unsigned int w = 0; w < inputWidth; w++)
55 {
56 unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
57 inputDecoder[index];
58 outputEncoder[index];
59 outputEncoder.Set(mult * inputDecoder.Get() + add);
60 }
61 }
62 }
63 }
64}
65
66} // namespace armnn
virtual IType Get() const =0
virtual void Set(IType right)=0
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
unsigned int GetHeightIndex() const
unsigned int GetChannelsIndex() const
Copyright (c) 2021 ARM Limited and Contributors.
void BatchNormImpl(const BatchNormalizationQueueDescriptor &data, Decoder< float > &meanDecoder, Decoder< float > &varianceDecoder, Decoder< float > &betaDecoder, Decoder< float > &gammaDecoder, Decoder< float > &inputDecoder, Encoder< float > &outputEncoder)
armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, unsigned int numberOfChannels, unsigned int height, unsigned int width, const armnn::DataLayout dataLayout, const armnn::DataType dataType)
float m_Eps
Value to add to the variance. Used to avoid dividing by zero.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
std::vector< ITensorHandle * > m_Inputs