21 #include <unordered_map>
22 #include <unordered_set>
33 template <
typename LayerType>
36 return PolymorphicDowncast<LayerType*>(layer);
39 template <
typename Func>
42 for (
auto it = m_Layers.begin(); it != m_Layers.end(); )
44 auto next = std::next(it);
67 return {
m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
73 &(PtrCast<const InputLayer>) };
87 &(PtrCast<const OutputLayer>) };
92 return {
m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
98 Graph(
bool shapeInferenceMethod =
false,
bool allowExpandedDims =
false)
99 : m_LayersInOrder(true)
100 , m_AllowExpandedDims(allowExpandedDims)
112 *
this = std::move(other);
117 m_InputIds = std::move(other.m_InputIds);
118 m_OutputIds = std::move(other.m_OutputIds);
119 m_LayersInOrder = std::move(other.m_LayersInOrder);
120 m_Views = std::move(other.m_Views);
121 m_Profiler = std::move(other.m_Profiler);
122 m_AllowExpandedDims = other.m_AllowExpandedDims;
123 m_ShapeInferenceMethod = other.m_ShapeInferenceMethod;
124 other.ForEachLayer([
this](
Layer* otherLayer)
126 otherLayer->
Reparent(*
this, m_Layers.end());
129 if (!other.m_PosInGraphMap.empty())
134 if (!other.m_Layers.empty())
155 template <
typename LayerT,
typename... Args>
160 template <
typename LayerT,
typename... Args>
164 template <
typename LayerT,
typename... Args>
172 template <
typename LayerT>
227 m_Views[notifyOnEvent].emplace_back(observable);
231 m_Views[notifyOnEvent].remove(observable);
237 const std::shared_ptr<IProfiler>&
GetProfiler()
const;
242 template <
typename LayerT>
243 class LayerInGraphBase;
245 template <
typename LayerT>
258 while ((it != m_Layers.end()) &&
268 while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() ==
LayerType::Output))
275 void NotifyObservables(
GraphEvent event, Layer* graphState)
278 for (
auto& observable : m_Views[event])
280 observable->Update(graphState);
284 std::unordered_set<LayerBindingId> m_InputIds;
285 std::unordered_set<LayerBindingId> m_OutputIds;
286 std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
293 mutable bool m_LayersInOrder;
295 bool m_AllowExpandedDims;
297 std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
300 std::shared_ptr<IProfiler> m_Profiler;
304 void ConstructErrorMessageForUnconnectedInputs(Layer*
const layer,
305 unsigned int slotIndex);
311 template <
typename LayerT>
312 class Graph::LayerInGraphBase :
public LayerT
315 template <
typename... Args>
316 LayerInGraphBase(
Graph& graph,
Iterator insertBefore, Args&&... args)
317 : LayerT(
std::forward<Args>(args)...), m_Graph(&graph)
319 Insert(*m_Graph, insertBefore);
326 void Reparent(
Graph& destGraph,
Iterator insertBefore)
override
328 Insert(destGraph, insertBefore);
331 m_Graph = &destGraph;
337 graph.m_PosInGraphMap.emplace(
this, graph.m_Layers.emplace(insertBefore,
this));
340 void Remove(
Graph& graph)
342 auto layerIt = graph.GetPosInGraph(*
this);
343 graph.m_Layers.erase(layerIt);
345 const size_t numErased = graph.m_PosInGraphMap.erase(
this);
357 template <
typename LayerT>
358 class Graph::LayerInGraph final :
public LayerInGraphBase<LayerT>
361 template <
typename... Args>
362 LayerInGraph(
Graph& graph, Args&&... args)
363 : LayerInGraphBase<LayerT>(graph,
366 std::forward<Args>(args)...)
369 template <
typename... Args>
370 LayerInGraph(
Graph& graph,
Iterator insertBefore, Args&&... args)
371 : LayerInGraphBase<LayerT>(graph,
373 graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)),
374 std::forward<Args>(args)...)
380 class Graph::LayerInGraph<
ConstantLayer> final :
public LayerInGraphBase<ConstantLayer>
383 template <
typename... Args>
388 std::forward<Args>(args)...)
390 template <
typename... Args>
393 : LayerInGraph(graph,
std::forward<Args>(args)...)
401 class Graph::LayerInGraph<
InputLayer> final :
public LayerInGraphBase<InputLayer>
404 template <
typename... Args>
409 std::forward<Args>(args)...)
411 const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
417 template <
typename... Args>
420 : LayerInGraph(graph,
std::forward<Args>(args)...)
425 const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
432 class Graph::LayerInGraph<
OutputLayer> final :
public LayerInGraphBase<OutputLayer>
435 template <
typename... Args>
440 std::forward<Args>(args)...)
442 const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
450 const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
457 auto it = m_PosInGraphMap.find(&layer);
458 if (it == m_PosInGraphMap.end())
465 template <
typename LayerT,
typename... Args>
468 m_LayersInOrder = m_LayersInOrder &&
470 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, std::forward<Args>(args)...);
472 layer->SetShapeInferenceMethod(m_ShapeInferenceMethod);
473 layer->SetAllowExpandedDims(m_AllowExpandedDims);
480 template <
typename LayerT,
typename... Args>
485 const Iterator pos = (parentOut !=
nullptr)
488 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
489 insertBefore.
Insert(*layer);
496 template <
typename LayerT,
typename... Args>
502 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
504 if (layer->GetNumInputSlots() != 1)
510 insertAfter.
Connect(layer->GetInputSlot(0));
524 template <
typename LayerT>