ArmNN
 24.08
PermuteAsReshape.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
15 {
16 public:
17  /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
18  void Run(Graph& graph, PermuteLayer& permute) const
19  {
20  if (IsReshape(permute))
21  {
22  const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo();
23 
24  const std::string name = std::string("as_reshape-") + permute.GetName();
25  const ReshapeDescriptor descriptor{outInfo.GetShape()};
26  // Inserts NewLayer so layers don't need to be re-sorted.
27  auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str());
28 
29  // Bypass permute. It will be deleted since it's left unconnected.
30  permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31  }
32  }
33 
34 protected:
35  PermuteAsReshapeImpl() = default;
36  ~PermuteAsReshapeImpl() = default;
37 
38 private:
39  static bool IsReshape(const PermuteLayer& layer)
40  {
41  const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42  const PermutationVector& permutation = layer.GetPermutation();
43 
44  const unsigned int numDimensions = permutation.GetSize();
45 
46  unsigned int lastGtOne = 0;
47  while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U))
48  {
49  ++lastGtOne;
50  }
51 
52  bool isReshape = true;
53  for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
54  {
55  if (outShape[permutation[i]] > 1U)
56  {
57  isReshape = permutation[lastGtOne] < permutation[i];
58  lastGtOne = i;
59  }
60  }
61 
62  return isReshape;
63  }
64 };
65 
67 
68 } // namespace optimizations
69 } // namespace armnn
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::optimizations::PermuteAsReshapeImpl::Run
void Run(Graph &graph, PermuteLayer &permute) const
Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
Definition: PermuteAsReshape.hpp:18
armnn::Layer::GetOutputSlot
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:339
Optimization.hpp
armnn::Layer::GetInputSlot
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:337
armnn::Layer::GetName
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:332
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::ReshapeLayer
This layer represents a reshape operation.
Definition: ReshapeLayer.hpp:15
armnn::optimizations::PermuteAsReshapeImpl::PermuteAsReshapeImpl
PermuteAsReshapeImpl()=default
armnn::ReshapeDescriptor
A ReshapeDescriptor for the ReshapeLayer.
Definition: Descriptors.hpp:1023
armnn::Layer::GetOutputHandler
const OutputHandler & GetOutputHandler(unsigned int i=0) const
Definition: Layer.hpp:245
armnn::PermutationVector
Definition: Types.hpp:314
armnn::optimizations::PermuteAsReshapeImpl::~PermuteAsReshapeImpl
~PermuteAsReshapeImpl()=default
armnn::OutputSlot::MoveAllConnections
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:156
armnn::OptimizeForType
Definition: Optimization.hpp:67
armnn::PermutationVector::GetSize
SizeType GetSize() const
Definition: Types.hpp:357
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::optimizations::PermuteAsReshapeImpl
Definition: PermuteAsReshape.hpp:14
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::PermuteLayer
This layer represents a permutation operation.
Definition: PermuteLayer.hpp:15
armnn::OutputHandler::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
Definition: OutputHandler.hpp:42
armnn::Graph
Definition: Graph.hpp:30
armnn::Graph::InsertNewLayer
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:481
armnn::PermuteLayer::GetPermutation
const PermutationVector & GetPermutation() const
Definition: PermuteLayer.hpp:38