Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not quantized layers.
The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer combined with the parameters of the child BatchNorm layer.
29 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30 Layer& child = connection.GetOwningLayer();
37 if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
39 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40 auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41 auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
44 BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45 auto epsilon = batchNormDescriptor.m_Eps;
48 ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(
true));
49 ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(
true));
50 ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(
true));
51 ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(
true));
53 auto convDescriptor = convLayer->GetParameters();
54 ConstTensor weightsTensor;
56 "FuseBatchNorm: Weight data should not be null.");
58 ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
59 &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
61 weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
62 weightLayer->m_LayerOutput->Map(
true));
65 auto weightsShape = weightsTensor.GetInfo().GetShape();
66 const unsigned int inputChannels = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
67 const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
68 const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
69 const unsigned int weightsHeight = depthwise ? weightsShape[1] :
70 weightsShape[dataLayout.GetHeightIndex()];
71 const unsigned int weightsWidth = depthwise ? weightsShape[2] :
72 weightsShape[dataLayout.GetWidthIndex()];
74 const auto* weightsBuffer =
static_cast<const T*
>(weightsTensor.GetMemoryArea());
75 const auto* betaBuffer =
static_cast<const T*
>(betaTensor.GetMemoryArea());
76 const auto* gammaBuffer =
static_cast<const T*
>(gammaTensor.GetMemoryArea());
77 const auto* meanBuffer =
static_cast<const T*
>(meanTensor.GetMemoryArea());
78 const auto* varBuffer =
static_cast<const T*
>(varTensor.GetMemoryArea());
80 std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
81 std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
82 std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
83 std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
84 std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
87 std::vector<T> fusedWeightsVector(weightsVector.size());
89 for (
unsigned int cInput = 0; cInput < inputChannels; ++cInput)
91 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
93 T mult = gammaVector[cOut] /
static_cast<T
>(sqrtf(varianceVector[cOut] + epsilon));
95 for (
unsigned int h = 0; h < weightsHeight; ++h)
97 for (
unsigned int w = 0; w < weightsWidth; ++w)
99 unsigned int weightsIdx = 0;
103 cInput = cOut / depthMultiplier;
104 weightsIdx = w * outputChannels + cOut +
105 h * weightsWidth * outputChannels;
109 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
110 h * weightsWidth * inputChannels +
116 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
117 cInput * weightsWidth * weightsHeight +
121 fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
126 ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector);
129 std::vector<T> fusedBiasVector(outputChannels);
130 bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled;
131 if (biasWasEnabledBeforeOpt)
133 ConstTensor biasTensor;
135 "FuseBatchNorm: Bias data should not be null if bias is enabled.");
137 ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
138 &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
140 biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
141 biasLayer->m_LayerOutput->Map(
true));
143 const auto* biasBuffer =
static_cast<const T*
>(biasTensor.GetMemoryArea());
144 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
146 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
148 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
149 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
154 convDescriptor.m_BiasEnabled =
true;
155 std::vector<T> biasVector(outputChannels, T(0));
157 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
159 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
160 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
163 ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType, 0.0f, 0,
true), fusedBiasVector);
166 const std::string name = std::string(
"fused-") + child.GetName() + std::string(
"-into-") + base.GetName();
167 auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
173 if (newConv2dLayer.GetNumInputSlots() > 1)
176 weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
177 weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));
178 weightLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
181 ConstantLayer* biasLayer;
182 if (biasWasEnabledBeforeOpt)
184 biasLayer = PolymorphicDowncast<ConstantLayer*>(
185 &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
187 biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2));
188 biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
195 biasLayer = graph.AddLayer<ConstantLayer>(
"Bias");
196 biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo());
197 biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
199 biasLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
204 newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
206 parentOut = &newConv2dLayer.GetOutputSlot();
211 child.GetOutputSlot().MoveAllConnections(*parentOut);