21 #include <unordered_map>
22 #include <unordered_set>
33 template <
typename LayerType>
36 return PolymorphicDowncast<LayerType*>(layer);
39 template <
typename Func>
42 for (
auto it = m_Layers.begin(); it != m_Layers.end(); )
44 auto next = std::next(it);
67 return {
m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
73 &(PtrCast<const InputLayer>) };
87 &(PtrCast<const OutputLayer>) };
92 return {
m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
98 Graph(
bool shapeInferenceMethod =
false,
bool allowExpandedDims =
false)
99 : m_LayersInOrder(true)
100 , m_AllowExpandedDims(allowExpandedDims)
112 *
this = std::move(other);
117 m_InputIds = std::move(other.m_InputIds);
118 m_OutputIds = std::move(other.m_OutputIds);
119 m_LayersInOrder = std::move(other.m_LayersInOrder);
120 m_Views = std::move(other.m_Views);
121 m_Profiler = std::move(other.m_Profiler);
122 m_AllowExpandedDims = other.m_AllowExpandedDims;
123 m_ShapeInferenceMethod = other.m_ShapeInferenceMethod;
124 other.ForEachLayer([
this](
Layer* otherLayer)
126 otherLayer->
Reparent(*
this, m_Layers.end());
148 template <
typename LayerT,
typename... Args>
153 template <
typename LayerT,
typename... Args>
157 template <
typename LayerT,
typename... Args>
165 template <
typename LayerT>
220 m_Views[notifyOnEvent].emplace_back(observable);
224 m_Views[notifyOnEvent].remove(observable);
230 const std::shared_ptr<IProfiler>&
GetProfiler()
const;
235 template <
typename LayerT>
236 class LayerInGraphBase;
238 template <
typename LayerT>
251 while ((it != m_Layers.end()) &&
261 while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() ==
LayerType::Output))
268 void NotifyObservables(
GraphEvent event, Layer* graphState)
271 for (
auto& observable : m_Views[event])
273 observable->Update(graphState);
277 std::unordered_set<LayerBindingId> m_InputIds;
278 std::unordered_set<LayerBindingId> m_OutputIds;
279 std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
286 mutable bool m_LayersInOrder;
288 bool m_AllowExpandedDims;
290 std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
293 std::shared_ptr<IProfiler> m_Profiler;
297 void ConstructErrorMessageForUnconnectedInputs(Layer*
const layer,
298 unsigned int slotIndex);
304 template <
typename LayerT>
305 class Graph::LayerInGraphBase :
public LayerT
308 template <
typename... Args>
309 LayerInGraphBase(
Graph& graph,
Iterator insertBefore, Args&&... args)
310 : LayerT(
std::forward<Args>(args)...), m_Graph(&graph)
312 Insert(*m_Graph, insertBefore);
319 void Reparent(
Graph& destGraph,
Iterator insertBefore)
override
321 Insert(destGraph, insertBefore);
324 m_Graph = &destGraph;
330 graph.m_PosInGraphMap.emplace(
this, graph.m_Layers.emplace(insertBefore,
this));
333 void Remove(
Graph& graph)
335 auto layerIt = graph.GetPosInGraph(*
this);
336 graph.m_Layers.erase(layerIt);
338 const size_t numErased = graph.m_PosInGraphMap.erase(
this);
348 template <
typename LayerT>
349 class Graph::LayerInGraph final :
public LayerInGraphBase<LayerT>
352 template <
typename... Args>
353 LayerInGraph(
Graph& graph, Args&&... args)
354 : LayerInGraphBase<LayerT>(graph,
357 std::forward<Args>(args)...)
360 template <
typename... Args>
361 LayerInGraph(
Graph& graph,
Iterator insertBefore, Args&&... args)
362 : LayerInGraphBase<LayerT>(graph,
364 graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)),
365 std::forward<Args>(args)...)
371 class Graph::LayerInGraph<
ConstantLayer> final :
public LayerInGraphBase<ConstantLayer>
374 template <
typename... Args>
379 std::forward<Args>(args)...)
381 template <
typename... Args>
384 : LayerInGraph(graph,
std::forward<Args>(args)...)
392 class Graph::LayerInGraph<
InputLayer> final :
public LayerInGraphBase<InputLayer>
395 template <
typename... Args>
400 std::forward<Args>(args)...)
402 const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
408 template <
typename... Args>
411 : LayerInGraph(graph,
std::forward<Args>(args)...)
416 const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
424 class Graph::LayerInGraph<
OutputLayer> final :
public LayerInGraphBase<OutputLayer>
427 template <
typename... Args>
432 std::forward<Args>(args)...)
434 const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
442 const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
450 auto it = m_PosInGraphMap.find(&layer);
455 template <
typename LayerT,
typename... Args>
458 m_LayersInOrder = m_LayersInOrder &&
460 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, std::forward<Args>(args)...);
462 layer->SetShapeInferenceMethod(m_ShapeInferenceMethod);
463 layer->SetAllowExpandedDims(m_AllowExpandedDims);
470 template <
typename LayerT,
typename... Args>
475 const Iterator pos = (parentOut !=
nullptr)
478 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
479 insertBefore.
Insert(*layer);
486 template <
typename LayerT,
typename... Args>
492 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
497 insertAfter.
Connect(layer->GetInputSlot(0));
511 template <
typename LayerT>