ArmNN
 25.11
Loading...
Searching...
No Matches
ConvertConstPermuteLayersToConstLayers.hpp
Go to the documentation of this file.
1//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Optimization.hpp"
10#include <ResolveType.hpp>
11
12namespace armnn
13{
14namespace optimizations
15{
16
18{
19public:
20 void Run(Graph& graph, InputSlot& connection) const
21 {
22 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
23 Layer& child = connection.GetOwningLayer();
24
27
28 if (base.GetDataType() == child.GetDataType())
29 {
30 switch (base.GetDataType())
31 {
33 ReplaceConstPermuteLayer<DataType::Float16>(graph,
36 break;
38 ReplaceConstPermuteLayer<DataType::Float32>(graph,
41 break;
43 ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
46 break;
48 ReplaceConstPermuteLayer<DataType::Signed32>(graph,
51 break;
53 ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
56 break;
58 ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
61 break;
63 ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
66 break;
68 ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
71 break;
73 ReplaceConstPermuteLayer<DataType::Signed64>(graph,
76 break;
78 ReplaceConstPermuteLayer<DataType::Boolean>(graph,
81 break;
82 }
83 }
84 }
85protected:
88private:
89 template<armnn::DataType ArmnnType,
91 static void ReplaceConstPermuteLayer(Graph& graph,
92 ConstantLayer* constantLayer,
93 PermuteLayer* permuteLayer)
94 {
95 IgnoreUnused(graph);
96 /**
97 * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
98 * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
99 * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
100 */
101 TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
102 std::vector<T> newValues(outputPermuteInfo.GetNumElements());
103 armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
104 constantLayer->m_LayerOutput->Map(true), newValues.data(),
105 GetDataTypeSize(outputPermuteInfo.GetDataType()));
106
107 TensorInfo newInfo = outputPermuteInfo;
108 newInfo.SetConstant(true);
109 ConstTensor newInput(newInfo, newValues);
110 constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
111
112 // Moves connections in permute output to the constant layer.
113 // Permute layer will be removed if left unconnected.
114 permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
115
116 // Updating the output tensor
117 constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
118 ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
119 }
120};
121
125
126} // namespace optimizations
127} // namespace armnn
#define ARMNN_ASSERT(COND)
Definition Assert.hpp:14
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
A layer that the constant data can be bound to.
std::shared_ptr< ConstTensorHandle > m_LayerOutput
Layer & GetOwningLayer() const
Definition Layer.hpp:53
const OutputSlot * GetConnectedOutputSlot() const
Definition Layer.hpp:56
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition Layer.hpp:339
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition Layer.hpp:286
DataType GetDataType() const
Definition Layer.cpp:345
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition Layer.cpp:156
void SetTensorInfo(const TensorInfo &tensorInfo) override
Definition Layer.cpp:95
Layer & GetOwningLayer() const
Definition Layer.hpp:132
const TensorInfo & GetTensorInfo() const override
Definition Layer.cpp:100
This layer represents a permutation operation.
const PermutationVector & GetPermutation() const
const TensorShape & GetShape() const
Definition Tensor.hpp:193
unsigned int GetNumElements() const
Definition Tensor.hpp:198
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition Tensor.cpp:518
bool IsConstant() const
Definition Tensor.cpp:513
DataType GetDataType() const
Definition Tensor.hpp:200
OptimizeForConnection< ConstantLayer, PermuteLayer, ConvertConstPermuteLayersToConstLayers > FusePermuteIntoConstLayer
Copyright (c) 2021 ARM Limited and Contributors.
typename ResolveTypeImpl< DT >::Type ResolveType
constexpr unsigned int GetDataTypeSize(DataType dataType)
DestType PolymorphicDowncast(SourceType *value)
Polymorphic downcast for build in pointers only.
DataType
Definition Types.hpp:49
void IgnoreUnused(Ts &&...)