29 #include "utils/Utils.h"
37 class GraphMobilenetExample :
public Example
40 GraphMobilenetExample() : cmd_parser(), common_opts(cmd_parser), common_params(), graph(0,
"MobileNetV1")
44 model_id_opt->
set_help(
"Mobilenet model id (0: 1.0_224, else: 0.75_160");
46 GraphMobilenetExample(
const GraphMobilenetExample &) =
delete;
47 GraphMobilenetExample &operator=(
const GraphMobilenetExample &) =
delete;
48 ~GraphMobilenetExample()
override =
default;
49 bool do_setup(
int argc,
char **argv)
override
52 cmd_parser.parse(argc, argv);
53 cmd_parser.validate();
59 if (common_params.help)
61 cmd_parser.print_help(argv[0]);
66 std::cout << common_params << std::endl;
69 int model_id = model_id_opt->value();
72 unsigned int spatial_size = (model_id == 0 || common_params.data_type ==
DataType::QASYMM8) ? 224 : 160;
77 common_params.data_layout);
82 graph << common_params.
target << common_params.fast_math_hint;
87 create_graph_float(input_descriptor, model_id);
91 create_graph_qasymm(input_descriptor);
101 config.
use_tuner = common_params.enable_tuner;
104 config.
mlgo_file = common_params.mlgo_file;
106 graph.finalize(common_params.target, config);
110 void do_run()
override
125 float depth_scale = (model_id == 0) ? 1.f : 0.75;
126 std::string model_path =
127 (model_id == 0) ?
"/cnn_data/mobilenet_v1_1_224_model/" :
"/cnn_data/mobilenet_v1_075_160_model/";
130 std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
133 std::string data_path = common_params.
data_path;
136 if (!data_path.empty())
138 data_path += model_path;
144 std::unique_ptr<arm_compute::graph::ITensorAccessor>(
nullptr),
154 graph << get_dwsc_node_float(data_path,
"Conv2d_1", 64 * depth_scale,
PadStrideInfo(1, 1, 1, 1),
156 graph << get_dwsc_node_float(data_path,
"Conv2d_2", 128 * depth_scale,
159 graph << get_dwsc_node_float(data_path,
"Conv2d_3", 128 * depth_scale,
162 graph << get_dwsc_node_float(data_path,
"Conv2d_4", 256 * depth_scale,
165 graph << get_dwsc_node_float(data_path,
"Conv2d_5", 256 * depth_scale,
168 graph << get_dwsc_node_float(data_path,
"Conv2d_6", 512 * depth_scale,
171 graph << get_dwsc_node_float(data_path,
"Conv2d_7", 512 * depth_scale,
174 graph << get_dwsc_node_float(data_path,
"Conv2d_8", 512 * depth_scale,
177 graph << get_dwsc_node_float(data_path,
"Conv2d_9", 512 * depth_scale,
180 graph << get_dwsc_node_float(data_path,
"Conv2d_10", 512 * depth_scale,
183 graph << get_dwsc_node_float(data_path,
"Conv2d_11", 512 * depth_scale,
186 graph << get_dwsc_node_float(data_path,
"Conv2d_12", 1024 * depth_scale,
189 graph << get_dwsc_node_float(data_path,
"Conv2d_13", 1024 * depth_scale,
203 std::string data_path = common_params.
data_path;
206 if (!data_path.empty())
208 data_path +=
"/cnn_data/mobilenet_qasymm8_model/";
214 const std::vector<QuantizationInfo> conv_weights_quant_info = {
218 const std::vector<QuantizationInfo> conv_out_quant_info = {
223 const std::vector<QuantizationInfo> depth_weights_quant_info = {
239 const std::vector<QuantizationInfo> point_weights_quant_info = {
260 conv_weights_quant_info.at(0), conv_out_quant_info.at(0))
264 graph << get_dwsc_node_qasymm(data_path,
"Conv2d_1", 64U,
PadStrideInfo(1U, 1U, 1U, 1U),
265 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(0),
266 point_weights_quant_info.at(0));
267 graph << get_dwsc_node_qasymm(
268 data_path,
"Conv2d_2", 128U,
PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U,
DimensionRoundingType::FLOOR),
269 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(1), point_weights_quant_info.at(1));
270 graph << get_dwsc_node_qasymm(
271 data_path,
"Conv2d_3", 128U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
272 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(2), point_weights_quant_info.at(2));
273 graph << get_dwsc_node_qasymm(
274 data_path,
"Conv2d_4", 256U,
PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U,
DimensionRoundingType::FLOOR),
275 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(3), point_weights_quant_info.at(3));
276 graph << get_dwsc_node_qasymm(
277 data_path,
"Conv2d_5", 256U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
278 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(4), point_weights_quant_info.at(4));
279 graph << get_dwsc_node_qasymm(
280 data_path,
"Conv2d_6", 512U,
PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U,
DimensionRoundingType::FLOOR),
281 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(5), point_weights_quant_info.at(5));
282 graph << get_dwsc_node_qasymm(
283 data_path,
"Conv2d_7", 512U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
284 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(6), point_weights_quant_info.at(6));
285 graph << get_dwsc_node_qasymm(
286 data_path,
"Conv2d_8", 512U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
287 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(7), point_weights_quant_info.at(7));
288 graph << get_dwsc_node_qasymm(
289 data_path,
"Conv2d_9", 512U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
290 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(8), point_weights_quant_info.at(8));
291 graph << get_dwsc_node_qasymm(
292 data_path,
"Conv2d_10", 512U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
293 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(9), point_weights_quant_info.at(9));
294 graph << get_dwsc_node_qasymm(
295 data_path,
"Conv2d_11", 512U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
296 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(10), point_weights_quant_info.at(10));
297 graph << get_dwsc_node_qasymm(
298 data_path,
"Conv2d_12", 1024U,
PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U,
DimensionRoundingType::FLOOR),
299 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(11), point_weights_quant_info.at(11));
301 << get_dwsc_node_qasymm(
302 data_path,
"Conv2d_13", 1024U,
PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U,
DimensionRoundingType::FLOOR),
303 PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(12), point_weights_quant_info.at(12))
307 PadStrideInfo(1U, 1U, 0U, 0U), 1, conv_weights_quant_info.at(1),
308 conv_out_quant_info.at(1))
312 ConcatLayer get_dwsc_node_float(
const std::string &data_path,
313 std::string &¶m_path,
314 unsigned int conv_filt,
318 std::string total_path = param_path +
"_";
323 std::unique_ptr<arm_compute::graph::ITensorAccessor>(
nullptr), dwc_pad_stride_info)
324 .
set_name(total_path +
"depthwise/depthwise")
330 .
set_name(total_path +
"depthwise/BatchNorm")
332 .
set_name(total_path +
"depthwise/Relu6")
335 std::unique_ptr<arm_compute::graph::ITensorAccessor>(
nullptr), conv_pad_stride_info)
336 .
set_name(total_path +
"pointwise/Conv2D")
342 .
set_name(total_path +
"pointwise/BatchNorm")
344 .
set_name(total_path +
"pointwise/Relu6");
349 ConcatLayer get_dwsc_node_qasymm(
const std::string &data_path,
350 std::string &¶m_path,
351 const unsigned int conv_filt,
357 std::string total_path = param_path +
"_";
362 dwc_pad_stride_info, 1, std::move(depth_weights_quant_info))
363 .
set_name(total_path +
"depthwise/depthwise")
365 .
set_name(total_path +
"depthwise/Relu6")
368 1, std::move(point_weights_quant_info))
369 .
set_name(total_path +
"pointwise/Conv2D")
371 .
set_name(total_path +
"pointwise/Relu6");
392 int main(
int argc,
char **argv)
394 return arm_compute::utils::run_example<GraphMobilenetExample>(argc, argv);