ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
TypesUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
9 #include <armnn/Tensor.hpp>
10 #include <armnn/Types.hpp>
11 
12 #include <stdint.h>
13 #include <cmath>
14 #include <ostream>
15 #include <set>
16 #include <type_traits>
17 
18 namespace armnn
19 {
20 
21 constexpr char const* GetStatusAsCString(Status status)
22 {
23  switch (status)
24  {
25  case armnn::Status::Success: return "Status::Success";
26  case armnn::Status::Failure: return "Status::Failure";
27  default: return "Unknown";
28  }
29 }
30 
31 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
32 {
33  switch (activation)
34  {
35  case ActivationFunction::Sigmoid: return "Sigmoid";
36  case ActivationFunction::TanH: return "TanH";
37  case ActivationFunction::Linear: return "Linear";
38  case ActivationFunction::ReLu: return "ReLu";
39  case ActivationFunction::BoundedReLu: return "BoundedReLu";
40  case ActivationFunction::SoftReLu: return "SoftReLu";
41  case ActivationFunction::LeakyReLu: return "LeakyReLu";
42  case ActivationFunction::Abs: return "Abs";
43  case ActivationFunction::Sqrt: return "Sqrt";
44  case ActivationFunction::Square: return "Square";
45  case ActivationFunction::Elu: return "Elu";
46  case ActivationFunction::HardSwish: return "HardSwish";
47  case ActivationFunction::Gelu: return "Gelu";
48  default: return "Unknown";
49  }
50 }
51 
52 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
53 {
54  switch (function)
55  {
56  case ArgMinMaxFunction::Max: return "Max";
57  case ArgMinMaxFunction::Min: return "Min";
58  default: return "Unknown";
59  }
60 }
61 
62 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
63 {
64  switch (operation)
65  {
66  case ComparisonOperation::Equal: return "Equal";
67  case ComparisonOperation::Greater: return "Greater";
68  case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
69  case ComparisonOperation::Less: return "Less";
70  case ComparisonOperation::LessOrEqual: return "LessOrEqual";
71  case ComparisonOperation::NotEqual: return "NotEqual";
72  default: return "Unknown";
73  }
74 }
75 
76 constexpr char const* GetBinaryOperationAsCString(BinaryOperation operation)
77 {
78  switch (operation)
79  {
80  case BinaryOperation::Add: return "Add";
81  case BinaryOperation::Div: return "Div";
82  case BinaryOperation::Maximum: return "Maximum";
83  case BinaryOperation::Minimum: return "Minimum";
84  case BinaryOperation::Mul: return "Mul";
85  case BinaryOperation::Power: return "Power";
86  case BinaryOperation::SqDiff: return "SqDiff";
87  case BinaryOperation::Sub: return "Sub";
88  case BinaryOperation::FloorDiv: return "FloorDiv";
89  default: return "Unknown";
90  }
91 }
92 
93 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
94 {
95  switch (operation)
96  {
97  case UnaryOperation::Abs: return "Abs";
98  case UnaryOperation::Ceil: return "Ceil";
99  case UnaryOperation::Exp: return "Exp";
100  case UnaryOperation::Sqrt: return "Sqrt";
101  case UnaryOperation::Rsqrt: return "Rsqrt";
102  case UnaryOperation::Neg: return "Neg";
103  case UnaryOperation::Log: return "Log";
104  case UnaryOperation::LogicalNot: return "LogicalNot";
105  case UnaryOperation::Sin: return "Sin";
106  default: return "Unknown";
107  }
108 }
109 
111 {
112  switch (operation)
113  {
114  case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd";
115  case LogicalBinaryOperation::LogicalOr: return "LogicalOr";
116  default: return "Unknown";
117  }
118 }
119 
120 constexpr char const* GetFusedTypeAsCString(FusedKernelType type)
121 {
122  switch (type)
123  {
124  case FusedKernelType::AddMulAdd: return "AddMulAdd";
125  default: return "Unknown";
126  }
127 }
128 
129 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
130 {
131  switch (pooling)
132  {
133  case PoolingAlgorithm::Average: return "Average";
134  case PoolingAlgorithm::Max: return "Max";
135  case PoolingAlgorithm::L2: return "L2";
136  default: return "Unknown";
137  }
138 }
139 
141 {
142  switch (rounding)
143  {
144  case OutputShapeRounding::Ceiling: return "Ceiling";
145  case OutputShapeRounding::Floor: return "Floor";
146  default: return "Unknown";
147  }
148 }
149 
150 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
151 {
152  switch (method)
153  {
154  case PaddingMethod::Exclude: return "Exclude";
155  case PaddingMethod::IgnoreValue: return "IgnoreValue";
156  default: return "Unknown";
157  }
158 }
159 
160 constexpr char const* GetPaddingModeAsCString(PaddingMode mode)
161 {
162  switch (mode)
163  {
164  case PaddingMode::Constant: return "Exclude";
165  case PaddingMode::Symmetric: return "Symmetric";
166  case PaddingMode::Reflect: return "Reflect";
167  default: return "Unknown";
168  }
169 }
170 
171 constexpr char const* GetReduceOperationAsCString(ReduceOperation reduce_operation)
172 {
173  switch (reduce_operation)
174  {
175  case ReduceOperation::Sum: return "Sum";
176  case ReduceOperation::Max: return "Max";
177  case ReduceOperation::Mean: return "Mean";
178  case ReduceOperation::Min: return "Min";
179  case ReduceOperation::Prod: return "Prod";
180  default: return "Unknown";
181  }
182 }
183 constexpr unsigned int GetDataTypeSize(DataType dataType)
184 {
185  switch (dataType)
186  {
187  case DataType::BFloat16:
188  case DataType::Float16: return 2U;
189  case DataType::Float32:
190  case DataType::Signed32: return 4U;
191  case DataType::Signed64: return 8U;
192  case DataType::QAsymmU8: return 1U;
193  case DataType::QAsymmS8: return 1U;
194  case DataType::QSymmS8: return 1U;
195  case DataType::QSymmS16: return 2U;
196  case DataType::Boolean: return 1U;
197  default: return 0U;
198  }
199 }
200 
201 template <unsigned N>
202 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
203 {
204  bool isEqual = true;
205  for (unsigned i = 0; isEqual && (i < N); ++i)
206  {
207  isEqual = (strA[i] == strB[i]);
208  }
209  return isEqual;
210 }
211 
212 /// Deprecated function that will be removed together with
213 /// the Compute enum
214 constexpr armnn::Compute ParseComputeDevice(const char* str)
215 {
216  if (armnn::StrEqual(str, "CpuAcc"))
217  {
218  return armnn::Compute::CpuAcc;
219  }
220  else if (armnn::StrEqual(str, "CpuRef"))
221  {
222  return armnn::Compute::CpuRef;
223  }
224  else if (armnn::StrEqual(str, "GpuAcc"))
225  {
226  return armnn::Compute::GpuAcc;
227  }
228  else
229  {
231  }
232 }
233 
234 constexpr const char* GetDataTypeName(DataType dataType)
235 {
236  switch (dataType)
237  {
238  case DataType::Float16: return "Float16";
239  case DataType::Float32: return "Float32";
240  case DataType::Signed64: return "Signed64";
241  case DataType::QAsymmU8: return "QAsymmU8";
242  case DataType::QAsymmS8: return "QAsymmS8";
243  case DataType::QSymmS8: return "QSymmS8";
244  case DataType::QSymmS16: return "QSymm16";
245  case DataType::Signed32: return "Signed32";
246  case DataType::Boolean: return "Boolean";
247  case DataType::BFloat16: return "BFloat16";
248 
249  default:
250  return "Unknown";
251  }
252 }
253 
254 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
255 {
256  switch (dataLayout)
257  {
258  case DataLayout::NCHW: return "NCHW";
259  case DataLayout::NHWC: return "NHWC";
260  case DataLayout::NDHWC: return "NDHWC";
261  case DataLayout::NCDHW: return "NCDHW";
262  default: return "Unknown";
263  }
264 }
265 
267 {
268  switch (channel)
269  {
270  case NormalizationAlgorithmChannel::Across: return "Across";
271  case NormalizationAlgorithmChannel::Within: return "Within";
272  default: return "Unknown";
273  }
274 }
275 
277 {
278  switch (method)
279  {
280  case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
281  case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
282  default: return "Unknown";
283  }
284 }
285 
286 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
287 {
288  switch (method)
289  {
290  case ResizeMethod::Bilinear: return "Bilinear";
291  case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
292  default: return "Unknown";
293  }
294 }
295 
296 constexpr const char* GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)
297 {
298  switch (memBlockStrategyType)
299  {
300  case MemBlockStrategyType::SingleAxisPacking: return "SingleAxisPacking";
301  case MemBlockStrategyType::MultiAxisPacking: return "MultiAxisPacking";
302  default: return "Unknown";
303  }
304 }
305 
306 template<typename T>
308  : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
309 {};
310 
311 template<typename T>
312 constexpr bool IsQuantizedType()
313 {
314  return std::is_integral<T>::value;
315 }
316 
317 constexpr bool IsQuantized8BitType(DataType dataType)
318 {
319  return dataType == DataType::QAsymmU8 ||
320  dataType == DataType::QAsymmS8 ||
321  dataType == DataType::QSymmS8;
322 }
323 
324 constexpr bool IsQuantizedType(DataType dataType)
325 {
326  return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
327 }
328 
329 inline std::ostream& operator<<(std::ostream& os, Status stat)
330 {
331  os << GetStatusAsCString(stat);
332  return os;
333 }
334 
335 
336 inline std::ostream& operator<<(std::ostream& os, const armnn::TensorShape& shape)
337 {
338  os << "[";
340  {
341  for (uint32_t i = 0; i < shape.GetNumDimensions(); ++i)
342  {
343  if (i != 0)
344  {
345  os << ",";
346  }
347  if (shape.GetDimensionSpecificity(i))
348  {
349  os << shape[i];
350  }
351  else
352  {
353  os << "?";
354  }
355  }
356  }
357  else
358  {
359  os << "Dimensionality Not Specified";
360  }
361  os << "]";
362  return os;
363 }
364 
365 /// Quantize a floating point data type into an 8-bit data type.
366 /// @param value - The value to quantize.
367 /// @param scale - The scale (must be non-zero).
368 /// @param offset - The offset.
369 /// @return - The quantized value calculated as round(value/scale)+offset.
370 ///
371 template<typename QuantizedType>
372 QuantizedType Quantize(float value, float scale, int32_t offset);
373 
374 /// Dequantize an 8-bit data type into a floating point data type.
375 /// @param value - The value to dequantize.
376 /// @param scale - The scale (must be non-zero).
377 /// @param offset - The offset.
378 /// @return - The dequantized value calculated as (value-offset)*scale.
379 ///
380 template <typename QuantizedType>
381 float Dequantize(QuantizedType value, float scale, int32_t offset);
382 
383 inline void VerifyTensorInfoDataType(const armnn::TensorInfo & info, armnn::DataType dataType)
384 {
385  if (info.GetDataType() != dataType)
386  {
387  std::stringstream ss;
388  ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
389  << " for tensor:" << info.GetShape()
390  << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
391  throw armnn::Exception(ss.str());
392  }
393 }
394 
395 } //namespace armnn
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:47
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
bool GetDimensionSpecificity(unsigned int i) const
Gets information about if the dimension size has been specified or not.
Definition: Tensor.cpp:211
Dimensionality GetDimensionality() const
Function that returns the tensor type.
Definition: Tensor.hpp:92
Copyright (c) 2021 ARM Limited and Contributors.
constexpr char const * GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation)
Definition: TypesUtils.hpp:110
constexpr char const * GetPaddingMethodAsCString(PaddingMethod method)
Definition: TypesUtils.hpp:150
PaddingMode
The padding mode controls whether the padding should be filled with constant values (Constant),...
Definition: Types.hpp:202
constexpr char const * GetStatusAsCString(Status status)
Definition: TypesUtils.hpp:21
UnaryOperation
Definition: Types.hpp:126
ComparisonOperation
Definition: Types.hpp:110
LogicalBinaryOperation
Definition: Types.hpp:120
PaddingMethod
The padding method modifies the output of pooling layers.
Definition: Types.hpp:190
@ Exclude
The padding fields don't count and are ignored.
@ IgnoreValue
The padding fields count, but are ignored.
constexpr char const * GetReduceOperationAsCString(ReduceOperation reduce_operation)
Definition: TypesUtils.hpp:171
constexpr char const * GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
Definition: TypesUtils.hpp:129
FusedKernelType
Definition: Types.hpp:268
ActivationFunction
Definition: Types.hpp:87
@ BoundedReLu
min(a, max(b, input)) ReLu1 & ReLu6.
std::ostream & operator<<(std::ostream &os, const std::vector< Compute > &compute)
Deprecated function that will be removed together with the Compute enum.
Definition: BackendId.hpp:47
constexpr char const * GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
Definition: TypesUtils.hpp:52
constexpr bool StrEqual(const char *strA, const char(&strB)[N])
Definition: TypesUtils.hpp:202
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
Definition: TypesUtils.hpp:214
Status
enumeration
Definition: Types.hpp:43
constexpr char const * GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
Definition: TypesUtils.hpp:140
constexpr char const * GetUnaryOperationAsCString(UnaryOperation operation)
Definition: TypesUtils.hpp:93
MemBlockStrategyType
Definition: Types.hpp:255
constexpr const char * GetDataTypeName(DataType dataType)
Definition: TypesUtils.hpp:234
float Dequantize(QuantizedType value, float scale, int32_t offset)
Dequantize an 8-bit data type into a floating point data type.
Definition: TypesUtils.cpp:48
constexpr char const * GetFusedTypeAsCString(FusedKernelType type)
Definition: TypesUtils.hpp:120
PoolingAlgorithm
Definition: Types.hpp:152
void VerifyTensorInfoDataType(const armnn::TensorInfo &info, armnn::DataType dataType)
Definition: TypesUtils.hpp:383
ResizeMethod
Definition: Types.hpp:168
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:183
constexpr char const * GetActivationFunctionAsCString(ActivationFunction activation)
Definition: TypesUtils.hpp:31
constexpr const char * GetMemBlockStrategyTypeName(MemBlockStrategyType memBlockStrategyType)
Definition: TypesUtils.hpp:296
constexpr char const * GetComparisonOperationAsCString(ComparisonOperation operation)
Definition: TypesUtils.hpp:62
QuantizedType Quantize(float value, float scale, int32_t offset)
Quantize a floating point data type into an 8-bit data type.
Definition: TypesUtils.cpp:30
ReduceOperation
Definition: Types.hpp:159
NormalizationAlgorithmChannel
Definition: Types.hpp:209
BinaryOperation
Definition: Types.hpp:139
DataLayout
Definition: Types.hpp:63
constexpr bool IsQuantizedType()
Definition: TypesUtils.hpp:312
constexpr const char * GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
Definition: TypesUtils.hpp:276
NormalizationAlgorithmMethod
Definition: Types.hpp:215
@ LocalContrast
Jarret 2009: Local Contrast Normalization.
@ LocalBrightness
Krichevsky 2012: Local Brightness Normalization.
DataType
Definition: Types.hpp:49
constexpr bool IsQuantized8BitType(DataType dataType)
Definition: TypesUtils.hpp:317
constexpr char const * GetPaddingModeAsCString(PaddingMode mode)
Definition: TypesUtils.hpp:160
constexpr const char * GetResizeMethodAsCString(ResizeMethod method)
Definition: TypesUtils.hpp:286
OutputShapeRounding
Definition: Types.hpp:223
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
Definition: BackendId.hpp:22
@ CpuAcc
CPU Execution: NEON: ArmCompute.
@ CpuRef
CPU Execution: Reference C++ kernels.
@ GpuAcc
GPU Execution: OpenCL: ArmCompute.
ArgMinMaxFunction
Definition: Types.hpp:104
constexpr const char * GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
Definition: TypesUtils.hpp:266
constexpr char const * GetBinaryOperationAsCString(BinaryOperation operation)
Definition: TypesUtils.hpp:76
constexpr const char * GetDataLayoutName(DataLayout dataLayout)
Definition: TypesUtils.hpp:254