ArmNN
 25.11
Loading...
Searching...
No Matches
OptimizeConsecutiveReshapesImpl Class Reference

#include <OptimizeConsecutiveReshapes.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every connection between a base ReshapeLayer and a child ReshapeLayer.

Protected Member Functions

 OptimizeConsecutiveReshapesImpl ()=default
 ~OptimizeConsecutiveReshapesImpl ()=default

Detailed Description

Definition at line 14 of file OptimizeConsecutiveReshapes.hpp.

Constructor & Destructor Documentation

◆ OptimizeConsecutiveReshapesImpl()

OptimizeConsecutiveReshapesImpl ( )
protecteddefault

◆ ~OptimizeConsecutiveReshapesImpl()

~OptimizeConsecutiveReshapesImpl ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph & graph,
InputSlot & connection ) const
inline

Run for every connection between a base ReshapeLayer and a child ReshapeLayer.

Inserts an equivalent ReshapeLayer that bypasses both for that connection.

Definition at line 19 of file OptimizeConsecutiveReshapes.hpp.

20 {
21 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22 Layer& child = connection.GetOwningLayer();
23
24 ARMNN_ASSERT(base.GetType() == LayerType::Reshape);
25 ARMNN_ASSERT(child.GetType() == LayerType::Reshape);
26
27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28
29 const TensorInfo& inInfo = parentOut->GetTensorInfo();
30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31
32 // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33 // and no other Layer.
34 if (base.GetOutputSlot(0).GetNumConnections() > 1)
35 {
36 return;
37 }
38
39 if (inInfo.GetShape() != outInfo.GetShape())
40 {
41 // Inserts equivalent reshape before base layer.
42 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43 const ReshapeDescriptor descriptor{outInfo.GetShape()};
44 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45
46 // Parent is now the new layer.
47 parentOut = &newReshape.GetOutputSlot();
48 }
49
50 // Moves connections in child output to parent layer.
51 // Child layer will be removed as it's left unconnected.
52 // Base layer will be removed if left unconnected.
53 child.GetOutputSlot().MoveAllConnections(*parentOut);
54 }
#define ARMNN_ASSERT(COND)
Definition Assert.hpp:14

References ARMNN_ASSERT, InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), Layer::GetName(), OutputSlot::GetNumConnections(), Layer::GetOutputHandler(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), TensorInfo::GetShape(), OutputHandler::GetTensorInfo(), OutputSlot::GetTensorInfo(), Layer::GetType(), Graph::InsertNewLayer(), OutputSlot::MoveAllConnections(), and armnn::Reshape.


The documentation for this class was generated from the following file: