ArmNN
 24.08
SquashEqualSiblings.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename Comparable>
19 {
20 public:
21  /// Run for every connection between a base Layer (any) and a child ComparableLayer.
22  /// For all siblings of the child layer that compare equal to it, bypasses and removes
23  /// them. I.e., moves the connections in the outputs of the siblings to the outputs of
24  /// the child layer, so the siblings are left unconnected (and later removed).
25  void Run(Graph& graph, InputSlot& connection) const
26  {
27  IgnoreUnused(graph);
28  auto& child = connection.GetOwningLayer();
29 
30  if (!child.IsOutputUnconnected())
31  {
32  OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
33 
34  if (baseOutput.GetNumConnections() > 1)
35  {
36  auto& comparableChild = *PolymorphicDowncast<Comparable*>(&child);
37 
38  Layer* lowestPriorityChild = &child;
39  for (auto&& it : baseOutput.GetConnections())
40  {
41  Layer* sibling = &it->GetOwningLayer();
42  if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
43  {
44  if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
45  {
46  std::swap(sibling, lowestPriorityChild);
47  }
48  // Bypasses sibling. It will be removed as it's left unconnected.
49  auto siblingOut = sibling->BeginOutputSlots();
50  for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
51  lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
52  {
53  siblingOut->MoveAllConnections(*lowestPriorityChildOut);
54  ++siblingOut;
55  }
56  }
57  }
58  }
59  }
60  }
61 
62 protected:
63  SquashEqualSiblingsImpl() = default;
64  ~SquashEqualSiblingsImpl() = default;
65 };
66 
71 
72 } // namespace optimizations
73 } // namespace armnn
armnn::InputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:53
armnn::OutputSlot
Definition: Layer.hpp:100
armnn::swap
void swap(OriginsDescriptor &first, OriginsDescriptor &second)
Definition: Descriptors.cpp:357
IgnoreUnused.hpp
Optimization.hpp
armnn::optimizations::SquashEqualSiblingsImpl::SquashEqualSiblingsImpl
SquashEqualSiblingsImpl()=default
armnn::optimizations::SquashEqualSiblingsImpl
Definition: SquashEqualSiblings.hpp:18
armnn::Layer
Definition: Layer.hpp:230
armnn::TransposeLayer
This layer represents a transpose operation.
Definition: TransposeLayer.hpp:15
armnn::optimizations::SquashEqualSiblingsImpl::~SquashEqualSiblingsImpl
~SquashEqualSiblingsImpl()=default
armnn::OptimizeForConnection
Definition: Optimization.hpp:118
armnn::OutputSlot::GetNumConnections
unsigned int GetNumConnections() const override
Definition: Layer.hpp:158
PolymorphicDowncast.hpp
armnn::optimizations::SquashEqualSiblingsImpl::Run
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base Layer (any) and a child ComparableLayer.
Definition: SquashEqualSiblings.hpp:25
armnn::InputSlot
Definition: Layer.hpp:42
armnn::Layer::BeginOutputSlots
std::vector< OutputSlot >::iterator BeginOutputSlots()
Definition: Layer.hpp:266
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn::OutputSlot::GetConnections
const std::vector< InputSlot * > & GetConnections() const
Definition: Layer.hpp:145
armnn::InputSlot::GetConnectedOutputSlot
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:56
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::Layer::GetPriority
LayerPriority GetPriority() const
Definition: Layer.cpp:360
armnn::Layer::EndOutputSlots
std::vector< OutputSlot >::iterator EndOutputSlots()
Definition: Layer.hpp:267
armnn::Graph
Definition: Graph.hpp:30