ArmNN
 25.11
Loading...
Searching...
No Matches
PermuteAsReshape.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
15{
16public:
17 /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
18 void Run(Graph& graph, PermuteLayer& permute) const
19 {
20 if (IsReshape(permute))
21 {
22 const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo();
23
24 const std::string name = std::string("as_reshape-") + permute.GetName();
25 const ReshapeDescriptor descriptor{outInfo.GetShape()};
26 // Inserts NewLayer so layers don't need to be re-sorted.
27 auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str());
28
29 // Bypass permute. It will be deleted since it's left unconnected.
30 permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31 }
32 }
33
34protected:
37
38private:
39 static bool IsReshape(const PermuteLayer& layer)
40 {
41 const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42 const PermutationVector& permutation = layer.GetPermutation();
43
44 const unsigned int numDimensions = permutation.GetSize();
45
46 unsigned int lastGtOne = 0;
47 while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U))
48 {
49 ++lastGtOne;
50 }
51
52 bool isReshape = true;
53 for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
54 {
55 if (outShape[permutation[i]] > 1U)
56 {
57 isReshape = permutation[lastGtOne] < permutation[i];
58 lastGtOne = i;
59 }
60 }
61
62 return isReshape;
63 }
64};
65
67
68} // namespace optimizations
69} // namespace armnn
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition Graph.hpp:481
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition Layer.hpp:245
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition Layer.cpp:156
SizeType GetSize() const
Definition Types.hpp:359
This layer represents a permutation operation.
const PermutationVector & GetPermutation() const
This layer represents a reshape operation.
const TensorShape & GetShape() const
Definition Tensor.hpp:193
void Run(Graph &graph, PermuteLayer &permute) const
Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
OptimizeForType< PermuteLayer, PermuteAsReshapeImpl > PermuteAsReshape
Copyright (c) 2021 ARM Limited and Contributors.
A ReshapeDescriptor for the ReshapeLayer.