16 #include <unordered_set>
47 PartialSubgraph* GetRepresentative()
50 if (m_Parent ==
nullptr)
56 PartialSubgraph* result = m_Parent->GetRepresentative();
65 void MergeWith(PartialSubgraph* other)
67 if (m_Parent ==
nullptr)
69 other = other->GetRepresentative();
81 for (PartialSubgraph* a : m_Antecedents)
83 size_t numErased = a->m_Dependants.erase(
this);
86 a->m_Dependants.insert(m_Parent);
88 for (PartialSubgraph* a : m_Dependants)
90 size_t numErased = a->m_Antecedents.erase(
this);
93 a->m_Antecedents.insert(m_Parent);
98 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
99 m_Antecedents.clear();
100 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
101 m_Dependants.clear();
106 GetRepresentative()->MergeWith(other);
111 bool IsMergedWith(PartialSubgraph* other)
113 return GetRepresentative() == other->GetRepresentative();
117 void AddDirectAntecedent(PartialSubgraph* antecedent)
119 if (m_Parent ==
nullptr)
121 antecedent = antecedent->GetRepresentative();
123 m_Antecedents.insert(antecedent);
126 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
128 for (PartialSubgraph* d : m_Dependants)
130 d->m_Antecedents.insert(antecedent);
131 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
136 antecedent->m_Dependants.insert(
this);
137 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
138 for (PartialSubgraph* a : antecedent->m_Antecedents)
140 a->m_Dependants.insert(
this);
141 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
147 GetRepresentative()->AddDirectAntecedent(antecedent);
152 bool HasAntecedent(PartialSubgraph* antecedent)
154 if (m_Parent ==
nullptr)
156 antecedent = antecedent->GetRepresentative();
158 return m_Antecedents.count(antecedent) > 0;
163 return GetRepresentative()->HasAntecedent(antecedent);
169 PartialSubgraph* m_Parent;
171 std::unordered_set<PartialSubgraph*> m_Antecedents;
173 std::unordered_set<PartialSubgraph*> m_Dependants;
177 struct LayerSelectionInfo
179 using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
180 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
190 bool IsInputLayer()
const
195 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
198 for (
auto&& slot = PolymorphicDowncast<Layer*>(
m_Layer)->BeginInputSlots();
199 slot != PolymorphicDowncast<Layer*>(
m_Layer)->EndInputSlots();
202 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
203 ARMNN_ASSERT_MSG(parentLayerOutputSlot !=
nullptr,
"The input slots must be connected here.");
204 if (parentLayerOutputSlot)
206 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
207 auto parentInfo = layerInfos.find(&parentLayer);
208 if (parentInfo == layerInfos.end() ||
209 !
m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
212 InputSlot* inputSlot = &(*slot);
213 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
215 inputSlots.push_back(inputSlot);
222 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
225 for (
auto&& slot = PolymorphicDowncast<Layer*>(
m_Layer)->BeginOutputSlots();
226 slot != PolymorphicDowncast<Layer*>(
m_Layer)->EndOutputSlots();
229 for (InputSlot* childLayerInputSlot : slot->GetConnections())
231 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
232 auto childInfo = layerInfos.find(&childLayer);
233 if (childInfo == layerInfos.end() ||
234 !
m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
237 OutputSlot* outputSlot = &(*slot);
238 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
240 outputSlots.push_back(outputSlot);
266 template<
typename Delegate>
268 LayerSelectionInfo& layerInfo,
271 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
275 auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection());
277 Layer& inputLayer = connectedInput->GetOwningLayer();
279 auto parentInfo = layerInfos.find(&inputLayer);
280 if (parentInfo != layerInfos.end())
282 function(parentInfo->second);
287 template<
typename Delegate>
289 LayerSelectionInfo& layerInfo,
292 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
296 for (
auto& output : outputSlot.GetConnections())
298 Layer& childLayer = output->GetOwningLayer();
300 auto childInfo = layerInfos.find(&childLayer);
301 if (childInfo != layerInfos.end())
303 function(childInfo->second);
309 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
315 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
327 bool dependenciesOk = true;
328 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
332 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
334 dependenciesOk = false;
342 if (layerInfo.m_Subgraph ==
nullptr)
344 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
350 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
357 if (layerInfo.m_Subgraph ==
nullptr)
359 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
367 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
369 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
378 [&ready](LayerSelectionInfo& parentInfo)
380 if (!parentInfo.m_IsProcessed)
388 SubgraphViewSelector::Subgraphs
391 LayerSelectionInfo::LayerInfoContainer layerInfos;
393 LayerSelectionInfo::LayerInfoQueue processQueue;
395 for (
auto& layer : subgraphLayers)
398 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector});
399 LayerSelectionInfo& layerInfo = emplaced.first->second;
402 if (layerInfo.IsInputLayer())
404 processQueue.push(&layerInfo);
409 for (
auto& inputSlot : subgraphInputSlots)
411 Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer();
412 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
413 LayerSelectionInfo& layerInfo = emplaced.first->second;
415 processQueue.push(&layerInfo);
418 while (!processQueue.empty())
420 LayerSelectionInfo& layerInfo = *processQueue.front();
424 if (!layerInfo.m_IsProcessed)
430 processQueue.push(&layerInfo);
438 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
440 processQueue.push(&childInfo);
444 layerInfo.m_IsProcessed =
true;
449 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
450 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
451 for (
auto&
info : layerInfos)
453 if (
info.second.m_IsSelected)
455 auto it = splitMap.find(
info.second.m_Subgraph->GetRepresentative());
456 if (it == splitMap.end())
459 std::make_pair(
info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
463 it->second.push_back(&
info.second);
470 for (
auto& splitGraph : splitMap)
475 for (
auto&& infoPtr : splitGraph.second)
477 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
478 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
479 layers.push_back(infoPtr->m_Layer);
486 auto* castA = PolymorphicDowncast<const InputSlot*>(a);
487 auto* castB = PolymorphicDowncast<const InputSlot*>(b);
488 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
489 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
494 else if (guidA == guidB)
496 return (castA->GetSlotIndex() < castB->GetSlotIndex());
502 auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
503 auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
504 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
505 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
510 else if (guidA == guidB)
519 result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
521 std::move(outputs)));
527 std::sort(result.begin(), result.end(), [](
const SubgraphView::SubgraphViewPtr& a,
528 const SubgraphView::SubgraphViewPtr& b)
530 return a->GetIConnectableLayers().front()->GetGuid() < b->GetIConnectableLayers().front()->GetGuid();