126 std::string substringStart,
127 std::string substringEnd,
128 bool removeStartChar = 0,
129 bool includeEndChar = 0)
131 size_t startPos = fullString.find(substringStart);
132 size_t endPos = fullString.find(substringEnd, startPos);
133 if (startPos == std::string::npos || endPos == std::string::npos)
142 startPos+= removeStartChar;
143 endPos += includeEndChar;
144 return fullString.substr(startPos, endPos - startPos);
160 ifStream.read(stringBuffer,
static_cast<std::streamsize
>(headerInfo.
m_HeaderLen));
174 std::string::iterator whitespaceSubstringStart = std::remove(header.
m_HeaderString.begin(),
196 fortranOrderString =
getSubstring(fortranOrderString,
":",
"e", 1, 1);
197 header.
m_FortranOrder = fortranOrderString.find(
"True") != std::string::npos ? true :
false;
205 shapeString =
getSubstring(shapeString,
"(",
")", 1, 0);
283 if(descr.find(
"f4") != std::string::npos || descr.find(
"f8") != std::string::npos)
285 return std::is_same<T, float>::value;
287 else if (descr.find(
"i8") != std::string::npos)
289 return std::is_same<T, int64_t>::value;
291 else if (descr.find(
"i4") != std::string::npos)
293 return std::is_same<T, int32_t>::value;
295 else if (descr.find(
"i2") != std::string::npos)
297 return std::is_same<T, int16_t>::value;
299 else if (descr.find(
"i1") != std::string::npos)
301 return std::is_same<T, int8_t>::value;
303 else if (descr.find(
"u1") != std::string::npos)
305 return std::is_same<T, uint8_t>::value;
329 const T*
const array,
330 const unsigned int numElements,
334 std::ofstream out(outputTensorFileName, std::ofstream::binary);
339 std::string shapeStr =
"(";
342 shapeStr = shapeStr + std::to_string(shape[i]) +
", ";
344 shapeStr = shapeStr + std::to_string(shape[shape.
GetNumDimensions()-1]) +
")";
347 std::string endianChar = (*(
reinterpret_cast<char *
>(&i))) ?
"<" :
">";
349 std::string fortranOrder =
"False";
350 std::string headerStr =
"{'descr': '" + endianChar + dataTypeStr +
351 "', 'fortran_order': " + fortranOrder +
352 ", 'shape': " + shapeStr +
", }";
363 uint8_t major_version = 1;
366 if (length >= 255 * 255)
373 size_t padding_length = 16 - length % 16;
374 std::string padding(padding_length,
' ');
378 out.put(major_version);
382 if (major_version == 1)
384 auto header_len =
static_cast<uint16_t
>(headerStr.length() + padding.length() + 1);
386 std::array<uint8_t, 2> header_len_16{
static_cast<uint8_t
>((header_len >> 0) & 0xff),
387 static_cast<uint8_t
>((header_len >> 8) & 0xff)};
388 out.write(
reinterpret_cast<char *
>(header_len_16.data()), 2);
392 auto header_len =
static_cast<uint32_t
>(headerStr.length() + padding.length() + 1);
394 std::array<uint8_t, 4> header_len_32{
395 static_cast<uint8_t
>((header_len >> 0) & 0xff),
static_cast<uint8_t
>((header_len >> 8) & 0xff),
396 static_cast<uint8_t
>((header_len >> 16) & 0xff),
static_cast<uint8_t
>((header_len >> 24) & 0xff)};
397 out.write(
reinterpret_cast<char *
>(header_len_32.data()), 4);
400 out << headerStr << padding <<
'\n';
404 out.write(
reinterpret_cast<const char *
>(array),
sizeof(T) * numElements);