24 #ifndef SRC_DYNAMIC_FUSION_SKETCH_UTILS_DEPENDENCYGRAPH
25 #define SRC_DYNAMIC_FUSION_SKETCH_UTILS_DEPENDENCYGRAPH
37 namespace experimental
39 namespace dynamic_fusion
44 bool is_in(
const T &v,
const std::vector<T> &vec)
62 using AdjList = std::map<Id, std::vector<Id>>;
98 const std::vector<TensorId> &inputs,
99 const std::vector<TensorId> &outputs,
100 bool is_output =
false)
const
103 if (all_ops().empty())
111 if (_last_op_available)
113 auto use_input_from_last_op =
false;
115 for (
auto src_tensor : inputs)
117 const auto src_ops = _adj_src_ops.find(src_tensor);
119 if (src_ops != _adj_src_ops.end())
123 if (!src_ops->second.empty())
125 const auto src_op = src_ops->second[0];
127 if (src_op == _last_op)
129 if (use_input_from_last_op)
136 use_input_from_last_op =
true;
148 if (!use_input_from_last_op)
157 for (
auto dst_tensor : outputs)
159 if (_adj_dst_ops.find(dst_tensor) != _adj_dst_ops.end())
174 const std::vector<TensorId> &inputs,
175 const std::vector<TensorId> &outputs,
176 bool is_output =
false)
178 const auto success =
add_operator(op, inputs, outputs, is_output);
192 const std::vector<TensorId> &inputs,
193 const std::vector<TensorId> &outputs,
194 bool is_output =
false)
196 if (operator_exists(op))
200 _adj_src_tensors[op] = {};
201 _adj_dst_tensors[op] = {};
202 for (
auto in_tensor : inputs)
206 link_input(op, in_tensor);
208 for (
auto out_tensor : outputs)
211 if (path_exists_from_tensor_to_op(out_tensor, op))
218 link_output(op, out_tensor);
224 _last_op_available =
true;
240 std::vector<OpPack> ops_seq;
241 std::set<Id> done_ops;
242 std::set<Id> done_tensors;
246 for (
auto tensor : input_tensors)
248 done_tensors.insert(
tensor);
250 for (
auto op : _adj_dst_ops.at(
tensor))
252 build_operators_sequence_from_op(op, ops_seq, done_ops, done_tensors);
271 return std::make_tuple(g0._adj_src_tensors, g0._adj_dst_tensors, g0._adj_src_ops, g0._adj_dst_ops) ==
272 std::make_tuple(g1._adj_src_tensors, g1._adj_dst_tensors, g1._adj_src_ops, g1._adj_dst_ops);
276 return _adj_src_ops.at(
tensor);
280 return _adj_dst_ops.at(
tensor);
288 std::vector<TensorId> tensors{};
289 std::transform(std::begin(_adj_src_ops),
std::end(_adj_src_ops), std::back_inserter(tensors),
290 [](
const auto &it) {
return it.first; });
299 std::vector<TensorId> tensors;
300 for (
auto tensor_src_ops : _adj_src_ops)
302 if (tensor_src_ops.second.empty())
304 tensors.push_back(tensor_src_ops.first);
315 std::vector<TensorId> tensors;
316 for (
auto tensor_dst_ops : _adj_dst_ops)
318 if (tensor_dst_ops.second.empty())
320 tensors.push_back(tensor_dst_ops.first);
331 std::vector<TensorId> tensors;
335 for (
auto src_tensor : _adj_src_ops)
337 if (!src_tensor.second.empty())
339 const auto dst_tensor = _adj_dst_ops.find(src_tensor.first);
340 if (dst_tensor != _adj_dst_ops.end())
342 if (!dst_tensor->second.empty())
344 tensors.push_back(src_tensor.first);
358 std::vector<OperatorId> ops{};
359 const auto op_list = all_ops();
361 for (
auto op : op_list)
363 if (src_ops(op).empty())
365 ops.emplace_back(op);
375 if (!tensor_exists(in_tensor))
377 insert_new_tensor(in_tensor);
380 _adj_src_tensors[op].push_back(in_tensor);
381 _adj_dst_ops[in_tensor].push_back(op);
386 if (!tensor_exists(out_tensor))
388 insert_new_tensor(out_tensor);
391 _adj_dst_tensors[op].push_back(out_tensor);
392 _adj_src_ops[out_tensor].push_back(op);
395 std::vector<OperatorId> src_ops(
OperatorId op)
const
398 std::vector<OperatorId> ops{};
399 for (
TensorId src_tensor : src_tensors(op))
401 ops.insert(ops.end(), std::begin(_adj_src_ops.at(src_tensor)),
std::end(_adj_src_ops.at(src_tensor)));
405 std::vector<OperatorId> dst_ops(
OperatorId op)
const
408 std::vector<OperatorId> ops{};
409 for (
TensorId dst_tensor : _adj_dst_tensors.at(op))
411 ops.insert(ops.end(), std::begin(_adj_dst_ops.at(dst_tensor)),
std::end(_adj_dst_ops.at(dst_tensor)));
421 std::vector<TensorId> src_tensors(
OperatorId op)
const
424 return _adj_src_tensors.at(op);
431 std::vector<TensorId> dst_tensors(
OperatorId op)
const
434 return _adj_dst_tensors.at(op);
440 std::vector<OperatorId> all_ops()
const
442 std::vector<OperatorId> ops{};
443 std::transform(std::begin(_adj_src_tensors),
std::end(_adj_src_tensors), std::back_inserter(ops),
444 [](
const auto &it) {
return it.first; });
453 for (
auto src_tensor : _adj_src_tensors.at(op))
455 auto &dst_ops = _adj_dst_ops.at(src_tensor);
456 dst_ops.erase(std::remove(std::begin(dst_ops),
std::end(dst_ops), op),
std::end(dst_ops));
458 for (
auto dst_tensor : _adj_dst_tensors.at(op))
460 auto &src_ops = _adj_src_ops.at(dst_tensor);
461 src_ops.erase(std::remove(std::begin(src_ops),
std::end(src_ops), op),
std::end(src_ops));
467 if (_adj_src_ops.at(
t).empty() && _adj_dst_ops.at(
t).empty())
469 _adj_src_ops.erase(
t);
470 _adj_dst_ops.erase(
t);
473 _adj_src_tensors.erase(op);
474 _adj_dst_tensors.erase(op);
478 _adj_src_ops[
tensor] = {};
479 _adj_dst_ops[
tensor] = {};
483 return _adj_src_ops.find(
tensor) != _adj_src_ops.end() && _adj_dst_ops.find(
tensor) != _adj_dst_ops.end();
487 return _adj_src_tensors.find(op) != _adj_src_tensors.end() &&
488 _adj_dst_tensors.find(op) != _adj_dst_tensors.end();
492 if (!operator_exists(op) || !tensor_exists(
tensor))
496 const auto op_inputs = src_tensors(op);
497 return std::find(op_inputs.begin(), op_inputs.end(),
tensor) != op_inputs.end();
501 if (!operator_exists(op) || !tensor_exists(
tensor))
505 const auto op_outputs = dst_tensors(op);
506 return std::find(op_outputs.begin(), op_outputs.end(),
tensor) != op_outputs.end();
510 return is_src_tensor_of(op,
tensor) || is_dst_tensor_of(op,
tensor);
520 return dst_ops(op).empty();
522 std::vector<OperatorId> get_dst_ops()
const
524 std::vector<OperatorId> ops{};
525 const auto op_list = all_ops();
527 for (
auto op : op_list)
531 ops.emplace_back(op);
538 if (!tensor_exists(src_tensor) || !operator_exists(dst_op))
544 if (path_exists_from_op_to_op(child_op, dst_op))
554 if (!operator_exists(src_op) || !operator_exists(dst_op))
558 if (src_op == dst_op)
562 if (
is_in(src_op, get_dst_ops()))
566 for (
auto child_tensor : dst_tensors(src_op))
568 if (path_exists_from_tensor_to_op(child_tensor, dst_op))
576 void build_operators_sequence_from_op(
Id op,
577 std::vector<OpPack> &ops_seq,
578 std::set<Id> &done_ops,
579 std::set<Id> &done_tensors)
const
584 if (done_ops.find(op) != done_ops.end())
592 const auto src_tensors = _adj_src_tensors.at(op);
594 for (
auto src : src_tensors)
596 if (done_tensors.find(
src) == done_tensors.end())
603 const auto dst_tensors = _adj_dst_tensors.at(op);
607 OpPack
pack{op, src_tensors, dst_tensors};
608 ops_seq.push_back(
pack);
610 done_tensors.insert(dst_tensors.begin(), dst_tensors.end());
614 if (dst_tensors.size() == 1 && _adj_dst_ops.at(dst_tensors[0]).size() == 1)
616 op = _adj_dst_ops.at(dst_tensors[0])[0];
620 for (
auto dst_tensor : dst_tensors)
622 const auto dst_ops = _adj_dst_ops.at(dst_tensor);
624 for (
auto dst_op : dst_ops)
626 build_operators_sequence_from_op(dst_op, ops_seq, done_ops, done_tensors);
641 bool _last_op_available{
false};