ArmNN
 24.08
Optimization.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Graph.hpp"
8 #include "LayersFwd.hpp"
9 
11 
12 namespace armnn
13 {
14 
16 {
17 public:
18  Optimization() = default;
19  virtual ~Optimization() = default;
20  virtual void Run(Graph& graph, Layer& base) const = 0;
21 protected:
22 };
23 
24 // Wrappers
25 // The implementation of the following wrappers make use of the CRTP C++ idiom
26 // (curiously recurring template pattern).
27 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
28 
29 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
30 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
31 /// after applying each optimization.
32 template <typename BaseType, typename Wrapped>
33 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
34 {
35 public:
36  using Wrapped::Wrapped;
37 
38  void Run(Graph& graph, Layer& base) const override
39  {
40  if (base.GetType() == LayerEnumOf<BaseType>())
41  {
42  Wrapped::Run(graph, *PolymorphicDowncast<BaseType*>(&base));
43  }
44  }
45 
46 protected:
47  ~OptimizeForTypeImpl() = default;
48 };
49 
50 /// Specialization that calls Wrapped::Run() for any layer type.
51 template <typename Wrapped>
52 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
53 {
54 public:
55  using Wrapped::Wrapped;
56 
57  void Run(Graph& graph, Layer& base) const override
58  {
59  Wrapped::Run(graph, base);
60  }
61 
62 protected:
63  ~OptimizeForTypeImpl() = default;
64 };
65 
66 template <typename BaseType, typename Wrapped>
67 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
68 {
69 public:
71 };
72 
73 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
74 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
75 /// after applying each optimization.
76 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
77 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
78 template <typename BaseType, typename ChildType, typename Wrapped>
79 class OptimizeForConnectionImpl : public Wrapped
80 {
81 public:
82  using Wrapped::Wrapped;
83 
84  void Run(Graph& graph, BaseType& base) const
85  {
86  for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
87  {
88  for (auto&& childInput : output->GetConnections())
89  {
90  if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
91  {
92  Wrapped::Run(graph, *childInput);
93  }
94  }
95 
96  // Removes unconnected children.
97  for (unsigned int i = 0; i < output->GetNumConnections();)
98  {
99  Layer* child = &output->GetConnection(i)->GetOwningLayer();
100 
101  if (child->IsOutputUnconnected())
102  {
103  graph.EraseLayer(child);
104  }
105  else
106  {
107  ++i;
108  }
109  }
110  }
111  }
112 
113 protected:
114  ~OptimizeForConnectionImpl() = default;
115 };
116 
117 template <typename BaseType, typename ChildType, typename Wrapped>
119  : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
120 {
121 public:
123 };
124 
125 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
126 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
127 /// after applying each optimization.
128 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
129 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
130 template <typename BaseType, typename ChildType, typename Wrapped>
132 {
133 public:
134  using Wrapped::Wrapped;
135 
136  void Run(Graph& graph, BaseType& base) const
137  {
138  for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
139  {
140  if (output->GetNumConnections() == 1)
141  {
142  for (auto&& childInput : output->GetConnections())
143  {
144  if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
145  {
146  Wrapped::Run(graph, *childInput);
147  }
148  }
149 
150  // Removes unconnected children.
151  for (unsigned int i = 0; i < output->GetNumConnections();)
152  {
153  Layer* child = &output->GetConnection(i)->GetOwningLayer();
154 
155  if (child->IsOutputUnconnected())
156  {
157  graph.EraseLayer(child);
158  }
159  else
160  {
161  ++i;
162  }
163  }
164  }
165  }
166  }
167 
168 protected:
170 };
171 
172 template <typename BaseType, typename ChildType, typename Wrapped>
174  : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
175 {
176 public:
177  using OptimizeForTypeImpl<BaseType,
179 };
180 
181 } // namespace armnn
armnn::OptimizeForTypeImpl::~OptimizeForTypeImpl
~OptimizeForTypeImpl()=default
armnn::OptimizeForTypeImpl::Run
void Run(Graph &graph, Layer &base) const override
Definition: Optimization.hpp:38
armnn::Graph::EraseLayer
void EraseLayer(Iterator pos)
Deletes the layer at the specified position.
Definition: Graph.hpp:517
armnn::OptimizeForExclusiveConnectionImpl::Run
void Run(Graph &graph, BaseType &base) const
Definition: Optimization.hpp:136
Graph.hpp
armnn::Optimization::~Optimization
virtual ~Optimization()=default
armnn::Optimization::Optimization
Optimization()=default
armnn::OptimizeForTypeImpl
Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
Definition: Optimization.hpp:33
armnn::Layer::IsOutputUnconnected
bool IsOutputUnconnected()
Definition: Layer.hpp:270
armnn::Layer
Definition: Layer.hpp:230
armnn::OptimizeForConnection
Definition: Optimization.hpp:118
PolymorphicDowncast.hpp
LayersFwd.hpp
armnn::Optimization::Run
virtual void Run(Graph &graph, Layer &base) const =0
armnn::OptimizeForConnectionImpl::~OptimizeForConnectionImpl
~OptimizeForConnectionImpl()=default
armnn::OptimizeForType
Definition: Optimization.hpp:67
armnn::OptimizeForExclusiveConnectionImpl
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
Definition: Optimization.hpp:131
armnn::Layer::GetType
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:286
armnn::Optimization
Definition: Optimization.hpp:15
armnn::OptimizeForExclusiveConnectionImpl::~OptimizeForExclusiveConnectionImpl
~OptimizeForExclusiveConnectionImpl()=default
armnn::OptimizeForTypeImpl< Layer, Wrapped >::Run
void Run(Graph &graph, Layer &base) const override
Definition: Optimization.hpp:57
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::OptimizeForConnectionImpl::Run
void Run(Graph &graph, BaseType &base) const
Definition: Optimization.hpp:84
armnn::OptimizeForConnectionImpl
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
Definition: Optimization.hpp:79
armnn::Graph
Definition: Graph.hpp:30
armnn::OptimizeForExclusiveConnection
Definition: Optimization.hpp:173