ArmNN
 24.08
TosaOperatorUtils.hpp File Reference
#include <Layer.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
#include "common/include/ProfilingGuid.hpp"
#include <tosa_serialization_handler.h>
Include dependency graph for TosaOperatorUtils.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

DType ArmNNToDType (const DataType &type)
 
DataType DtypeToArmNN (const DType type)
 
std::vector< int32_t > GetTosaTensorShape (const TensorShape &shape)
 
std::string GenerateUniqueInputName (const armnn::InputSlot &slot)
 
std::string GenerateUniqueOutputName (const Layer &layer, uint32_t layerSlot=0)
 
std::string GetUniqueTosaMappingID ()
 
std::string TosaDTypeToString (DType tosaDType)
 
std::string TosaOpToString (Op tosaOp)
 
std::vector< uint8_t > ConvertConstantTensorDataToBuffer (const std::shared_ptr< ConstTensorHandle > &tensorHandle)
 
std::vector< uint8_t > CreateConstTosaData (const void *value, DType dtype, const std::vector< int32_t > &shape)
 
template<typename T >
void CreateConstTosaOperator (const std::string &outputName, const T value, DType dtype, const std::vector< int32_t > &shape, TosaSerializationOperator *&op, TosaSerializationTensor *&tensor)
 

Variables

const std::string mainName = "main"
 

Function Documentation

◆ ArmNNToDType()

DType ArmNNToDType ( const DataType type)
inline

Definition at line 22 of file TosaOperatorUtils.hpp.

23 {
24  switch (type)
25  {
26  case DataType::Float16:
27  return DType_FP16;
28  case DataType::BFloat16:
29  return DType_BF16;
30  case DataType::Float32:
31  return DType_FP32;
32  case DataType::QAsymmU8:
33  return DType_UINT8;
34  case DataType::QSymmS8:
35  case DataType::QAsymmS8:
36  return DType_INT8;
37  case DataType::QSymmS16:
38  return DType_INT16;
39  case DataType::Signed32:
40  return DType_INT32;
41  case DataType::Signed64:
42  // No signed 64, only DType_INT48.
43  return DType_UNKNOWN;
44  case DataType::Boolean:
45  return DType_BOOL;
46  default:
47  return DType_UNKNOWN;
48  }
49 }

References armnn::BFloat16, armnn::Boolean, armnn::Float16, armnn::Float32, armnn::QAsymmS8, armnn::QAsymmU8, armnn::QSymmS16, armnn::QSymmS8, armnn::Signed32, and armnn::Signed64.

Referenced by ConvertAvgPool2DIgnoreValueToTosaOperator(), ConvertBatchMatMulToTosaOperator(), ConvertConv2dToTosaOperator(), ConvertDepthwiseConv2dToTosaOperator(), ConvertElementwiseBinaryToTosaOperator(), ConvertFullyConnectedToTosaOperator(), ConvertPooling2DToTosaOperator(), ConvertQuantizeToTosaOperator(), and ConvertSoftmaxToTosaOperator().

◆ ConvertConstantTensorDataToBuffer()

std::vector<uint8_t> ConvertConstantTensorDataToBuffer ( const std::shared_ptr< ConstTensorHandle > &  tensorHandle)
inline

Definition at line 333 of file TosaOperatorUtils.hpp.

334 {
335  tosa_err_t error = tosa_err_t::TOSA_OK;
336  std::vector<uint8_t> uint8Data;
337  auto tensorInfo = tensorHandle->GetTensorInfo();
338 
339  switch (tensorInfo.GetDataType())
340  {
341  case DataType::Float32:
342  {
343  std::vector<float> data(tensorInfo.GetNumElements());
344  memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
345 
346  error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
347  break;
348  }
349  case DataType::Float16:
350  {
351  std::vector<float> data(tensorInfo.GetNumElements());
352  memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
353 
354  error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
355  break;
356  }
357  case DataType::QSymmS8:
358  case DataType::QAsymmS8:
359  {
360  std::vector<int8_t> data(tensorInfo.GetNumElements());
361  memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
362 
363  error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
364  break;
365  }
366  case DataType::QAsymmU8:
367  {
368  memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
369  break;
370  }
371  case DataType::QSymmS16:
372  {
373  std::vector<int16_t> data(tensorInfo.GetNumElements());
374  memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
375 
376  error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
377  break;
378  }
379  case DataType::Signed32:
380  {
381  std::vector<int32_t> data(tensorInfo.GetNumElements());
382  memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
383 
384  error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
385  break;
386  }
387  default:
388  {
389  throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
390  }
391  }
392 
393  if(error != tosa_err_t::TOSA_OK)
394  {
395  throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
396  }
397 
398  tensorHandle->Unmap();
399  return uint8Data;
400 }

References armnn::error, armnn::Float16, armnn::Float32, armnn::QAsymmS8, armnn::QAsymmU8, armnn::QSymmS16, armnn::QSymmS8, and armnn::Signed32.

Referenced by ConvertConstantToTosaOperator().

◆ CreateConstTosaData()

std::vector<uint8_t> CreateConstTosaData ( const void *  value,
DType  dtype,
const std::vector< int32_t > &  shape 
)
inline

Definition at line 402 of file TosaOperatorUtils.hpp.

405 {
406  std::vector<uint8_t> uint8Data;
407  tosa_err_t error = tosa_err_t::TOSA_OK;
408 
409  unsigned int numElements = 1;
410  for (auto s : shape)
411  {
412  if (s < 0)
413  {
414  throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled.");
415  }
416  numElements = numElements * static_cast<unsigned int>(s);
417  }
418 
419  switch (dtype)
420  {
421  case DType::DType_FP32:
422  {
423  std::vector<float> data(numElements, *static_cast<const float*>(value));
424  error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
425  break;
426  }
427  case DType::DType_FP16:
428  {
429  std::vector<float> data(numElements, *static_cast<const float*>(value));
430  error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
431  break;
432  }
433  case DType::DType_INT48:
434  {
435  std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value));
436  error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
437  break;
438  }
439  case DType::DType_INT32:
440  {
441  std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value));
442  error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
443  break;
444  }
445  case DType::DType_INT16:
446  {
447  std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value));
448  error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
449  break;
450  }
451  case DType::DType_INT8:
452  {
453  std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
454  error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
455  break;
456  }
457  case DType::DType_UINT8:
458  {
459  const int8_t* copy_data = static_cast<const int8_t*>(value);
460  uint8Data.assign(copy_data, copy_data + numElements);
461  break;
462  }
463  case DType::DType_INT4:
464  {
465  std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
466  error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
467  break;
468  }
469  case DType::DType_BOOL:
470  {
471  std::vector<bool> data(numElements, *static_cast<const bool*>(value));
472  error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
473  break;
474  }
475  default:
476  {
477  throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered.");
478  }
479  }
480 
481  if(error != tosa_err_t::TOSA_OK)
482  {
483  throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data");
484  }
485 
486  return uint8Data;
487 }

References armnn::error.

Referenced by CreateConstTosaOperator().

◆ CreateConstTosaOperator()

void CreateConstTosaOperator ( const std::string &  outputName,
const T  value,
DType  dtype,
const std::vector< int32_t > &  shape,
TosaSerializationOperator *&  op,
TosaSerializationTensor *&  tensor 
)
inline

Definition at line 490 of file TosaOperatorUtils.hpp.

496 {
497  std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape);
498 
499  op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName});
500  ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator");
501 
502  tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data);
503  ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor");
504 }

References ARMNN_THROW_MSG_IF_FALSE, and CreateConstTosaData().

◆ DtypeToArmNN()

DataType DtypeToArmNN ( const DType  type)
inline

Definition at line 52 of file TosaOperatorUtils.hpp.

53 {
54  switch (type)
55  {
56  case DType_FP16:
57  return DataType::Float16;
58  case DType_BF16:
59  return DataType::BFloat16;
60  case DType_FP32:
61  return DataType::Float32;
62  case DType_UINT8:
63  return DataType::QAsymmU8;
64  case DType_INT8:
65  return DataType::QSymmS8;
66  case DType_INT16:
67  return DataType::QSymmS16;
68  case DType_INT32:
69  return DataType::Signed32;
70  case DType_BOOL:
71  return DataType::Boolean;
72  default:
73  throw armnn::Exception("DtypeToArmNN: Unsupported tosa::DType in ArmNN.");
74  return DataType::Boolean;
75  }
76 }

References armnn::BFloat16, armnn::Boolean, armnn::Float16, armnn::Float32, armnn::QAsymmU8, armnn::QSymmS16, armnn::QSymmS8, and armnn::Signed32.

◆ GenerateUniqueInputName()

std::string GenerateUniqueInputName ( const armnn::InputSlot slot)
inline

◆ GenerateUniqueOutputName()

std::string GenerateUniqueOutputName ( const Layer layer,
uint32_t  layerSlot = 0 
)
inline

◆ GetTosaTensorShape()

◆ GetUniqueTosaMappingID()

◆ TosaDTypeToString()

std::string TosaDTypeToString ( DType  tosaDType)
inline

Definition at line 144 of file TosaOperatorUtils.hpp.

145 {
146  switch (tosaDType)
147  {
148  case DType_UNKNOWN:
149  return "DType_UNKNOWN";
150  case DType_BOOL:
151  return "DType_BOOL";
152  case DType_UINT8:
153  return "DType_UINT8";
154  case DType_INT4:
155  return "DType_INT4";
156  case DType_INT8:
157  return "DType_INT8";
158  case DType_INT16:
159  return "DType_INT16";
160  case DType_INT32:
161  return "DType_INT32";
162  case DType_INT48:
163  return "DType_INT48";
164  case DType_FP32:
165  return "DType_FP32";
166  case DType_UINT16:
167  return "DType_UINT16";
168  case DType_FP16:
169  return "DType_FP16";
170  case DType_BF16:
171  return "DType_BF16";
172  case DType_SHAPE:
173  return "DType_SHAPE";
174  }
175  return "";
176 }

◆ TosaOpToString()

std::string TosaOpToString ( Op  tosaOp)
inline

Definition at line 179 of file TosaOperatorUtils.hpp.

180 {
181  switch (tosaOp)
182  {
183  case Op_ADD:
184  return "Op_ADD";
185  case Op_AVG_POOL2D:
186  return "Op_AVG_POOL2D";
187  case Op_MAX_POOL2D:
188  return "Op_MAX_POOL2D";
189  case Op_PAD:
190  return "Op_PAD";
191  case Op_UNKNOWN:
192  return "Op_UNKNOWN";
193  case Op_ARGMAX:
194  return "Op_ARGMAX";
195  case Op_CONV2D:
196  return "Op_CONV2D";
197  case Op_CONV3D:
198  return "Op_CONV3D";
199  case Op_DEPTHWISE_CONV2D:
200  return "Op_DEPTHWISE_CONV2D";
201  case Op_FULLY_CONNECTED:
202  return "Op_FULLY_CONNECTED";
203  case Op_MATMUL:
204  return "Op_MATMUL";
205  case Op_TRANSPOSE_CONV2D:
206  return "Op_TRANSPOSE_CONV2D";
207  case Op_CLAMP:
208  return "Op_CLAMP";
209  case Op_RESERVED:
210  return "Op_RESERVED";
211  case Op_SIGMOID:
212  return "Op_SIGMOID";
213  case Op_TANH:
214  return "Op_TANH";
215  case Op_ARITHMETIC_RIGHT_SHIFT:
216  return "Op_ARITHMETIC_RIGHT_SHIFT";
217  case Op_BITWISE_AND:
218  return "Op_BITWISE_AND";
219  case Op_BITWISE_OR:
220  return "Op_BITWISE_OR";
221  case Op_BITWISE_XOR:
222  return "Op_BITWISE_XOR";
223  case Op_INTDIV:
224  return "Op_INTDIV";
225  case Op_LOGICAL_AND:
226  return "Op_LOGICAL_AND";
227  case Op_LOGICAL_LEFT_SHIFT:
228  return "Op_LOGICAL_LEFT_SHIFT";
229  case Op_LOGICAL_RIGHT_SHIFT:
230  return "Op_LOGICAL_RIGHT_SHIFT";
231  case Op_LOGICAL_OR:
232  return "Op_LOGICAL_OR";
233  case Op_LOGICAL_XOR:
234  return "Op_LOGICAL_XOR";
235  case Op_MAXIMUM:
236  return "Op_MAXIMUM";
237  case Op_MINIMUM:
238  return "Op_MINIMUM";
239  case Op_MUL:
240  return "Op_MUL";
241  case Op_POW:
242  return "Op_POW";
243  case Op_SUB:
244  return "Op_SUB";
245  case Op_TABLE:
246  return "Op_TABLE";
247  case Op_ABS:
248  return "Op_ABS";
249  case Op_BITWISE_NOT:
250  return "Op_BITWISE_NOT";
251  case Op_CEIL:
252  return "Op_CEIL";
253  case Op_CLZ:
254  return "Op_CLZ";
255  case Op_EXP:
256  return "Op_EXP";
257  case Op_FLOOR:
258  return "Op_FLOOR";
259  case Op_LOG:
260  return "Op_LOG";
261  case Op_LOGICAL_NOT:
262  return "Op_LOGICAL_NOT";
263  case Op_NEGATE:
264  return "Op_NEGATE";
265  case Op_RECIPROCAL:
266  return "Op_RECIPROCAL";
267  case Op_RSQRT:
268  return "Op_RSQRT";
269  case Op_SELECT:
270  return "Op_SELECT";
271  case Op_EQUAL:
272  return "Op_EQUAL";
273  case Op_GREATER:
274  return "Op_GREATER";
275  case Op_GREATER_EQUAL:
276  return "Op_GREATER_EQUAL";
277  case Op_REDUCE_ANY:
278  return "Op_REDUCE_ANY";
279  case Op_REDUCE_ALL:
280  return "Op_REDUCE_ALL";
281  case Op_REDUCE_MAX:
282  return "Op_REDUCE_MAX";
283  case Op_REDUCE_MIN:
284  return "Op_REDUCE_MIN";
285  case Op_REDUCE_PRODUCT:
286  return "Op_REDUCE_PRODUCT";
287  case Op_REDUCE_SUM:
288  return "Op_REDUCE_SUM";
289  case Op_CONCAT:
290  return "Op_CONCAT";
291  case Op_RESHAPE:
292  return "Op_RESHAPE";
293  case Op_REVERSE:
294  return "Op_REVERSE";
295  case Op_SLICE:
296  return "Op_SLICE";
297  case Op_TILE:
298  return "Op_TILE";
299  case Op_TRANSPOSE:
300  return "Op_TRANSPOSE";
301  case Op_GATHER:
302  return "Op_GATHER";
303  case Op_SCATTER:
304  return "Op_SCATTER";
305  case Op_RESIZE:
306  return "Op_RESIZE";
307  case Op_CAST:
308  return "Op_CAST";
309  case Op_RESCALE:
310  return "Op_RESCALE";
311  case Op_CONST:
312  return "Op_CONST";
313  case Op_IDENTITY:
314  return "Op_IDENTITY";
315  case Op_CUSTOM:
316  return "Op_CUSTOM";
317  case Op_COND_IF:
318  return "Op_COND_IF";
319  case Op_WHILE_LOOP:
320  return "Op_WHILE_LOOP";
321  case Op_FFT2D:
322  return "Op_FFT2D";
323  case Op_RFFT2D:
324  return "Op_RFFT2D";
325  case Op_ERF:
326  return "Op_ERF";
327  case Op_DIM: // = Op_MAX
328  return "Op_DIM";
329  }
330  return "";
331 }

Variable Documentation

◆ mainName

const std::string mainName = "main"
armnn::InputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:53
armnn::Layer::GetOutputSlot
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:339
armnn::BoostLogSeverityMapping::error
@ error
armnn::Layer
Definition: Layer.hpp:230
armnn::OutputSlot::CalculateIndexOnOwner
unsigned int CalculateIndexOnOwner() const override
Definition: Layer.cpp:172
armnn::OutputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:132
armnn::TensorShape::GetNumDimensions
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
ARMNN_THROW_MSG_IF_FALSE
#define ARMNN_THROW_MSG_IF_FALSE(_cond, _except, _str)
Definition: Exceptions.hpp:206
CreateConstTosaData
std::vector< uint8_t > CreateConstTosaData(const void *value, DType dtype, const std::vector< int32_t > &shape)
Definition: TosaOperatorUtils.hpp:402
armnn::Layer::GetType
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:286
armnn::InputSlot::GetConnectedOutputSlot
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:56
armnn::OutputSlot::GetConnection
const InputSlot * GetConnection(unsigned int index) const override
Definition: Layer.cpp:83