14 namespace optimizations
40 auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41 auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
45 auto epsilon = batchNormDescriptor.
m_Eps;
48 ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(
true));
49 ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(
true));
50 ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(
true));
51 ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(
true));
53 auto convDescriptor = convLayer->GetParameters();
54 auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
55 ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(
true));
58 auto weightsShape = weightsInfo.GetShape();
59 const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1;
60 const unsigned int inputChannels = depthwise ? weightsShape[1] :
61 weightsShape[dataLayout.GetChannelsIndex()];
62 const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0];
63 const unsigned int weightsHeight = depthwise ? weightsShape[2] :
64 weightsShape[dataLayout.GetHeightIndex()];
65 const unsigned int weightsWidth = depthwise ? weightsShape[3] :
66 weightsShape[dataLayout.GetWidthIndex()];
68 const auto* weightsBuffer =
static_cast<const T*
>(weightsTensor.GetMemoryArea());
69 const auto* betaBuffer =
static_cast<const T*
>(betaTensor.GetMemoryArea());
70 const auto* gammaBuffer =
static_cast<const T*
>(gammaTensor.GetMemoryArea());
71 const auto* meanBuffer =
static_cast<const T*
>(meanTensor.GetMemoryArea());
72 const auto* varBuffer =
static_cast<const T*
>(varTensor.GetMemoryArea());
74 std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
75 std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements());
76 std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
77 std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements());
78 std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
81 std::vector<T> fusedWeightsVector(weightsVector.size());
82 unsigned int depthwiseMultiplierIdx = 0;
84 for (
unsigned int cInput = 0; cInput < inputChannels; ++cInput)
86 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
88 T mult = gammaVector[cOut] /
static_cast<T
>(sqrtf (varianceVector[cOut] + epsilon));
92 cInput = cOut / depthMultiplier;
93 depthwiseMultiplierIdx = cOut % depthMultiplier;
96 for (
unsigned int h = 0; h < weightsHeight; ++h)
98 for (
unsigned int w = 0; w < weightsWidth; ++w)
100 unsigned int weightsIdx = 0;
104 weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels +
105 cInput * weightsWidth * weightsHeight +
111 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
112 h * weightsWidth * inputChannels +
118 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
119 cInput * weightsWidth * weightsHeight +
123 fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
128 ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
131 std::vector<T> fusedBiasVector(outputChannels);
132 if (convDescriptor.m_BiasEnabled)
135 "FuseBatchNorm: Bias data should not be null if bias is enabled.");
137 ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(
true));
138 const auto* biasBuffer =
static_cast<const T*
>(biasTensor.GetMemoryArea());
139 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
141 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
143 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
144 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
149 convDescriptor.m_BiasEnabled =
true;
150 std::vector<T> biasVector(outputChannels, T(0));
152 for (
unsigned int cOut = 0; cOut < outputChannels; ++cOut)
154 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
155 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
161 const std::string name = std::string(
"fused-") + child.
GetName() + std::string(
"-into-") + base.
GetName();
165 newConv2dLayer.m_Weight = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
166 newConv2dLayer.m_Bias = std::make_unique<ScopedTensorHandle>(
ConstTensor(fusedBiasTensor));
169 newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
171 parentOut = &newConv2dLayer.GetOutputSlot();
191 BatchNormalizationLayer,
196 BatchNormalizationLayer,
201 BatchNormalizationLayer,
This layer represents a batch normalization operation.
This layer represents a depthwise convolution 2d operation.
Layer & GetOwningLayer() const
float m_Eps
Value to add to the variance. Used to avoid dividing by zero.
typename ResolveTypeImpl< DT >::Type ResolveType
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
#define ARMNN_ASSERT_MSG(COND, MSG)
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
#define ARMNN_ASSERT(COND)
void Run(Graph &graph, InputSlot &connection) const
Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for...
DataType GetDataType() const
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
const char * GetName() const override
Returns the name of the layer.
This layer represents a convolution 2d operation.
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
A BatchNormalizationDescriptor for the BatchNormalizationLayer.