12 #include "common/include/ProfilingGuid.hpp"
14 #include <tosa_serialization_handler.h>
16 using namespace armnn;
81 std::vector<int32_t> returnShape;
84 returnShape.push_back(
static_cast<int32_t
>(shape[i]));
90 static std::string GenerateUniqueName(
const Layer& layer, uint32_t layerSlot)
92 std::string guid = std::to_string(layer.
GetGuid());
93 std::string slotAndGuid = std::to_string(layerSlot) +
"_" + guid;
98 return "input_" + guid;
100 return "output" + slotAndGuid;
102 return "constant_" + guid;
104 return "intermediate" + slotAndGuid;
116 return GenerateUniqueName(connectedLayer, connectedOutputSlotIdx);
128 return GenerateUniqueName(connectedLayer, layerSlot);
132 return GenerateUniqueName(layer, layerSlot);
137 static int uniqueTosaMappingID = 0;
140 return std::to_string(++uniqueTosaMappingID);
149 return "DType_UNKNOWN";
153 return "DType_UINT8";
159 return "DType_INT16";
161 return "DType_INT32";
163 return "DType_INT48";
167 return "DType_UINT16";
173 return "DType_SHAPE";
186 return "Op_AVG_POOL2D";
188 return "Op_MAX_POOL2D";
199 case Op_DEPTHWISE_CONV2D:
200 return "Op_DEPTHWISE_CONV2D";
201 case Op_FULLY_CONNECTED:
202 return "Op_FULLY_CONNECTED";
205 case Op_TRANSPOSE_CONV2D:
206 return "Op_TRANSPOSE_CONV2D";
210 return "Op_RESERVED";
215 case Op_ARITHMETIC_RIGHT_SHIFT:
216 return "Op_ARITHMETIC_RIGHT_SHIFT";
218 return "Op_BITWISE_AND";
220 return "Op_BITWISE_OR";
222 return "Op_BITWISE_XOR";
226 return "Op_LOGICAL_AND";
227 case Op_LOGICAL_LEFT_SHIFT:
228 return "Op_LOGICAL_LEFT_SHIFT";
229 case Op_LOGICAL_RIGHT_SHIFT:
230 return "Op_LOGICAL_RIGHT_SHIFT";
232 return "Op_LOGICAL_OR";
234 return "Op_LOGICAL_XOR";
250 return "Op_BITWISE_NOT";
262 return "Op_LOGICAL_NOT";
266 return "Op_RECIPROCAL";
275 case Op_GREATER_EQUAL:
276 return "Op_GREATER_EQUAL";
278 return "Op_REDUCE_ANY";
280 return "Op_REDUCE_ALL";
282 return "Op_REDUCE_MAX";
284 return "Op_REDUCE_MIN";
285 case Op_REDUCE_PRODUCT:
286 return "Op_REDUCE_PRODUCT";
288 return "Op_REDUCE_SUM";
300 return "Op_TRANSPOSE";
314 return "Op_IDENTITY";
320 return "Op_WHILE_LOOP";
335 tosa_err_t
error = tosa_err_t::TOSA_OK;
336 std::vector<uint8_t> uint8Data;
337 auto tensorInfo = tensorHandle->GetTensorInfo();
339 switch (tensorInfo.GetDataType())
343 std::vector<float> data(tensorInfo.GetNumElements());
344 memcpy(data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
346 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
351 std::vector<float> data(tensorInfo.GetNumElements());
352 memcpy(data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
354 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
360 std::vector<int8_t> data(tensorInfo.GetNumElements());
361 memcpy(data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
363 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
368 memcpy(uint8Data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
373 std::vector<int16_t> data(tensorInfo.GetNumElements());
374 memcpy(data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
376 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
381 std::vector<int32_t> data(tensorInfo.GetNumElements());
382 memcpy(data.data(), tensorHandle->Map(
true), tensorInfo.GetNumBytes());
384 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
389 throw armnn::Exception(
"SetConstantTensorData: An unsupported data type was encountered.");
393 if(
error != tosa_err_t::TOSA_OK)
395 throw armnn::Exception(
"SetConstantTensorData: An error occurred when converting constant data");
398 tensorHandle->Unmap();
404 const std::vector<int32_t>& shape)
406 std::vector<uint8_t> uint8Data;
407 tosa_err_t
error = tosa_err_t::TOSA_OK;
409 unsigned int numElements = 1;
414 throw armnn::Exception(
"CreateConstTosaData: negative shape elements unhandled.");
416 numElements = numElements *
static_cast<unsigned int>(s);
421 case DType::DType_FP32:
423 std::vector<float> data(numElements, *
static_cast<const float*
>(value));
424 error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
427 case DType::DType_FP16:
429 std::vector<float> data(numElements, *
static_cast<const float*
>(value));
430 error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
433 case DType::DType_INT48:
435 std::vector<int64_t> data(numElements, *
static_cast<const int64_t*
>(value));
436 error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
439 case DType::DType_INT32:
441 std::vector<int32_t> data(numElements, *
static_cast<const int32_t*
>(value));
442 error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
445 case DType::DType_INT16:
447 std::vector<int16_t> data(numElements, *
static_cast<const int16_t*
>(value));
448 error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
451 case DType::DType_INT8:
453 std::vector<int8_t> data(numElements, *
static_cast<const int8_t*
>(value));
454 error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
457 case DType::DType_UINT8:
459 const int8_t* copy_data =
static_cast<const int8_t*
>(value);
460 uint8Data.assign(copy_data, copy_data + numElements);
463 case DType::DType_INT4:
465 std::vector<int8_t> data(numElements, *
static_cast<const int8_t*
>(value));
466 error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
469 case DType::DType_BOOL:
471 std::vector<bool> data(numElements, *
static_cast<const bool*
>(value));
472 error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
477 throw armnn::Exception(
"CreateConstTosaData: An unsupported data type was encountered.");
481 if(
error != tosa_err_t::TOSA_OK)
483 throw armnn::Exception(
"CreateConstTosaData: An error occurred when converting constant data");
493 const std::vector<int32_t>& shape,
494 TosaSerializationOperator*& op,
495 TosaSerializationTensor*& tensor)
497 std::vector<uint8_t> uint8Data =
CreateConstTosaData(
static_cast<const void *
>(&value), dtype, shape);
499 op =
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {outputName});
502 tensor =
new TosaSerializationTensor(outputName, shape, dtype, uint8Data);