ArmNN
 25.11
Loading...
Searching...
No Matches
TransposeAsReshape.hpp
Go to the documentation of this file.
1//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
15{
16public:
17 /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
18 void Run(Graph& graph, TransposeLayer& transpose) const
19 {
20 if (IsReshape(transpose))
21 {
22 const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23
24 const std::string name = std::string("as_reshape-") + transpose.GetName();
25 const ReshapeDescriptor descriptor{outInfo.GetShape()};
26 // Inserts NewLayer so layers don't need to be re-sorted.
27 auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
28
29 // Bypass transpose. It will be deleted since it's left unconnected.
30 transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31 }
32 }
33
34protected:
37
38private:
39 static bool IsReshape(const TransposeLayer& layer)
40 {
41 const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42 const PermutationVector& permutation = layer.GetPermutation();
43
44 const unsigned int numDimensions = permutation.GetSize();
45 std::map<unsigned int, unsigned int> permuteMappings;
46 for (unsigned int i = 0; i < permutation.GetSize(); ++i)
47 {
48 permuteMappings[permutation[i]] = i;
49 }
50
51 std::vector<unsigned int> permuteVector;
52 for (unsigned int i = 0; i < permutation.GetSize(); ++i)
53 {
54 permuteVector.push_back(permuteMappings.at(i));
55 }
56
57 unsigned int lastGtOne = 0;
58 while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
59 {
60 ++lastGtOne;
61 }
62
63 bool isReshape = true;
64 for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
65 {
66 if (outShape[permuteVector[i]] > 1U)
67 {
68 isReshape = permuteVector[lastGtOne] < permuteVector[i];
69 lastGtOne = i;
70 }
71 }
72
73 return isReshape;
74 }
75};
76
78
79} // namespace optimizations
80} // namespace armnn
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition Graph.hpp:481
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition Layer.hpp:245
const char * GetName() const override
Returns the name of the layer.
Definition Layer.hpp:332
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition Layer.cpp:156
SizeType GetSize() const
Definition Types.hpp:359
This layer represents a reshape operation.
const TensorShape & GetShape() const
Definition Tensor.hpp:193
This layer represents a transpose operation.
const PermutationVector & GetPermutation() const
void Run(Graph &graph, TransposeLayer &transpose) const
Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
OptimizeForType< TransposeLayer, TransposeAsReshapeImpl > TransposeAsReshape
Copyright (c) 2021 ARM Limited and Contributors.
A ReshapeDescriptor for the ReshapeLayer.