ArmNN
 25.11
Loading...
Searching...
No Matches
FuseBatchNorm.hpp
Go to the documentation of this file.
1//
2// Copyright © 2020,2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Optimization.hpp"
10#include <ResolveType.hpp>
11
12namespace armnn
13{
14namespace optimizations
15{
16
17template<typename ConvLayer, armnn::DataType ArmnnType,
20{
21public:
22 /// Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not
23 /// quantized layers.
24 /// The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will
25 /// be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer
26 /// combined with the parameters of the child BatchNorm layer.
27 void Run(Graph& graph, InputSlot& connection) const
28 {
29 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30 Layer& child = connection.GetOwningLayer();
31
32 bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33
34 ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
36
37 if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38 {
39 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40 auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41 auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42
43 // Read convolution and batch norm parameters
44 BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45 auto epsilon = batchNormDescriptor.m_Eps;
46 IgnoreUnused(epsilon);
47
48 ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49 ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50 ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51 ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52
53 auto convDescriptor = convLayer->GetParameters();
54 ConstTensor weightsTensor;
55 ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
56 "FuseBatchNorm: Weight data should not be null.");
57
60
61 weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
62 weightLayer->m_LayerOutput->Map(true));
63
64 armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
65 auto weightsShape = weightsTensor.GetInfo().GetShape();
66 const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
67 const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
68 const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
69 const unsigned int weightsHeight = depthwise ? weightsShape[1] :
70 weightsShape[dataLayout.GetHeightIndex()];
71 const unsigned int weightsWidth = depthwise ? weightsShape[2] :
72 weightsShape[dataLayout.GetWidthIndex()];
73
74 const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
75 const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea());
76 const auto* gammaBuffer = static_cast<const T*>(gammaTensor.GetMemoryArea());
77 const auto* meanBuffer = static_cast<const T*>(meanTensor.GetMemoryArea());
78 const auto* varBuffer = static_cast<const T*>(varTensor.GetMemoryArea());
79
80 std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
81 std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
82 std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
83 std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
84 std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
85
86 // fusedWeights = ( gamma * weights ) / ( std - epsilon);
87 std::vector<T> fusedWeightsVector(weightsVector.size());
88
89 for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
90 {
91 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
92 {
93 T mult = gammaVector[cOut] / static_cast<T>(sqrtf(varianceVector[cOut] + epsilon));
94
95 for (unsigned int h = 0; h < weightsHeight; ++h)
96 {
97 for (unsigned int w = 0; w < weightsWidth; ++w)
98 {
99 unsigned int weightsIdx = 0;
100
101 if (depthwise)
102 {
103 cInput = cOut / depthMultiplier;
104 weightsIdx = w * outputChannels + cOut +
105 h * weightsWidth * outputChannels;
106 }
107 else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
108 {
109 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
110 h * weightsWidth * inputChannels +
111 w * inputChannels +
112 cInput;
113 }
114 else
115 {
116 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
117 cInput * weightsWidth * weightsHeight +
118 h * weightsWidth +
119 w;
120 }
121 fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
122 }
123 }
124 }
125 }
126 ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector);
127
128 // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
129 std::vector<T> fusedBiasVector(outputChannels);
130 bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled;
131 if (biasWasEnabledBeforeOpt)
132 {
133 ConstTensor biasTensor;
134 ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
135 "FuseBatchNorm: Bias data should not be null if bias is enabled.");
136
139
140 biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
141 biasLayer->m_LayerOutput->Map(true));
142
143 const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
144 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
145
146 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
147 {
148 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
149 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
150 }
151 }
152 else
153 {
154 convDescriptor.m_BiasEnabled = true;
155 std::vector<T> biasVector(outputChannels, T(0));
156
157 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
158 {
159 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
160 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
161 }
162 }
163 ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType, 0.0f, 0, true), fusedBiasVector);
164
165 // Insert the new convolution layer that has batch norm parameters fused into
166 const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
167 auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
168 convDescriptor,
169 name.c_str());
170
171 // Connect weights and bias from old to new Conv2d layer
172 // This optimization will always have 3 input slots on the Conv2d base layer
173 if (newConv2dLayer.GetNumInputSlots() > 1)
174 {
175 // Remove old connection and connect to new layer2d
176 weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
177 weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));
178 weightLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
179
180 // Move bias const layers as normal if it was enabled before the optimisation
181 ConstantLayer* biasLayer;
182 if (biasWasEnabledBeforeOpt)
183 {
186 // Remove old connection and connect to new layer2d
187 biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2));
188 biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
189
190 }
191 // Otherwise create a new bias layer and add to the new convolution2d
192 else
193 {
194 // Add in bias constant layer
195 biasLayer = graph.AddLayer<ConstantLayer>("Bias");
196 biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo());
197 biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
198 }
199 biasLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
200 }
201
202
203 // Reconnects with original parent.
204 newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
205 // Parent is now the new convolution2d layer.
206 parentOut = &newConv2dLayer.GetOutputSlot();
207
208 // Moves connections in child output to parent layer.
209 // Child layer will be removed as it's left unconnected.
210 // Base layer will be removed if left unconnected.
211 child.GetOutputSlot().MoveAllConnections(*parentOut);
212 }
213 }
214protected:
215 FuseBatchNorm() = default;
216 ~FuseBatchNorm() = default;
217};
218
223
228
233
238
239} // namespace optimizations
240} // namespace armnn
#define ARMNN_ASSERT(COND)
Definition Assert.hpp:14
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition Assert.hpp:15
const TensorInfo & GetInfo() const
Definition Tensor.hpp:297
unsigned int GetNumElements() const
Definition Tensor.hpp:305
MemoryType GetMemoryArea() const
Definition Tensor.hpp:307
This layer represents a batch normalization operation.
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
A layer that the constant data can be bound to.
std::shared_ptr< ConstTensorHandle > m_LayerOutput
This layer represents a convolution 2d operation.
This layer represents a depthwise convolution 2d operation.
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition Graph.hpp:481
LayerT * AddLayer(Args &&... args)
Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
Definition Graph.hpp:466
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition Layer.hpp:286
DataType GetDataType() const
Definition Layer.cpp:345
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition Layer.cpp:156
void SetTensorInfo(const TensorInfo &tensorInfo) override
Definition Layer.cpp:95
Layer & GetOwningLayer() const
Definition Layer.hpp:132
void Disconnect(InputSlot &slot)
Definition Layer.cpp:131
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
int Connect(InputSlot &destination)
Definition Layer.cpp:123
const TensorShape & GetShape() const
Definition Tensor.hpp:193
void Run(Graph &graph, InputSlot &connection) const
Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for...
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
unsigned int GetHeightIndex() const
unsigned int GetChannelsIndex() const
OptimizeForExclusiveConnection< Convolution2dLayer, BatchNormalizationLayer, FuseBatchNorm< Convolution2dLayer, armnn::DataType::Float16 > > FuseBatchNormIntoConvolution2DFloat16
OptimizeForExclusiveConnection< DepthwiseConvolution2dLayer, BatchNormalizationLayer, FuseBatchNorm< DepthwiseConvolution2dLayer, armnn::DataType::Float16 > > FuseBatchNormIntoDepthwiseConvolution2DFloat16
OptimizeForExclusiveConnection< DepthwiseConvolution2dLayer, BatchNormalizationLayer, FuseBatchNorm< DepthwiseConvolution2dLayer, armnn::DataType::Float32 > > FuseBatchNormIntoDepthwiseConvolution2DFloat32
OptimizeForExclusiveConnection< Convolution2dLayer, BatchNormalizationLayer, FuseBatchNorm< Convolution2dLayer, armnn::DataType::Float32 > > FuseBatchNormIntoConvolution2DFloat32
Copyright (c) 2021 ARM Limited and Contributors.
typename ResolveTypeImpl< DT >::Type ResolveType
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
DataType
Definition Types.hpp:49
void IgnoreUnused(Ts &&...)
A BatchNormalizationDescriptor for the BatchNormalizationLayer.
float m_Eps
Value to add to the variance. Used to avoid dividing by zero.