30 namespace experimental
32 namespace dynamic_fusion
36 std::vector<DependencyGraph::TensorId> get_tensor_ids(
const std::vector<const ITensorInfo *> tensors)
38 std::vector<DependencyGraph::TensorId> tensor_ids{};
39 std::transform(std::begin(tensors),
std::end(tensors), std::back_inserter(tensor_ids),
40 [](
const auto &
t) {
return t->id(); });
47 : _id{
id}, _operator_type{operator_type}, _tensors{tensors}
58 return _operator_type;
68 const auto src_tensor_ids = get_tensor_ids(op.
tensors().get_const_src_tensors());
69 const auto dst_tensor_ids = get_tensor_ids(op.
tensors().get_const_dst_tensors());
96 if (_operators.size() > 0)
100 const auto first_dst_tensor = root_dst_tensors[0];
101 const auto dst_tensors = op.
tensors().get_const_dst_tensors();
102 for (
const auto &
t : root_dst_tensors)
109 for (
const auto &
t : dst_tensors)
118 if (_operators.size() > 0)
122 const auto first_dst_tensor_layout = root_dst_tensors[0]->data_layout();
123 const auto dst_tensors = op.
tensors().get_const_dst_tensors();
124 for (
const auto &
t : root_dst_tensors)
126 if (
t->data_layout() != first_dst_tensor_layout)
131 for (
const auto &
t : dst_tensors)
133 if (
t->data_layout() != first_dst_tensor_layout)
144 const auto src_tensor_ids = get_tensor_ids(op.
tensors().get_const_src_tensors());
145 const auto dst_tensor_ids = get_tensor_ids(op.
tensors().get_const_dst_tensors());
147 _operators[op.
id()] = op;
152 auto new_id =
static_cast<OperatorId>(_operators.size());
153 return Operator{new_id, operator_type, tensors};
163 return &_operators.at(roots[0]);