Compute Library
 21.02
TopologicalSort.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
27 
28 #include "support/Iterable.h"
29 
30 #include <list>
31 #include <stack>
32 
33 namespace arm_compute
34 {
35 namespace graph
36 {
37 namespace detail
38 {
39 /** Checks if all the input dependencies of a node have been visited
40  *
41  * @param[in] node Node to check
42  * @param[in] visited Vector that contains the visited information
43  *
44  * @return True if all inputs dependencies have been visited else false
45  */
46 inline bool all_inputs_are_visited(const INode *node, const std::vector<bool> &visited)
47 {
48  ARM_COMPUTE_ERROR_ON(node == nullptr);
49  const Graph *graph = node->graph();
50  ARM_COMPUTE_ERROR_ON(graph == nullptr);
51 
52  bool are_all_visited = true;
53  for(const auto &input_edge_id : node->input_edges())
54  {
55  if(input_edge_id != EmptyNodeID)
56  {
57  const Edge *input_edge = graph->edge(input_edge_id);
58  ARM_COMPUTE_ERROR_ON(input_edge == nullptr);
59  ARM_COMPUTE_ERROR_ON(input_edge->producer() == nullptr);
60  if(!visited[input_edge->producer_id()])
61  {
62  are_all_visited = false;
63  break;
64  }
65  }
66  }
67 
68  return are_all_visited;
69 }
70 } // namespace detail
71 
72 std::vector<NodeID> bfs(Graph &g)
73 {
74  std::vector<NodeID> bfs_order_vector;
75 
76  // Created visited vector
77  std::vector<bool> visited(g.nodes().size(), false);
78 
79  // Create BFS queue
80  std::list<NodeID> queue;
81 
82  // Push inputs and mark as visited
83  for(auto &input : g.nodes(NodeType::Input))
84  {
85  if(input != EmptyNodeID)
86  {
87  visited[input] = true;
88  queue.push_back(input);
89  }
90  }
91 
92  // Push const nodes and mark as visited
93  for(auto &const_node : g.nodes(NodeType::Const))
94  {
95  if(const_node != EmptyNodeID)
96  {
97  visited[const_node] = true;
98  queue.push_back(const_node);
99  }
100  }
101 
102  // Iterate over vector and edges
103  while(!queue.empty())
104  {
105  // Dequeue a node from queue and process
106  NodeID n = queue.front();
107  bfs_order_vector.push_back(n);
108  queue.pop_front();
109 
110  const INode *node = g.node(n);
111  ARM_COMPUTE_ERROR_ON(node == nullptr);
112  for(const auto &eid : node->output_edges())
113  {
114  const Edge *e = g.edge(eid);
115  ARM_COMPUTE_ERROR_ON(e == nullptr);
116  if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
117  {
118  visited[e->consumer_id()] = true;
119  queue.push_back(e->consumer_id());
120  }
121  }
122  }
123 
124  return bfs_order_vector;
125 }
126 
127 std::vector<NodeID> dfs(Graph &g)
128 {
129  std::vector<NodeID> dfs_order_vector;
130 
131  // Created visited vector
132  std::vector<bool> visited(g.nodes().size(), false);
133 
134  // Create DFS stack
135  std::stack<NodeID> stack;
136 
137  // Push inputs and mark as visited
138  for(auto &input : g.nodes(NodeType::Input))
139  {
140  if(input != EmptyNodeID)
141  {
142  visited[input] = true;
143  stack.push(input);
144  }
145  }
146 
147  // Push const nodes and mark as visited
148  for(auto &const_node : g.nodes(NodeType::Const))
149  {
150  if(const_node != EmptyNodeID)
151  {
152  visited[const_node] = true;
153  stack.push(const_node);
154  }
155  }
156 
157  // Iterate over vector and edges
158  while(!stack.empty())
159  {
160  // Pop a node from stack and process
161  NodeID n = stack.top();
162  dfs_order_vector.push_back(n);
163  stack.pop();
164 
165  // Mark node as visited
166  if(!visited[n])
167  {
168  visited[n] = true;
169  }
170 
171  const INode *node = g.node(n);
172  ARM_COMPUTE_ERROR_ON(node == nullptr);
173  // Reverse iterate to push branches from right to left and pop on the opposite order
174  for(const auto &eid : arm_compute::utils::iterable::reverse_iterate(node->output_edges()))
175  {
176  const Edge *e = g.edge(eid);
177  ARM_COMPUTE_ERROR_ON(e == nullptr);
178  if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
179  {
180  stack.push(e->consumer_id());
181  }
182  }
183  }
184 
185  return dfs_order_vector;
186 }
187 } // namespace graph
188 } // namespace arm_compute
INode * consumer() const
Returns consumer node.
Definition: Edge.h:92
const std::set< EdgeID > & output_edges() const
Returns output edge set.
Definition: INode.cpp:132
const Graph * graph() const
Returns node&#39;s Graph.
Definition: INode.cpp:112
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
Copyright (c) 2017-2021 Arm Limited.
Node interface.
Definition: INode.h:45
const std::vector< EdgeID > & input_edges() const
Returns input edge set.
Definition: INode.cpp:127
NodeID producer_id() const
Returns producer node id.
Definition: Edge.h:68
std::vector< NodeID > dfs(Graph &g)
Depth first search traversal.
Graph class.
Definition: Graph.h:53
unsigned int NodeID
Definition: Types.h:66
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
Definition: Graph.cpp:174
Graph Edge.
Definition: Edge.h:39
constexpr NodeID EmptyNodeID
Constant EdgeID specifying an equivalent of null edge.
Definition: Types.h:73
const INode * node(NodeID id) const
Get node object given its id.
Definition: Graph.cpp:204
const Edge * edge(EdgeID id) const
Get edge object given its id.
Definition: Graph.cpp:214
NodeID consumer_id() const
Returns sink node id.
Definition: Edge.h:76
std::vector< NodeID > bfs(Graph &g)
Breadth first search traversal.
reverse_iterable< T > reverse_iterate(T &val)
Creates a reverse iterable for a given type.
Definition: Iterable.h:101
bool all_inputs_are_visited(const INode *node, const std::vector< bool > &visited)
Checks if all the input dependencies of a node have been visited.
INode * producer() const
Returns producer node.
Definition: Edge.h:84