29 #include "utils/Utils.h"
36 class GraphResNet12Example :
public Example
39 GraphResNet12Example()
41 common_opts(cmd_parser),
42 model_input_width(nullptr),
43 model_input_height(nullptr),
51 model_input_width->
set_help(
"Input image width.");
52 model_input_height->
set_help(
"Input image height.");
54 GraphResNet12Example(
const GraphResNet12Example &) =
delete;
55 GraphResNet12Example &operator=(
const GraphResNet12Example &) =
delete;
56 ~GraphResNet12Example()
override =
default;
57 bool do_setup(
int argc,
char **argv)
override
60 cmd_parser.parse(argc, argv);
61 cmd_parser.validate();
67 if (common_params.help)
69 cmd_parser.print_help(argv[0]);
74 const unsigned int image_width = model_input_width->value();
75 const unsigned int image_height = model_input_height->
value();
79 "QASYMM8 not supported for this graph");
82 std::cout << common_params << std::endl;
83 std::cout <<
"Image width: " << image_width << std::endl;
84 std::cout <<
"Image height: " << image_height << std::endl;
87 const std::string data_path = common_params.data_path;
88 const std::string model_path =
"/cnn_data/resnet12_model/";
91 std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
94 const TensorShape tensor_shape =
96 common_params.data_layout);
103 graph << common_params.target << common_params.fast_math_hint
108 PadStrideInfo(1, 1, 4, 4))
110 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
113 add_residual_block(data_path,
"block1", weights_layout);
114 add_residual_block(data_path,
"block2", weights_layout);
115 add_residual_block(data_path,
"block3", weights_layout);
116 add_residual_block(data_path,
"block4", weights_layout);
121 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
126 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
131 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH))
133 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, 0.58f, 0.5f))
135 <<
OutputLayer(std::make_unique<DummyAccessor>(0));
140 config.
use_tuner = common_params.enable_tuner;
143 config.
mlgo_file = common_params.mlgo_file;
145 graph.finalize(common_params.target, config);
150 void do_run()
override
164 void add_residual_block(
const std::string &data_path,
const std::string &
name,
DataLayout weights_layout)
166 std::stringstream unit_path_ss;
167 unit_path_ss << data_path <<
name <<
"_";
168 std::stringstream unit_name_ss;
169 unit_name_ss <<
name <<
"/";
171 std::string unit_path = unit_path_ss.str();
172 std::string unit_name = unit_name_ss.str();
180 PadStrideInfo(1, 1, 1, 1))
181 .
set_name(unit_name +
"conv1/convolution")
187 .
set_name(unit_name +
"conv1/BatchNorm")
188 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
194 PadStrideInfo(1, 1, 1, 1))
195 .
set_name(unit_name +
"conv2/convolution")
201 .
set_name(unit_name +
"conv2/BatchNorm")
202 <<
ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
203 .
set_name(unit_name +
"conv2/Relu");
205 graph <<
EltwiseLayer(std::move(left), std::move(right), EltwiseOperation::Add).
set_name(unit_name +
"add");
221 int main(
int argc,
char **argv)
223 return arm_compute::utils::run_example<GraphResNet12Example>(argc, argv);