ArmNN
 24.08
MovePermuteUpImpl Class Reference

#include <MovePermuteUp.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 Run for every connection between a base Layer (any) and a child PermuteLayer. More...
 

Protected Member Functions

 MovePermuteUpImpl ()=default
 
 ~MovePermuteUpImpl ()=default
 

Detailed Description

Definition at line 16 of file MovePermuteUp.hpp.

Constructor & Destructor Documentation

◆ MovePermuteUpImpl()

MovePermuteUpImpl ( )
protecteddefault

◆ ~MovePermuteUpImpl()

~MovePermuteUpImpl ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Run for every connection between a base Layer (any) and a child PermuteLayer.

If the type of the base layer allows it, it moves the permutation to the inputs of the base layer. I.e., adds equivalent permutations before the inputs of the base layer and moves the connections in the output of the child permute layer to the output of the base layer.

Definition at line 23 of file MovePermuteUp.hpp.

24  {
25  OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
26 
27  if (baseOutput.GetNumConnections() == 1U)
28  {
29  Layer& base = baseOutput.GetOwningLayer();
30 
31  if (CanMovePermuteToInputs(base))
32  {
33  auto permute = PolymorphicDowncast<PermuteLayer*>(&connection.GetOwningLayer());
34  const PermutationVector& perm = permute->GetPermutation();
35 
36  // Inserts an equivalent permute before every input of the base layer.
37  for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
38  {
39  // Inserts a new permute layer.
40  const std::string name = std::string("moved_up-") + permute->GetName();
41  PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str());
42 
43  // Sets output tensor info for the new layer.
44  OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
45  const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm);
46  permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
47  }
48 
49  // Bypasses permute. It will be removed as it's left unconnected.
50  permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
51  }
52  }
53  }

References Layer::BeginInputSlots(), Layer::EndInputSlots(), InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), OutputSlot::GetNumConnections(), Layer::GetOutputHandler(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), Graph::InsertNewLayer(), armnnUtils::Permuted(), and OutputHandler::SetTensorInfo().


The documentation for this class was generated from the following file:
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:125