12 #include <fmt/format.h>
65 fmt::format(
"Failed to create numpy header info at {}",
70 if (compare_result != 0) {
71 throw armnn::Exception(fmt::format(
"Numpy does not contain magic string. Can not parse invalid numpy {}",
83 if (*(
reinterpret_cast<char *
>(&i)) == 1)
100 if (*(
reinterpret_cast<char *
>(&i)) == 1)
126 std::string substringStart,
127 std::string substringEnd,
128 bool removeStartChar = 0,
129 bool includeEndChar = 0)
131 size_t startPos = fullString.find(substringStart);
132 size_t endPos = fullString.find(substringEnd, startPos);
133 if (startPos == std::string::npos || endPos == std::string::npos)
142 startPos+= removeStartChar;
143 endPos += includeEndChar;
144 return fullString.substr(startPos, endPos - startPos);
149 std::istringstream shapeStringStream(shapeString);
151 while(getline(shapeStringStream, token,
','))
153 header.
m_Shape.push_back(
static_cast<uint32_t
>(std::stoi(token)));
160 ifStream.read(stringBuffer, headerInfo.
m_HeaderLen);
174 std::string::iterator whitespaceSubstringStart = std::remove(header.
m_HeaderString.begin(),
196 fortranOrderString =
getSubstring(fortranOrderString,
":",
"e", 1, 1);
197 header.
m_FortranOrder = fortranOrderString.find(
"True") != std::string::npos ? true :
false;
205 shapeString =
getSubstring(shapeString,
"(",
")", 1, 0);
210 inline void ReadData(std::ifstream& ifStream, T* tensor,
const unsigned int& numElements)
212 ifStream.read(
reinterpret_cast<char *
>(tensor),
sizeof(T) * numElements);
218 if(descr.find(
"f4") != std::string::npos || descr.find(
"f8") != std::string::npos)
222 else if (descr.find(
"f2") != std::string::npos)
226 else if (descr.find(
"i8") != std::string::npos)
230 else if (descr.find(
"i4") != std::string::npos)
234 else if (descr.find(
"i2") != std::string::npos)
238 else if (descr.find(
"i1") != std::string::npos)
242 else if (descr.find(
"u1") != std::string::npos)
258 return "f" + std::to_string(
sizeof(
float));
273 throw armnn::Exception(fmt::format(
"ArmNN to Numpy data type:{} not supported. {}",
278 template <
typename T>
281 if(descr.find(
"f4") != std::string::npos || descr.find(
"f8") != std::string::npos)
283 return std::is_same<T, float>::value;
285 else if (descr.find(
"i8") != std::string::npos)
287 return std::is_same<T, int64_t>::value;
289 else if (descr.find(
"i4") != std::string::npos)
291 return std::is_same<T, int32_t>::value;
293 else if (descr.find(
"i2") != std::string::npos)
295 return std::is_same<T, int16_t>::value;
297 else if (descr.find(
"i1") != std::string::npos)
299 return std::is_same<T, int8_t>::value;
301 else if (descr.find(
"u1") != std::string::npos)
303 return std::is_same<T, uint8_t>::value;
314 unsigned int numEls = 1;
327 const T*
const array,
328 const unsigned int numElements,
332 std::ofstream out(outputTensorFileName, std::ofstream::binary);
337 std::string shapeStr =
"(";
340 shapeStr = shapeStr + std::to_string(shape[i]) +
", ";
342 shapeStr = shapeStr + std::to_string(shape[shape.
GetNumDimensions()-1]) +
")";
345 std::string endianChar = (*(
reinterpret_cast<char *
>(&i))) ?
"<" :
">";
347 std::string fortranOrder =
"False";
348 std::string headerStr =
"{'descr': '" + endianChar + dataTypeStr +
349 "', 'fortran_order': " + fortranOrder +
350 ", 'shape': " + shapeStr +
", }";
361 uint8_t major_version = 1;
364 if (length >= 255 * 255)
371 size_t padding_length = 16 - length % 16;
372 std::string padding(padding_length,
' ');
376 out.put(major_version);
380 if (major_version == 1)
382 auto header_len =
static_cast<uint16_t
>(headerStr.length() + padding.length() + 1);
384 std::array<uint8_t, 2> header_len_16{
static_cast<uint8_t
>((header_len >> 0) & 0xff),
385 static_cast<uint8_t
>((header_len >> 8) & 0xff)};
386 out.write(
reinterpret_cast<char *
>(header_len_16.data()), 2);
390 auto header_len =
static_cast<uint32_t
>(headerStr.length() + padding.length() + 1);
392 std::array<uint8_t, 4> header_len_32{
393 static_cast<uint8_t
>((header_len >> 0) & 0xff),
static_cast<uint8_t
>((header_len >> 8) & 0xff),
394 static_cast<uint8_t
>((header_len >> 16) & 0xff),
static_cast<uint8_t
>((header_len >> 24) & 0xff)};
395 out.write(
reinterpret_cast<char *
>(header_len_32.data()), 4);
398 out << headerStr << padding <<
'\n';
402 out.write(
reinterpret_cast<const char *
>(array),
sizeof(T) * numElements);