7 #if defined(ARMNN_ONNX_PARSER)
10 #if defined(ARMNN_SERIALIZER)
13 #if defined(ARMNN_TF_LITE_PARSER)
27 #define CXXOPTS_VECTOR_DELIMITER '.'
28 #include <cxxopts/cxxopts.hpp>
30 #include <fmt/format.h>
41 std::vector<unsigned int> result;
44 while (std::getline(stream, line))
47 for (
const std::string& token : tokens)
53 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
55 catch (
const std::exception&)
57 ARMNN_LOG(error) <<
"'" << token <<
"' is not a valid number. It has been ignored.";
63 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
66 int ParseCommandLineArgs(
int argc,
char* argv[],
67 std::string& modelFormat,
68 std::string& modelPath,
69 std::vector<std::string>& inputNames,
70 std::vector<std::string>& inputTensorShapeStrs,
71 std::vector<std::string>& outputNames,
72 std::string& outputPath,
bool& isModelBinary)
74 cxxopts::Options options(
"ArmNNConverter",
"Convert a neural network model from provided file to ArmNN format.");
77 std::string modelFormatDescription(
"Format of the model file");
78 #if defined(ARMNN_ONNX_PARSER)
79 modelFormatDescription +=
", onnx-binary, onnx-text";
81 #if defined(ARMNN_TF_PARSER)
82 modelFormatDescription +=
", tensorflow-binary, tensorflow-text";
84 #if defined(ARMNN_TF_LITE_PARSER)
85 modelFormatDescription +=
", tflite-binary";
87 modelFormatDescription +=
".";
89 (
"help",
"Display usage information")
90 (
"f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91 (
"m,model-path",
"Path to model file.", cxxopts::value<std::string>(modelPath))
93 (
"i,input-name",
"Identifier of the input tensors in the network. "
94 "Each input must be specified separately.",
95 cxxopts::value<std::vector<std::string>>(inputNames))
96 (
"s,input-tensor-shape",
97 "The shape of the input tensor in the network as a flat array of integers, "
98 "separated by comma. Each input shape must be specified separately after the input name. "
99 "This parameter is optional, depending on the network.",
100 cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
102 (
"o,output-name",
"Identifier of the output tensor in the network.",
103 cxxopts::value<std::vector<std::string>>(outputNames))
105 "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
107 catch (
const std::exception& e)
109 std::cerr << e.what() << std::endl << options.help() << std::endl;
114 cxxopts::ParseResult result = options.parse(argc, argv);
115 if (result.count(
"help"))
117 std::cerr << options.help() << std::endl;
121 std::string mandatorySingleParameters[] = {
"model-format",
"model-path",
"output-name",
"output-path" };
122 bool somethingsMissing =
false;
123 for (
auto param : mandatorySingleParameters)
125 if (result.count(param) != 1)
127 std::cerr <<
"Parameter \'--" << param <<
"\' is required but missing." << std::endl;
128 somethingsMissing =
true;
132 if (result.count(
"input-name") == 0)
134 std::cerr <<
"Parameter \'--" <<
"input-name" <<
"\' must be specified at least once." << std::endl;
135 somethingsMissing =
true;
138 if (result.count(
"input-tensor-shape") > 0)
140 if (result.count(
"input-tensor-shape") != result.count(
"input-name"))
142 std::cerr <<
"When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters "
143 "must be specified." << std::endl;
144 somethingsMissing =
true;
148 if (somethingsMissing)
150 std::cerr << options.help() << std::endl;
154 catch (
const cxxopts::exceptions::exception& e)
156 std::cerr << e.what() << std::endl << std::endl;
160 if (modelFormat.find(
"bin") != std::string::npos)
162 isModelBinary =
true;
164 else if (modelFormat.find(
"text") != std::string::npos)
166 isModelBinary =
false;
170 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Please include 'binary' or 'text'";
180 typedef T parserType;
186 ArmnnConverter(
const std::string& modelPath,
187 const std::vector<std::string>& inputNames,
188 const std::vector<armnn::TensorShape>& inputShapes,
189 const std::vector<std::string>& outputNames,
190 const std::string& outputPath,
193 m_ModelPath(modelPath),
194 m_InputNames(inputNames),
195 m_InputShapes(inputShapes),
196 m_OutputNames(outputNames),
197 m_OutputPath(outputPath),
198 m_IsModelBinary(isModelBinary) {}
202 if (m_NetworkPtr.get() ==
nullptr)
211 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
213 bool retVal =
serializer->SaveSerializedToStream(file);
218 template <
typename IParser>
219 bool CreateNetwork ()
221 return CreateNetwork (ParserType<IParser>());
226 std::string m_ModelPath;
227 std::vector<std::string> m_InputNames;
228 std::vector<armnn::TensorShape> m_InputShapes;
229 std::vector<std::string> m_OutputNames;
230 std::string m_OutputPath;
231 bool m_IsModelBinary;
233 template <
typename IParser>
234 bool CreateNetwork (ParserType<IParser>)
237 auto parser(IParser::Create());
239 std::map<std::string, armnn::TensorShape> inputShapes;
240 if (!m_InputShapes.empty())
242 const size_t numInputShapes = m_InputShapes.size();
243 const size_t numInputBindings = m_InputNames.size();
244 if (numInputShapes < numInputBindings)
247 "Not every input has its tensor shape specified: expected={0}, got={1}",
248 numInputBindings, numInputShapes));
251 for (
size_t i = 0; i < numInputShapes; i++)
253 inputShapes[m_InputNames[i]] = m_InputShapes[i];
259 m_NetworkPtr = (m_IsModelBinary ?
260 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
264 return m_NetworkPtr.get() !=
nullptr;
267 #if defined(ARMNN_TF_LITE_PARSER)
268 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
273 if (!m_InputShapes.empty())
275 const size_t numInputShapes = m_InputShapes.size();
276 const size_t numInputBindings = m_InputNames.size();
277 if (numInputShapes < numInputBindings)
280 "Not every input has its tensor shape specified: expected={0}, got={1}",
281 numInputBindings, numInputShapes));
287 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
290 return m_NetworkPtr.get() !=
nullptr;
294 #if defined(ARMNN_ONNX_PARSER)
295 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
300 if (!m_InputShapes.empty())
302 const size_t numInputShapes = m_InputShapes.size();
303 const size_t numInputBindings = m_InputNames.size();
304 if (numInputShapes < numInputBindings)
307 "Not every input has its tensor shape specified: expected={0}, got={1}",
308 numInputBindings, numInputShapes));
314 m_NetworkPtr = (m_IsModelBinary ?
315 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
316 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
319 return m_NetworkPtr.get() !=
nullptr;
327 int main(
int argc,
char* argv[])
330 #if (!defined(ARMNN_ONNX_PARSER) \
331 && !defined(ARMNN_TF_PARSER) \
332 && !defined(ARMNN_TF_LITE_PARSER))
333 ARMNN_LOG(fatal) <<
"Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
337 #if !defined(ARMNN_SERIALIZER)
338 ARMNN_LOG(fatal) <<
"Not built with Serializer support.";
350 std::string modelFormat;
351 std::string modelPath;
353 std::vector<std::string> inputNames;
354 std::vector<std::string> inputTensorShapeStrs;
355 std::vector<armnn::TensorShape> inputTensorShapes;
357 std::vector<std::string> outputNames;
358 std::string outputPath;
360 bool isModelBinary =
true;
362 if (ParseCommandLineArgs(
363 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
369 for (
const std::string& shapeStr : inputTensorShapeStrs)
371 if (!shapeStr.empty())
373 std::stringstream ss(shapeStr);
378 inputTensorShapes.push_back(shape);
382 ARMNN_LOG(fatal) <<
"Cannot create tensor shape: " << e.
what();
388 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
392 if (modelFormat.find(
"onnx") != std::string::npos)
394 #if defined(ARMNN_ONNX_PARSER)
397 ARMNN_LOG(fatal) <<
"Failed to load model from file";
401 ARMNN_LOG(fatal) <<
"Not built with Onnx parser support.";
405 else if (modelFormat.find(
"tflite") != std::string::npos)
407 #if defined(ARMNN_TF_LITE_PARSER)
410 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Only 'binary' format supported \
417 ARMNN_LOG(fatal) <<
"Failed to load model from file";
421 ARMNN_LOG(fatal) <<
"Not built with TfLite parser support.";
427 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'";
433 ARMNN_LOG(fatal) <<
"Failed to load model from file: " << e.
what();
437 if (!converter.Serialize())
439 ARMNN_LOG(fatal) <<
"Failed to serialize model";