ArmNN
 25.11
Loading...
Searching...
No Matches
MovePermuteUp.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017-2018,2020,2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
11
12namespace armnn
13{
14namespace optimizations
15{
17{
18public:
19 /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type
20 /// of the base layer allows it, it moves the permutation to the inputs of the base layer.
21 /// I.e., adds equivalent permutations before the inputs of the base layer and moves the
22 /// connections in the output of the child permute layer to the output of the base layer.
23 void Run(Graph& graph, InputSlot& connection) const
24 {
25 OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
26
27 if (baseOutput.GetNumConnections() == 1U)
28 {
29 Layer& base = baseOutput.GetOwningLayer();
30
31 if (CanMovePermuteToInputs(base))
32 {
33 auto permute = PolymorphicDowncast<PermuteLayer*>(&connection.GetOwningLayer());
34 const PermutationVector& perm = permute->GetPermutation();
35
36 // Inserts an equivalent permute before every input of the base layer.
37 for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
38 {
39 // Inserts a new permute layer.
40 const std::string name = std::string("moved_up-") + permute->GetName();
41 PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str());
42
43 // Sets output tensor info for the new layer.
44 OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
45 const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm);
46 permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
47 }
48
49 // Bypasses permute. It will be removed as it's left unconnected.
50 permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
51 }
52 }
53 }
54
55protected:
56 MovePermuteUpImpl() = default;
57 ~MovePermuteUpImpl() = default;
58
59private:
60 static bool CanMovePermuteToInputs(const Layer& base)
61 {
62 switch (base.GetType())
63 {
70 return true;
72 {
74 return (descriptor->m_Operation == BinaryOperation::Add ||
75 descriptor->m_Operation == BinaryOperation::Mul);
76 }
77 default:
78 return false;
79 }
80 }
81};
82
84
85} // namespace optimizations
86} // namespace armnn
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition Graph.hpp:481
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
std::vector< InputSlot >::iterator EndInputSlots()
Definition Layer.hpp:263
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition Layer.hpp:245
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition Layer.hpp:286
virtual const BaseDescriptor & GetParameters() const override
If the layer has a descriptor return it.
Definition Layer.hpp:378
std::vector< InputSlot >::iterator BeginInputSlots()
Definition Layer.hpp:262
void SetTensorInfo(const TensorInfo &tensorInfo)
Sets the TensorInfo used by this output handler.
unsigned int GetNumConnections() const override
Definition Layer.hpp:158
Layer & GetOwningLayer() const
Definition Layer.hpp:132
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
This layer represents a permutation operation.
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base Layer (any) and a child PermuteLayer.
OptimizeForConnection< Layer, PermuteLayer, MovePermuteUpImpl > MovePermuteUp
Copyright (c) 2021 ARM Limited and Contributors.
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition Permute.cpp:125