ArmNN
 24.08
OptimizeConsecutiveReshapesImpl Class Reference

#include <OptimizeConsecutiveReshapes.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every connection between a base ReshapeLayer and a child ReshapeLayer. More...
 

Protected Member Functions

 OptimizeConsecutiveReshapesImpl ()=default
 
 ~OptimizeConsecutiveReshapesImpl ()=default
 

Detailed Description

Definition at line 14 of file OptimizeConsecutiveReshapes.hpp.

Constructor & Destructor Documentation

◆ OptimizeConsecutiveReshapesImpl()

OptimizeConsecutiveReshapesImpl ( )
protecteddefault

◆ ~OptimizeConsecutiveReshapesImpl()

~OptimizeConsecutiveReshapesImpl ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every connection between a base ReshapeLayer and a child ReshapeLayer.

Inserts an equivalent ReshapeLayer that bypasses both for that connection.

Definition at line 19 of file OptimizeConsecutiveReshapes.hpp.

20  {
21  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22  Layer& child = connection.GetOwningLayer();
23 
24  ARMNN_ASSERT(base.GetType() == LayerType::Reshape);
25  ARMNN_ASSERT(child.GetType() == LayerType::Reshape);
26 
27  OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28 
29  const TensorInfo& inInfo = parentOut->GetTensorInfo();
30  const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31 
32  // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33  // and no other Layer.
34  if (base.GetOutputSlot(0).GetNumConnections() > 1)
35  {
36  return;
37  }
38 
39  if (inInfo.GetShape() != outInfo.GetShape())
40  {
41  // Inserts equivalent reshape before base layer.
42  const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43  const ReshapeDescriptor descriptor{outInfo.GetShape()};
44  auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45 
46  // Parent is now the new layer.
47  parentOut = &newReshape.GetOutputSlot();
48  }
49 
50  // Moves connections in child output to parent layer.
51  // Child layer will be removed as it's left unconnected.
52  // Base layer will be removed if left unconnected.
53  child.GetOutputSlot().MoveAllConnections(*parentOut);
54  }

References ARMNN_ASSERT, InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), Layer::GetName(), OutputSlot::GetNumConnections(), Layer::GetOutputHandler(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), TensorInfo::GetShape(), OutputHandler::GetTensorInfo(), OutputSlot::GetTensorInfo(), Layer::GetType(), Graph::InsertNewLayer(), OutputSlot::MoveAllConnections(), and armnn::Reshape.


The documentation for this class was generated from the following file:
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
armnn::LayerType::Reshape
@ Reshape