Compute Library
 21.05
Graph.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_GRAPH_GRAPH_H
25 #define ARM_COMPUTE_GRAPH_GRAPH_H
26 
27 #include "arm_compute/graph/Edge.h"
31 
32 #include "support/Mutex.h"
34 
35 #include <map>
36 #include <memory>
37 #include <string>
38 #include <utility>
39 #include <vector>
40 
41 #ifndef BARE_METAL
42 #include <thread>
43 #endif /* BARE_METAL */
44 
45 namespace arm_compute
46 {
47 namespace graph
48 {
49 /** Graph class
50  *
51  * Represents a multiple source - multiple sink directed graph
52  */
53 class Graph final
54 {
55 public:
56  Graph() = default;
57  /** Constructor
58  *
59  * @param[in] id Graph identification number. Can be used to differentiate between graphs. Default value 0
60  * @param[in] name Graph name. Default value empty string
61  */
62  Graph(GraphID id, std::string name);
63  /** Prevent instances of this class from being copied (As this class contains pointers) */
64  Graph(const Graph &) = delete;
65  /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
66  Graph &operator=(const Graph &) = delete;
67  /** Prevent instances of this class from being moved (As this class contains non movable objects) */
68  Graph(Graph &&) = delete;
69  /** Prevent instances of this class from being moved (As this class contains non movable objects) */
70  Graph &operator=(Graph &&) = delete;
71  /** Adds a node to the graph
72  *
73  * @note Models a single output node
74  *
75  * @tparam NT Node operation
76  * @tparam Ts Arguments to operation
77  *
78  * @param[in] args Node arguments
79  *
80  * @return ID of the node
81  */
82  template <typename NT, typename... Ts>
83  NodeID add_node(Ts &&... args);
84  /** Remove the node with the given ID
85  *
86  * @param[in] nid ID of the node to remove
87  *
88  * @return True if the removal took place else false
89  */
90  bool remove_node(NodeID nid);
91  /** Adds a connection between two nodes
92  *
93  * @param[in] source ID of the source node
94  * @param[in] source_idx Output index of the source node
95  * @param[in] sink ID of the sink node
96  * @param[in] sink_idx Input index of the sink node
97  *
98  * @return ID of this connection
99  */
100  EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx);
101  /** Removes an edge (connection)
102  *
103  * @param[in] eid Connection to remove
104  *
105  * @return True if the removal took place else false
106  */
107  bool remove_connection(EdgeID eid);
108  /** Returns graph name
109  *
110  * @return Graph name
111  */
112  std::string name() const;
113  /** Returns graph id
114  *
115  * @return Graph id
116  */
117  GraphID id() const;
118  /** Returns graph input nodes
119  *
120  * @param[in] type Type of nodes to return
121  *
122  * @return vector containing the graph node of given type
123  */
124  const std::vector<NodeID> &nodes(NodeType type);
125  /** Returns nodes of graph
126  *
127  * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
128  *
129  * @return Nodes of graph
130  */
131  std::vector<std::unique_ptr<INode>> &nodes();
132  /** Returns nodes of graph
133  *
134  * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
135  *
136  * @return Nodes of graph
137  */
138  const std::vector<std::unique_ptr<INode>> &nodes() const;
139  /** Returns edges of graph
140  *
141  * @warning Edges can be nullptr if they have been removed during the mutation steps of the graph
142  *
143  * @return Edges of graph
144  */
145  const std::vector<std::unique_ptr<Edge>> &edges() const;
146  /** Returns tensors of graph
147  *
148  * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
149  *
150  * @return Tensors of graph
151  */
152  std::vector<std::unique_ptr<Tensor>> &tensors();
153  /** Returns tensors of graph
154  *
155  * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
156  *
157  * @return Tensors of graph
158  */
159  const std::vector<std::unique_ptr<Tensor>> &tensors() const;
160  /** Get node object given its id
161  *
162  * @warning Can be nullptr if node was removed during the mutation steps of the graph
163  *
164  * @param[in] id Node ID
165  *
166  * @return The actual node object
167  */
168  const INode *node(NodeID id) const;
169  /** Get node object given its id
170  *
171  * @warning Can be nullptr if node was removed during the mutation steps of the graph
172  *
173  * @param[in] id Node ID
174  *
175  * @return The actual node object
176  */
177  INode *node(NodeID id);
178  /** Get edge object given its id
179  *
180  * @warning Can be nullptr if node was removed during the mutation steps of the graph
181  *
182  * @param[in] id Edge ID
183  *
184  * @return The actual edge object
185  */
186  const Edge *edge(EdgeID id) const;
187  /** Get edge object given its id
188  *
189  * @warning Can be nullptr if node was removed during the mutation steps of the graph
190  *
191  * @param[in] id Edge ID
192  *
193  * @return The actual edge object
194  */
195  Edge *edge(EdgeID id);
196  /** Get tensor object given its id
197  *
198  * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
199  *
200  * @param[in] id Tensor ID
201  *
202  * @return The actual tensor object
203  */
204  const Tensor *tensor(TensorID id) const;
205  /** Get tensor object given its id
206  *
207  * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
208  *
209  * @param[in] id Tensor ID
210  *
211  * @return The actual tensor object
212  */
213  Tensor *tensor(TensorID id);
214 
215 private:
216  /** Creates a tensor object
217  *
218  * @param[in] desc Tensor descriptor
219  *
220  * @return Tensor ID
221  */
222  TensorID create_tensor(const TensorDescriptor &desc = TensorDescriptor());
223 
224 private:
225  GraphID _id = GraphID(0); /**< Graph id */
226  std::string _name = {}; /**< Graph name */
227  std::vector<std::unique_ptr<INode>> _nodes = {}; /**< Graph nodes */
228  std::vector<std::unique_ptr<Edge>> _edges = {}; /**< Graph edges */
229  std::vector<std::unique_ptr<Tensor>> _tensors = {}; /**< Graph tensors */
230  std::map<NodeType, std::vector<NodeID>> _tagged_nodes = {}; /**< Graph nodes map with the node type as key */
231  arm_compute::Mutex _mtx = {}; /**< Mutex used for graph construction */
232 };
233 
234 template <typename NT, typename... Ts>
235 inline NodeID Graph::add_node(Ts &&... args)
236 {
238 
239  // Create node
240  NodeID nid = _nodes.size();
241  auto node = std::make_unique<NT>(std::forward<Ts>(args)...);
242  node->set_graph(this);
243  node->set_id(nid);
244 
245  // Keep track of input nodes
246  _tagged_nodes[node->type()].push_back(nid);
247 
248  // Associate a new tensor with each output
249  for(auto &output : node->_outputs)
250  {
251  output = create_tensor();
252  }
253 
254  // Propagate node shape if possible
256 
257  // Add node to the graph nodes
258  _nodes.push_back(std::move(node));
259 
260  return nid;
261 }
262 } // namespace graph
263 } // namespace arm_compute
264 #endif /* ARM_COMPUTE_GRAPH_GRAPH_H */
GraphID id() const
Returns graph id.
Definition: Graph.cpp:169
bool remove_connection(EdgeID eid)
Removes an edge (connection)
Definition: Graph.cpp:118
const std::vector< std::unique_ptr< Edge > > & edges() const
Returns edges of graph.
Definition: Graph.cpp:189
NodeID add_node(Ts &&... args)
Adds a node to the graph.
Definition: Graph.h:235
Graph & operator=(const Graph &)=delete
Prevent instances of this class from being copy assigned (As this class contains pointers)
std::mutex Mutex
Wrapper of Mutex data-object.
Definition: Mutex.h:33
decltype(strategy::transforms) typedef type
std::string name() const
Returns graph name.
Definition: Graph.cpp:164
Copyright (c) 2017-2021 Arm Limited.
std::vector< std::unique_ptr< INode > > & nodes()
Returns nodes of graph.
Definition: Graph.cpp:179
std::vector< std::unique_ptr< Tensor > > & tensors()
Returns tensors of graph.
Definition: Graph.cpp:194
Node interface.
Definition: INode.h:45
void set_graph(Graph *g)
Sets the graph that this node is registered to.
Definition: INode.cpp:50
bool remove_node(NodeID nid)
Remove the node with the given ID.
Definition: Graph.cpp:35
NodeType
Supported nodes.
Definition: Types.h:148
EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx)
Adds a connection between two nodes.
Definition: Graph.cpp:69
void set_id(NodeID id)
Sets the node id.
Definition: INode.cpp:56
unsigned int EdgeID
Definition: Types.h:69
Graph class.
Definition: Graph.h:53
unsigned int NodeID
Definition: Types.h:68
Graph Edge.
Definition: Edge.h:39
const INode * node(NodeID id) const
Get node object given its id.
Definition: Graph.cpp:204
const Edge * edge(EdgeID id) const
Get edge object given its id.
Definition: Graph.cpp:214
virtual bool forward_descriptors()=0
Forwards descriptor information to outputs if possible.
virtual NodeType type() const =0
Returns node's type.
std::lock_guard< Mutex > lock_guard
Wrapper of lock_guard data-object.
Definition: Mutex.h:37
unsigned int TensorID
Definition: Types.h:67
const Tensor * tensor(TensorID id) const
Get tensor object given its id.
Definition: Graph.cpp:224
Tensor object.
Definition: Tensor.h:41
unsigned int GraphID
Definition: Types.h:66