Compute Library
 22.05
CrossLayerMemoryManagerHelpers.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
34 
36 #include "support/Cast.h"
37 
38 #include <algorithm>
39 #include <map>
40 
41 namespace arm_compute
42 {
43 namespace graph
44 {
45 namespace detail
46 {
47 namespace
48 {
49 using HandleCountPair = std::pair<ITensorHandle *, unsigned int>;
50 using HandleCounter = std::map<HandleCountPair::first_type, HandleCountPair::second_type>;
51 using TargetHandleCounter = std::map<Target, HandleCounter>;
52 
53 /** Holds managed IO tensor handles if a task */
54 struct TaskHandles
55 {
56  std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> input_handles = {}; /**< Input handles to a task */
57  std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> output_handles = {}; /**< Output handles of a task */
58 };
59 
60 /** Returns memory group depending on handle backend type
61  *
62  * @param[in] ctx Graph context
63  * @param[in] handle Tensor handle
64  *
65  * @return Memory groupb
66  */
67 IMemoryGroup *get_memory_group_from_handle(GraphContext &ctx, ITensorHandle *handle)
68 {
69  ARM_COMPUTE_ERROR_ON(handle == nullptr);
70  return ctx.memory_management_ctx(handle->target())->cross_group.get();
71 }
72 
73 /** Get handles of const tensors of graph
74  *
75  * @param[in] g Graph
76  *
77  * @return Handles of const tensors of graph
78  */
79 std::set<ITensorHandle *> get_const_handles(const Graph &g)
80 {
81  std::set<NodeType> const_node_types = { NodeType::Input, NodeType::Output, NodeType::Const };
82 
83  std::set<ITensorHandle *> const_tensors;
84 
85  auto &nodes = g.nodes();
86  for(auto &node : nodes)
87  {
88  // If its a const node:
89  if(node != nullptr && const_node_types.find(node->type()) != std::end(const_node_types))
90  {
91  // TODO (geopin01) : Create IO iterator wrappers
92  // Add all its inputs / outputs to the list of constant handles
93  for(unsigned int i = 0; i < node->num_inputs(); ++i)
94  {
95  if(node->input(i) != nullptr)
96  {
97  const_tensors.insert(node->input(i)->handle()->parent_handle());
98  }
99  }
100  for(unsigned int i = 0; i < node->num_outputs(); ++i)
101  {
102  if(node->output(i) != nullptr)
103  {
104  const_tensors.insert(node->output(i)->handle()->parent_handle());
105  }
106  }
107  }
108  }
109 
110  return const_tensors;
111 }
112 
113 /** Builds a list of all the transition handles (Handles that are used to link two nodes)
114  *
115  * @param[in] ctx Graph context
116  * @param[in] task Workload task
117  * @param[in] const_tensors Constant tensors
118  *
119  * @return List of transition handles
120  */
121 TaskHandles get_transition_handles(GraphContext &ctx,
122  ExecutionTask &task,
123  const std::set<ITensorHandle *> &const_tensors)
124 {
125  ARM_COMPUTE_ERROR_ON(task.node == nullptr || (task.task == nullptr && !is_utility_node(task.node)));
126  INode &node = *task.node;
127 
128  TaskHandles transition_handles;
129 
130  // Add input handles
131  for(unsigned int i = 0; i < node.input_edges().size(); ++i)
132  {
133  Edge *input_edge = node.input_edge(i);
134  // If this input is the output of another node
135  if(input_edge != nullptr && input_edge->tensor() != nullptr && const_tensors.find(input_edge->tensor()->handle()->parent_handle()) == std::end(const_tensors))
136  {
137  // Then add it to the list of transition buffers
138  ITensorHandle *tensor_handle = input_edge->tensor()->handle()->parent_handle();
139  IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
140  transition_handles.input_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
141  }
142  }
143 
144  // Add output handles
145  for(unsigned int i = 0; i < node.num_outputs(); ++i)
146  {
147  Tensor *output_tensor = node.output(i);
148  // If this output is used as an input for another node
149  if(output_tensor != nullptr && const_tensors.find(output_tensor->handle()->parent_handle()) == std::end(const_tensors))
150  {
151  ITensorHandle *tensor_handle = output_tensor->handle()->parent_handle();
152  IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
153  transition_handles.output_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
154  }
155  }
156 
157  return transition_handles;
158 }
159 
160 /** Counts handles refcount for each input handle of each target
161  *
162  * @param[in] task Execution task containing the managed handles
163  * @param[in,out] handle_counter Data structure that keeps the handles reference count
164  */
165 void count_input_handles_per_target(const TaskHandles &task_handles, TargetHandleCounter &handle_counter)
166 {
167  for(const auto &handle : task_handles.input_handles)
168  {
169  ITensorHandle *key = handle.first;
170  HandleCounter &target_counter = handle_counter[key->target()];
171  if(target_counter.find(key) == std::end(target_counter))
172  {
173  target_counter.emplace(std::make_pair(key, 1));
174  }
175  else
176  {
177  ++target_counter[key];
178  }
179  }
180 }
181 
182 /** Calculates the lifetime of each tensor handle
183  *
184  * @param[in, out] tasks_handles Tensor handles for each task
185  * @param[in] hc Data structure that keeps the handles reference count
186  */
187 void configure_handle_lifetime(std::vector<TaskHandles> &tasks_handles, const HandleCounter &hc)
188 {
189  // Identify max number of tensors in flight
190  HandleCounter tensors_in_flight;
191 
192  // Acquires the given handles and sets them as in flight if they aren't already
193  auto acquire = [&](std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> &handles)
194  {
195  for(auto &handle : handles)
196  {
197  ITensorHandle *parent_handle = handle.first;
198  ARM_COMPUTE_ERROR_ON(parent_handle == nullptr);
199  // If the tensor is not already in flight:
200  if(tensors_in_flight.find(parent_handle) == std::end(tensors_in_flight))
201  {
202  ARM_COMPUTE_ERROR_ON(hc.find(parent_handle) == std::end(hc));
203  // Then add it to the list of in flight tensors
204  tensors_in_flight.insert(std::make_pair(parent_handle, hc.at(parent_handle)));
205  // Start of allocation's lifetime
206  parent_handle->manage(handle.second);
207  }
208  }
209  };
210 
211  for(auto &task_handle : tasks_handles)
212  {
213  // Marking all the input and output tensors of the task as in flight
214  acquire(task_handle.input_handles);
215  acquire(task_handle.output_handles);
216 
217  // Releasing the input tensors
218  for(auto &input_handle : task_handle.input_handles)
219  {
220  ITensorHandle *ihandle = input_handle.first;
221  ARM_COMPUTE_ERROR_ON(ihandle == nullptr);
222  ARM_COMPUTE_ERROR_ON(tensors_in_flight.find(ihandle) == std::end(tensors_in_flight));
223  --tensors_in_flight[ihandle];
224  if(tensors_in_flight[ihandle] <= 0)
225  {
226  // Remove tensor for tensors in flight
227  tensors_in_flight.erase(ihandle);
228  // End of allocation's lifetime
229  ihandle->allocate();
230  }
231  }
232  }
233 }
234 } // namespace
235 
237 {
238  // Get const tensors (un-managed)
239  std::set<ITensorHandle *> const_tensors = get_const_handles(g);
240 
241  std::vector<TaskHandles> tasks_handles;
242  TargetHandleCounter target_handle_count;
243 
244  // Count handles
245  for(auto &task : workload.tasks)
246  {
247  // Populates IO handles
248  tasks_handles.push_back(get_transition_handles(ctx, task, const_tensors));
249 
250  // Count handles
251  count_input_handles_per_target(tasks_handles.back(), target_handle_count);
252  }
253 
254  // Setup memory managers
255  for(auto &hc : target_handle_count)
256  {
257  MemoryManagerContext *mm_ctx = ctx.memory_management_ctx(hc.first);
258  if(mm_ctx != nullptr)
259  {
260  if(mm_ctx->cross_mm != nullptr && mm_ctx->cross_group != nullptr)
261  {
262  // Manage and allocate tensors
263  configure_handle_lifetime(tasks_handles, hc.second);
264  }
265  }
266  }
267 }
268 } // namespace detail
269 } // namespace graph
270 } // namespace arm_compute
std::vector< std::pair< ITensorHandle *, IMemoryGroup * > > input_handles
Input handles to a task.
void configure_transition_manager(Graph &g, GraphContext &ctx, ExecutionWorkload &workload)
Configures transition manager and execution workload.
std::shared_ptr< arm_compute::IMemoryGroup > cross_group
Cross-function memory group.
Definition: GraphContext.h:45
bool is_utility_node(INode *node)
Definition: Utils.h:37
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
Copyright (c) 2017-2022 Arm Limited.
std::vector< ExecutionTask > tasks
Execution workload.
Definition: Workload.h:102
void end(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:290
Graph class.
Definition: Graph.h:53
std::shared_ptr< arm_compute::IMemoryManager > cross_mm
Cross-function memory manager.
Definition: GraphContext.h:44
Contains structs required for memory management.
Definition: GraphContext.h:40
MemoryManagerContext * memory_management_ctx(Target target)
Gets a memory manager context for a given target.
std::vector< std::pair< ITensorHandle *, IMemoryGroup * > > output_handles
Output handles of a task.