ArmNN
 25.11
Loading...
Searching...
No Matches
SquashEqualSiblings.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
11
12namespace armnn
13{
14namespace optimizations
15{
16
17template <typename Comparable>
19{
20public:
21 /// Run for every connection between a base Layer (any) and a child ComparableLayer.
22 /// For all siblings of the child layer that compare equal to it, bypasses and removes
23 /// them. I.e., moves the connections in the outputs of the siblings to the outputs of
24 /// the child layer, so the siblings are left unconnected (and later removed).
25 void Run(Graph& graph, InputSlot& connection) const
26 {
27 IgnoreUnused(graph);
28 auto& child = connection.GetOwningLayer();
29
30 if (!child.IsOutputUnconnected())
31 {
32 OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
33
34 if (baseOutput.GetNumConnections() > 1)
35 {
36 auto& comparableChild = *PolymorphicDowncast<Comparable*>(&child);
37
38 Layer* lowestPriorityChild = &child;
39 for (auto&& it : baseOutput.GetConnections())
40 {
41 Layer* sibling = &it->GetOwningLayer();
42 if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
43 {
44 if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
45 {
46 std::swap(sibling, lowestPriorityChild);
47 }
48 // Bypasses sibling. It will be removed as it's left unconnected.
49 auto siblingOut = sibling->BeginOutputSlots();
50 for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
51 lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
52 {
53 siblingOut->MoveAllConnections(*lowestPriorityChildOut);
54 ++siblingOut;
55 }
56 }
57 }
58 }
59 }
60 }
61
62protected:
65};
66
71
72} // namespace optimizations
73} // namespace armnn
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
std::vector< OutputSlot >::iterator BeginOutputSlots()
Definition Layer.hpp:266
LayerPriority GetPriority() const
Definition Layer.cpp:360
std::vector< OutputSlot >::iterator EndOutputSlots()
Definition Layer.hpp:267
unsigned int GetNumConnections() const override
Definition Layer.hpp:158
const std::vector< InputSlot * > & GetConnections() const
Definition Layer.hpp:145
This layer represents a transpose operation.
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base Layer (any) and a child ComparableLayer.
OptimizeForConnection< Layer, TransposeLayer, SquashEqualSiblingsImpl< TransposeLayer > > SquashEqualTransposeSiblings
OptimizeForConnection< Layer, ReshapeLayer, SquashEqualSiblingsImpl< ReshapeLayer > > SquashEqualReshapeSiblings
OptimizeForConnection< Layer, PermuteLayer, SquashEqualSiblingsImpl< PermuteLayer > > SquashEqualPermuteSiblings
Copyright (c) 2021 ARM Limited and Contributors.
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
void IgnoreUnused(Ts &&...)