ArmNN
 25.11
Loading...
Searching...
No Matches
LayerSupportRules.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <algorithm>
9
10namespace armnn
11{
12
14{
15 if (!weightsType)
16 {
17 return weightsType;
18 }
19
20 switch(weightsType.value())
21 {
24 return weightsType;
30 default:
31 throw InvalidArgumentException("GetBiasTypeFromWeightsType(): Unsupported data type.");
32 }
33 return armnn::EmptyOptional();
34}
35
36template<typename F>
37bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
38{
39 bool supported = rule();
40 if (!supported && reason)
41 {
42 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
43 }
44 return supported;
45}
46
47struct Rule
48{
49 bool operator()() const
50 {
51 return m_Res;
52 }
53
54 bool m_Res = true;
55};
56
57template<typename T>
59{
60 return true;
61}
62
63template<typename T, typename... Rest>
64bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
65{
66 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
67
68 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
69}
70
71struct TypesAreEqual : public Rule
72{
73 template<typename ... Ts>
74 TypesAreEqual(const Ts&... ts)
75 {
77 }
78};
79
81{
83 {
84 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86 }
87};
88
89struct TypeAnyOf : public Rule
90{
91 template<typename Container>
92 TypeAnyOf(const TensorInfo& info, const Container& c)
93 {
94 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
95 {
96 return dt == info.GetDataType();
97 });
98 }
99};
100
101struct TypeIs : public Rule
102{
104 {
105 m_Res = dt == info.GetDataType();
106 }
107};
108
110{
112 {
113 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
114 }
115};
116
118{
119 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
120 {
121 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
122 }
123};
124
126{
127 template<typename Container>
129 {
130 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
131 {
132 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
133 });
134 }
135};
136
137struct ShapesAreSameRank : public Rule
138{
139 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
140 {
142 }
143};
144
146{
147 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
148 {
149 m_Res = info0.GetNumElements() == info1.GetNumElements();
150 }
151};
152
154{
155 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
156 {
157 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
158 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
159 return sizeIn;
160 }
161
163 {
164 const TensorShape& shape0 = in0.GetShape();
165 const TensorShape& shape1 = in1.GetShape();
166 const TensorShape& outShape = out.GetShape();
167
168 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
169 {
170 unsigned int sizeOut = outShape[i];
171 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
172 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
173
174 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
175 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
176 }
177 }
178};
179
181{
182 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
183 {
184 m_Res = info.GetNumDimensions() == expectedNumDimensions;
185 }
186};
187
189{
190 TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
191 {
192 m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
193 }
194};
195
196} //namespace armnn
float GetQuantizationScale() const
Definition Tensor.cpp:461
const TensorShape & GetShape() const
Definition Tensor.hpp:193
int32_t GetQuantizationOffset() const
Definition Tensor.cpp:482
unsigned int GetNumElements() const
Definition Tensor.hpp:198
DataType GetDataType() const
Definition Tensor.hpp:200
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition Tensor.cpp:174
Copyright (c) 2021 ARM Limited and Contributors.
bool AllTypesAreEqualImpl(T)
DataType
Definition Types.hpp:49
armnn::Optional< armnn::DataType > GetBiasTypeFromWeightsType(armnn::Optional< armnn::DataType > weightsType)
bool CheckSupportRule(F rule, Optional< std::string & > reasonIfUnsupported, const char *reason)
BiasAndWeightsTypesCompatible(const TensorInfo &info, const Container &c)
BiasAndWeightsTypesMatch(const TensorInfo &biases, const TensorInfo &weights)
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition Optional.hpp:32
QuantizationParametersAreEqual(const TensorInfo &info0, const TensorInfo &info1)
bool operator()() const
unsigned int CalcInputSize(const TensorShape &in, const TensorShape &out, unsigned int idx)
ShapesAreBroadcastCompatible(const TensorInfo &in0, const TensorInfo &in1, const TensorInfo &out)
ShapesAreSameRank(const TensorInfo &info0, const TensorInfo &info1)
ShapesAreSameTotalSize(const TensorInfo &info0, const TensorInfo &info1)
TensorNumDimensionsAreCorrect(const TensorInfo &info, unsigned int expectedNumDimensions)
TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo &info, unsigned int numDimensionsToCompare)
TypeAnyOf(const TensorInfo &info, const Container &c)
TypeIs(const TensorInfo &info, DataType dt)
TypeNotPerAxisQuantized(const TensorInfo &info)
TypesAreEqual(const Ts &... ts)