ArmNN
 25.11
Loading...
Searching...
No Matches
OptimizeConsecutiveReshapes.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
15{
16public:
17 /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
18 /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19 void Run(Graph& graph, InputSlot& connection) const
20 {
21 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22 Layer& child = connection.GetOwningLayer();
23
26
27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28
29 const TensorInfo& inInfo = parentOut->GetTensorInfo();
30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31
32 // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33 // and no other Layer.
34 if (base.GetOutputSlot(0).GetNumConnections() > 1)
35 {
36 return;
37 }
38
39 if (inInfo.GetShape() != outInfo.GetShape())
40 {
41 // Inserts equivalent reshape before base layer.
42 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43 const ReshapeDescriptor descriptor{outInfo.GetShape()};
44 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45
46 // Parent is now the new layer.
47 parentOut = &newReshape.GetOutputSlot();
48 }
49
50 // Moves connections in child output to parent layer.
51 // Child layer will be removed as it's left unconnected.
52 // Base layer will be removed if left unconnected.
53 child.GetOutputSlot().MoveAllConnections(*parentOut);
54 }
55
56protected:
59};
60
62
63} // namespace optimizations
64} // namespace armnn
#define ARMNN_ASSERT(COND)
Definition Assert.hpp:14
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition Graph.hpp:481
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition Layer.hpp:245
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition Layer.hpp:286
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition Layer.cpp:156
unsigned int GetNumConnections() const override
Definition Layer.hpp:158
Layer & GetOwningLayer() const
Definition Layer.hpp:132
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
This layer represents a reshape operation.
const TensorShape & GetShape() const
Definition Tensor.hpp:193
void Run(Graph &graph, InputSlot &connection) const
Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
OptimizeForConnection< ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl > OptimizeConsecutiveReshapes
Copyright (c) 2021 ARM Limited and Contributors.
A ReshapeDescriptor for the ReshapeLayer.