49 using HandleCountPair = std::pair<ITensorHandle *, unsigned int>;
50 using HandleCounter = std::map<HandleCountPair::first_type, HandleCountPair::second_type>;
51 using TargetHandleCounter = std::map<Target, HandleCounter>;
56 std::vector<std::pair<ITensorHandle *, IMemoryGroup *>>
input_handles = {};
67 IMemoryGroup *get_memory_group_from_handle(GraphContext &ctx, ITensorHandle *handle)
70 return ctx.memory_management_ctx(handle->target())->cross_group.get();
79 std::set<ITensorHandle *> get_const_handles(
const Graph &g)
83 std::set<ITensorHandle *> const_tensors;
85 auto &nodes = g.nodes();
86 for(
auto &node : nodes)
89 if(node !=
nullptr && const_node_types.find(node->type()) !=
std::end(const_node_types))
93 for(
unsigned int i = 0; i < node->num_inputs(); ++i)
95 if(node->input(i) !=
nullptr)
97 const_tensors.insert(node->input(i)->handle()->parent_handle());
100 for(
unsigned int i = 0; i < node->num_outputs(); ++i)
102 if(node->output(i) !=
nullptr)
104 const_tensors.insert(node->output(i)->handle()->parent_handle());
110 return const_tensors;
121 TaskHandles get_transition_handles(GraphContext &ctx,
123 const std::set<ITensorHandle *> &const_tensors)
126 INode &node = *task.node;
128 TaskHandles transition_handles;
131 for(
unsigned int i = 0; i < node.input_edges().size(); ++i)
133 Edge *input_edge = node.input_edge(i);
135 if(input_edge !=
nullptr && input_edge->tensor() !=
nullptr && const_tensors.find(input_edge->tensor()->handle()->parent_handle()) ==
std::end(const_tensors))
138 ITensorHandle *tensor_handle = input_edge->tensor()->handle()->parent_handle();
139 IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
140 transition_handles.input_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
145 for(
unsigned int i = 0; i < node.num_outputs(); ++i)
147 Tensor *output_tensor = node.output(i);
149 if(output_tensor !=
nullptr && const_tensors.find(output_tensor->handle()->parent_handle()) ==
std::end(const_tensors))
151 ITensorHandle *tensor_handle = output_tensor->handle()->parent_handle();
152 IMemoryGroup *mm_group = get_memory_group_from_handle(ctx, tensor_handle);
153 transition_handles.output_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
157 return transition_handles;
165 void count_input_handles_per_target(
const TaskHandles &task_handles, TargetHandleCounter &handle_counter)
167 for(
const auto &handle : task_handles.input_handles)
169 ITensorHandle *key = handle.first;
170 HandleCounter &target_counter = handle_counter[key->target()];
171 if(target_counter.find(key) ==
std::end(target_counter))
173 target_counter.emplace(std::make_pair(key, 1));
177 ++target_counter[key];
187 void configure_handle_lifetime(std::vector<TaskHandles> &tasks_handles,
const HandleCounter &hc)
190 HandleCounter tensors_in_flight;
193 auto acquire = [&](std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> &handles)
195 for(
auto &handle : handles)
197 ITensorHandle *parent_handle = handle.first;
200 if(tensors_in_flight.find(parent_handle) ==
std::end(tensors_in_flight))
204 tensors_in_flight.insert(std::make_pair(parent_handle, hc.at(parent_handle)));
206 parent_handle->manage(handle.second);
211 for(
auto &task_handle : tasks_handles)
214 acquire(task_handle.input_handles);
215 acquire(task_handle.output_handles);
218 for(
auto &input_handle : task_handle.input_handles)
220 ITensorHandle *ihandle = input_handle.first;
223 --tensors_in_flight[ihandle];
224 if(tensors_in_flight[ihandle] <= 0)
227 tensors_in_flight.erase(ihandle);
239 std::set<ITensorHandle *> const_tensors = get_const_handles(g);
241 std::vector<TaskHandles> tasks_handles;
242 TargetHandleCounter target_handle_count;
245 for(
auto &task : workload.
tasks)
248 tasks_handles.push_back(get_transition_handles(ctx, task, const_tensors));
251 count_input_handles_per_target(tasks_handles.back(), target_handle_count);
255 for(
auto &hc : target_handle_count)
258 if(mm_ctx !=
nullptr)
263 configure_handle_lifetime(tasks_handles, hc.second);
std::vector< std::pair< ITensorHandle *, IMemoryGroup * > > input_handles
Input handles to a task.
void configure_transition_manager(Graph &g, GraphContext &ctx, ExecutionWorkload &workload)
Configures transition manager and execution workload.
std::shared_ptr< arm_compute::IMemoryGroup > cross_group
Cross-function memory group.
bool is_utility_node(INode *node)
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Copyright (c) 2017-2021 Arm Limited.
std::vector< ExecutionTask > tasks
Execution workload.
void end(TokenStream &in, bool &valid)
std::shared_ptr< arm_compute::IMemoryManager > cross_mm
Cross-function memory manager.
Contains structs required for memory management.
MemoryManagerContext * memory_management_ctx(Target target)
Gets a memory manager context for a given target.
std::vector< std::pair< ITensorHandle *, IMemoryGroup * > > output_handles
Output handles of a task.