ArmNN
 25.11
Loading...
Searching...
No Matches
Optimization.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Graph.hpp"
8#include "LayersFwd.hpp"
9
11
12namespace armnn
13{
14
16{
17public:
18 Optimization() = default;
19 virtual ~Optimization() = default;
20 virtual void Run(Graph& graph, Layer& base) const = 0;
21protected:
22};
23
24// Wrappers
25// The implementation of the following wrappers make use of the CRTP C++ idiom
26// (curiously recurring template pattern).
27// For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
28
29/// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
30/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
31/// after applying each optimization.
32template <typename BaseType, typename Wrapped>
33class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
34{
35public:
36 using Wrapped::Wrapped;
37
38 void Run(Graph& graph, Layer& base) const override
39 {
40 if (base.GetType() == LayerEnumOf<BaseType>())
41 {
42 Wrapped::Run(graph, *PolymorphicDowncast<BaseType*>(&base));
43 }
44 }
45
46protected:
48};
49
50/// Specialization that calls Wrapped::Run() for any layer type.
51template <typename Wrapped>
52class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
53{
54public:
55 using Wrapped::Wrapped;
56
57 void Run(Graph& graph, Layer& base) const override
58 {
59 Wrapped::Run(graph, base);
60 }
61
62protected:
64};
65
66template <typename BaseType, typename Wrapped>
67class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
68{
69public:
70 using OptimizeForTypeImpl<BaseType, Wrapped>::OptimizeForTypeImpl;
71};
72
73/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
74/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
75/// after applying each optimization.
76/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
77/// - Children layers are removed if left unconnected after applying the wrapped optimization.
78template <typename BaseType, typename ChildType, typename Wrapped>
79class OptimizeForConnectionImpl : public Wrapped
80{
81public:
82 using Wrapped::Wrapped;
83
84 void Run(Graph& graph, BaseType& base) const
85 {
86 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
87 {
88 for (auto&& childInput : output->GetConnections())
89 {
90 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
91 {
92 Wrapped::Run(graph, *childInput);
93 }
94 }
95
96 // Removes unconnected children.
97 for (unsigned int i = 0; i < output->GetNumConnections();)
98 {
99 Layer* child = &output->GetConnection(i)->GetOwningLayer();
100
101 if (child->IsOutputUnconnected())
102 {
103 graph.EraseLayer(child);
104 }
105 else
106 {
107 ++i;
108 }
109 }
110 }
111 }
112
113protected:
115};
116
117template <typename BaseType, typename ChildType, typename Wrapped>
119 : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
120{
121public:
123};
124
125/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
126/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
127/// after applying each optimization.
128/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
129/// - Children layers are removed if left unconnected after applying the wrapped optimization.
130template <typename BaseType, typename ChildType, typename Wrapped>
132{
133public:
134 using Wrapped::Wrapped;
135
136 void Run(Graph& graph, BaseType& base) const
137 {
138 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
139 {
140 if (output->GetNumConnections() == 1)
141 {
142 for (auto&& childInput : output->GetConnections())
143 {
144 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
145 {
146 Wrapped::Run(graph, *childInput);
147 }
148 }
149
150 // Removes unconnected children.
151 for (unsigned int i = 0; i < output->GetNumConnections();)
152 {
153 Layer* child = &output->GetConnection(i)->GetOwningLayer();
154
155 if (child->IsOutputUnconnected())
156 {
157 graph.EraseLayer(child);
158 }
159 else
160 {
161 ++i;
162 }
163 }
164 }
165 }
166 }
167
168protected:
170};
171
172template <typename BaseType, typename ChildType, typename Wrapped>
174 : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
175{
176public:
177 using OptimizeForTypeImpl<BaseType,
179};
180
181} // namespace armnn
void EraseLayer(Iterator pos)
Deletes the layer at the specified position.
Definition Graph.hpp:517
bool IsOutputUnconnected()
Definition Layer.hpp:270
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition Layer.hpp:286
virtual void Run(Graph &graph, Layer &base) const =0
Optimization()=default
virtual ~Optimization()=default
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
void Run(Graph &graph, BaseType &base) const
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
void Run(Graph &graph, BaseType &base) const
void Run(Graph &graph, Layer &base) const override
Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
void Run(Graph &graph, Layer &base) const override
Copyright (c) 2021 ARM Limited and Contributors.
constexpr LayerType LayerEnumOf(const T *=nullptr)
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.