16#include <unordered_set>
47 PartialSubgraph* GetRepresentative()
50 if (m_Parent ==
nullptr)
56 PartialSubgraph* result = m_Parent->GetRepresentative();
65 void MergeWith(PartialSubgraph* other)
67 if (m_Parent ==
nullptr)
69 other = other->GetRepresentative();
81 for (PartialSubgraph* a : m_Antecedents)
83 size_t numErased = a->m_Dependants.erase(
this);
86 throw armnn::Exception(
"number of dependents erased must only be 1.");
88 a->m_Dependants.insert(m_Parent);
90 for (PartialSubgraph* a : m_Dependants)
92 size_t numErased = a->m_Antecedents.erase(
this);
95 throw armnn::Exception(
"number of antecedents erased must only be 1.");
98 a->m_Antecedents.insert(m_Parent);
103 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
104 m_Antecedents.clear();
105 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
106 m_Dependants.clear();
111 GetRepresentative()->MergeWith(other);
116 bool IsMergedWith(PartialSubgraph* other)
118 return GetRepresentative() == other->GetRepresentative();
122 void AddDirectAntecedent(PartialSubgraph* antecedent)
124 if (m_Parent ==
nullptr)
126 antecedent = antecedent->GetRepresentative();
128 m_Antecedents.insert(antecedent);
131 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
133 for (PartialSubgraph* d : m_Dependants)
135 d->m_Antecedents.insert(antecedent);
136 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
141 antecedent->m_Dependants.insert(
this);
142 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
143 for (PartialSubgraph* a : antecedent->m_Antecedents)
145 a->m_Dependants.insert(
this);
146 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
152 GetRepresentative()->AddDirectAntecedent(antecedent);
157 bool HasAntecedent(PartialSubgraph* antecedent)
159 if (m_Parent ==
nullptr)
161 antecedent = antecedent->GetRepresentative();
163 return m_Antecedents.count(antecedent) > 0;
168 return GetRepresentative()->HasAntecedent(antecedent);
174 PartialSubgraph* m_Parent;
176 std::unordered_set<PartialSubgraph*> m_Antecedents;
178 std::unordered_set<PartialSubgraph*> m_Dependants;
182struct LayerSelectionInfo
184 using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
185 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
189 , m_Subgraph{nullptr}
190 , m_IsSelected{selector(*layer)}
191 , m_IsProcessed(false)
195 bool IsInputLayer()
const
200 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
207 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
209 if (!parentLayerOutputSlot)
211 throw armnn::NullPointerException(
"The input slots must be connected here.");
214 if (parentLayerOutputSlot)
216 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
217 auto parentInfo = layerInfos.find(&parentLayer);
218 if (parentInfo == layerInfos.end() ||
219 !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
222 InputSlot* inputSlot = &(*slot);
223 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
225 inputSlots.push_back(inputSlot);
232 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
239 for (InputSlot* childLayerInputSlot : slot->GetConnections())
241 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
242 auto childInfo = layerInfos.find(&childLayer);
243 if (childInfo == layerInfos.end() ||
244 !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
247 OutputSlot* outputSlot = &(*slot);
248 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
250 outputSlots.push_back(outputSlot);
257 IConnectableLayer* m_Layer;
261 std::shared_ptr<PartialSubgraph> m_Subgraph;
276template<
typename Delegate>
278 LayerSelectionInfo& layerInfo,
290 Layer& inputLayer = connectedInput->GetOwningLayer();
292 auto parentInfo = layerInfos.find(&inputLayer);
293 if (parentInfo != layerInfos.end())
295 function(parentInfo->second);
300template<
typename Delegate>
302 LayerSelectionInfo& layerInfo,
309 for (
auto& output : outputSlot.GetConnections())
311 Layer& childLayer = output->GetOwningLayer();
313 auto childInfo = layerInfos.find(&childLayer);
314 if (childInfo != layerInfos.end())
316 function(childInfo->second);
322void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
328 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
340 bool dependenciesOk =
true;
341 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
345 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
347 dependenciesOk =
false;
355 if (layerInfo.m_Subgraph ==
nullptr)
357 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
363 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
370 if (layerInfo.m_Subgraph ==
nullptr)
372 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
380 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
382 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
391 [&ready](LayerSelectionInfo& parentInfo)
393 if (!parentInfo.m_IsProcessed)
404 LayerSelectionInfo::LayerInfoContainer layerInfos;
406 LayerSelectionInfo::LayerInfoQueue processQueue;
408 for (
auto& layer : subgraphLayers)
412 LayerSelectionInfo& layerInfo = emplaced.first->second;
415 if (layerInfo.IsInputLayer())
417 processQueue.push(&layerInfo);
422 for (
auto& inputSlot : subgraphInputSlots)
425 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
426 LayerSelectionInfo& layerInfo = emplaced.first->second;
428 processQueue.push(&layerInfo);
431 while (!processQueue.empty())
433 LayerSelectionInfo& layerInfo = *processQueue.front();
437 if (!layerInfo.m_IsProcessed)
443 processQueue.push(&layerInfo);
451 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
453 processQueue.push(&childInfo);
457 layerInfo.m_IsProcessed =
true;
462 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
463 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
464 for (
auto&
info : layerInfos)
466 if (
info.second.m_IsSelected)
468 auto it = splitMap.find(
info.second.m_Subgraph->GetRepresentative());
469 if (it == splitMap.end())
472 std::make_pair(
info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
476 it->second.push_back(&
info.second);
483 for (
auto& splitGraph : splitMap)
488 for (
auto&& infoPtr : splitGraph.second)
490 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
491 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
492 layers.push_back(infoPtr->m_Layer);
499 auto* castA = PolymorphicDowncast<const InputSlot*>(a);
500 auto* castB = PolymorphicDowncast<const InputSlot*>(b);
501 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
502 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
507 else if (guidA == guidB)
509 return (castA->GetSlotIndex() < castB->GetSlotIndex());
515 auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
516 auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
517 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
518 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
523 else if (guidA == guidB)
532 result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
534 std::move(outputs)));
540 std::sort(result.begin(), result.end(), [](
const SubgraphView::SubgraphViewPtr& a,
541 const SubgraphView::SubgraphViewPtr& b)
543 return a->GetIConnectableLayers().front()->GetGuid() < b->GetIConnectableLayers().front()->GetGuid();
Base class for all ArmNN exceptions so that users can filter to just those.
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
virtual LayerGuid GetGuid() const =0
Returns the unique id of the layer.
An output connection slot for a layer.
virtual unsigned int CalculateIndexOnOwner() const =0
const std::vector< OutputSlot > & GetOutputSlots() const
const std::vector< InputSlot > & GetInputSlots() const
The SubgraphView class represents a subgraph of a Graph.
std::vector< IOutputSlot * > IOutputSlots
std::vector< IInputSlot * > IInputSlots
const IConnectableLayers & GetIConnectableLayers() const
std::list< IConnectableLayer * > IConnectableLayers
const IInputSlots & GetIInputSlots() const
std::function< bool(const Layer &)> LayerSelectorFunction
static Subgraphs SelectSubgraphs(Graph &graph, const LayerSelectorFunction &selector)
Selects subgraphs from a graph based on the selector function and the algorithm.
std::vector< SubgraphView::SubgraphViewPtr > Subgraphs
Copyright (c) 2021 ARM Limited and Contributors.
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
void IgnoreUnused(Ts &&...)