35 namespace experimental
37 namespace dynamic_fusion
56 if (!_components.empty() &&
72 if (!_components.empty())
76 const auto first_dst_tensor = root_dst_tensors[0];
77 const auto dst_tensors = component->
tensors().get_const_dst_tensors();
78 for (
const auto &
t : root_dst_tensors)
85 for (
const auto &
t : dst_tensors)
94 if (!_components.empty())
98 const auto first_dst_tensor_layout = root_dst_tensors[0]->data_layout();
99 const auto dst_tensors = component->
tensors().get_const_dst_tensors();
100 for (
const auto &
t : root_dst_tensors)
102 if (
t->data_layout() != first_dst_tensor_layout)
107 for (
const auto &
t : dst_tensors)
109 if (
t->data_layout() != first_dst_tensor_layout)
125 _components.push_back(component);
138 std::set<const ITensorInfo *> output_tensors;
139 std::map<const ITensorInfo *, std::vector<const ITensorInfo *>> possible_tile_map;
140 std::map<const ITensorInfo *, int32_t> tile_usages;
142 for (
auto component : _components)
144 const auto tensors = component->tensors();
145 const auto src_tensors = tensors.get_const_src_tensors();
146 const auto dst_tensors = tensors.get_const_dst_tensors();
149 for (
auto tensor : src_tensors)
151 const auto output_tensors_it = output_tensors.find(
tensor);
153 if (output_tensors_it != output_tensors.end())
157 output_tensors.erase(output_tensors_it);
158 _interm_tensors.insert(
tensor);
160 else if (_interm_tensors.find(
tensor) == _interm_tensors.end())
162 _input_tensors.insert(
tensor);
165 possible_tile_map.emplace(
tensor, std::vector<const ITensorInfo *>());
169 for (
auto tensor : dst_tensors)
174 output_tensors.insert(
tensor);
177 possible_tile_map.emplace(
tensor, std::vector<const ITensorInfo *>());
181 const auto component_type = component->type();
186 const auto dst_tensor = dst_tensors[0];
187 const auto &
dst_shape = dst_tensor->tensor_shape();
188 const auto &dst_type = dst_tensor->data_type();
190 tile_usages[dst_tensor] = 0;
192 for (
auto src_tensor : src_tensors)
194 const auto &src_shape = src_tensor->tensor_shape();
195 const auto &src_type = src_tensor->data_type();
197 if (src_shape ==
dst_shape && src_type == dst_type)
199 const auto tile_usages_it = tile_usages.find(src_tensor);
207 ++tile_usages_it->second;
210 possible_tile_map[dst_tensor].push_back(src_tensor);
217 for (
auto tensor : dst_tensors)
225 for (
auto tensor : _input_tensors)
230 for (
auto component : _components)
232 const auto dst_tensors = component->tensors().get_const_dst_tensors();
234 for (
auto tensor : dst_tensors)
236 const auto target_tiles = possible_tile_map.at(
tensor);
239 for (
auto target : target_tiles)
241 const auto num_usage = tile_usages[target];
247 _tile_map[
tensor] = _tile_map.at(target);
254 for (
auto tensor : output_tensors)
261 for (
auto tensor_tile : _tile_map)
263 if (tensor_tile.first == tensor_tile.second && _interm_tensors.find(tensor_tile.first) != _interm_tensors.end())
265 _tiles.push_back(tensor_tile.first);
269 std::set_union(_input_tensors.begin(), _input_tensors.end(), output_tensors.begin(), output_tensors.end(),
270 std::back_inserter(_argument_tensors));
271 _any_output_tensor = *output_tensors.begin();
284 if (_tile_map.find(
tensor) != _tile_map.end())
286 return _tile_map.at(
tensor);
295 return _any_output_tensor;
301 return _argument_tensors;
310 return _components[0];
316 return _interm_tensors.find(
tensor) != _interm_tensors.end();
322 return _input_tensors.find(
tensor) != _input_tensors.end();
327 return _components.size();
331 return _components.empty();
335 return _components[index];
339 return _components[index];
343 return _components.begin();
347 return _components.end();
351 return _components.cbegin();
355 return _components.cend();
359 return _components.cbegin();
363 return _components.cend();