ArmNN
 25.11
Loading...
Searching...
No Matches
OptimizeInversePermutes.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
11
12namespace armnn
13{
14namespace optimizations
15{
16
17template <typename PermuteType>
19{
20public:
21 /// Run for every connection between a base PermuteLayer and a child PermuteLayer.
22 /// Bypasses both layers for that connection if one is the inverse of the other.
23 void Run(Graph& graph, InputSlot& connection) const
24 {
25 IgnoreUnused(graph);
26 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
27 auto child = PolymorphicDowncast<PermuteType*>(&connection.GetOwningLayer());
28
29 if (child->IsInverse(*PolymorphicDowncast<PermuteType*>(&base)))
30 {
31 // Bypass both layers. Child will be removed as it's left unconnected.
32 // Base layer will be removed if left unconnected.
33 child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
34 }
35 }
36
37protected:
40};
41
46
47} // namespace optimizations
48} // namespace armnn
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
Layer & GetOwningLayer() const
Definition Layer.hpp:132
This layer represents a permutation operation.
This layer represents a transpose operation.
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base PermuteLayer and a child PermuteLayer.
OptimizeForConnection< PermuteLayer, PermuteLayer, OptimizeInversePermutesImpl< PermuteLayer > > OptimizeInversePermutes
OptimizeForConnection< TransposeLayer, TransposeLayer, OptimizeInversePermutesImpl< TransposeLayer > > OptimizeInverseTransposes
Copyright (c) 2021 ARM Limited and Contributors.
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
void IgnoreUnused(Ts &&...)