35 class GraphResNet12Example :
public Example 38 GraphResNet12Example()
39 : cmd_parser(), common_opts(cmd_parser), model_input_width(nullptr), model_input_height(nullptr), common_params(), graph(0,
"ResNet12")
45 model_input_width->
set_help(
"Input image width.");
46 model_input_height->
set_help(
"Input image height.");
48 GraphResNet12Example(
const GraphResNet12Example &) =
delete;
49 GraphResNet12Example &operator=(
const GraphResNet12Example &) =
delete;
50 ~GraphResNet12Example()
override =
default;
51 bool do_setup(
int argc,
char **argv)
override 54 cmd_parser.parse(argc, argv);
55 cmd_parser.validate();
61 if(common_params.help)
63 cmd_parser.print_help(argv[0]);
68 const unsigned int image_width = model_input_width->value();
69 const unsigned int image_height = model_input_height->
value();
75 std::cout << common_params << std::endl;
76 std::cout <<
"Image width: " << image_width << std::endl;
77 std::cout <<
"Image height: " << image_height << std::endl;
80 const std::string data_path = common_params.data_path;
81 const std::string model_path =
"/cnn_data/resnet12_model/";
84 std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
93 graph << common_params.target
94 << common_params.fast_math_hint
100 PadStrideInfo(1, 1, 4, 4))
104 add_residual_block(data_path,
"block1", weights_layout);
105 add_residual_block(data_path,
"block2", weights_layout);
106 add_residual_block(data_path,
"block3", weights_layout);
107 add_residual_block(data_path,
"block4", weights_layout);
113 PadStrideInfo(1, 1, 1, 1))
120 PadStrideInfo(1, 1, 1, 1))
127 PadStrideInfo(1, 1, 4, 4))
131 <<
OutputLayer(std::make_unique<DummyAccessor>(0));
136 config.
use_tuner = common_params.enable_tuner;
139 config.
mlgo_file = common_params.mlgo_file;
141 graph.finalize(common_params.target, config);
146 void do_run()
override 160 void add_residual_block(
const std::string &data_path,
const std::string &
name,
DataLayout weights_layout)
162 std::stringstream unit_path_ss;
163 unit_path_ss << data_path << name <<
"_";
164 std::stringstream unit_name_ss;
165 unit_name_ss << name <<
"/";
167 std::string unit_path = unit_path_ss.str();
168 std::string unit_name = unit_name_ss.str();
177 PadStrideInfo(1, 1, 1, 1))
178 .
set_name(unit_name +
"conv1/convolution")
185 .
set_name(unit_name +
"conv1/BatchNorm")
192 PadStrideInfo(1, 1, 1, 1))
193 .
set_name(unit_name +
"conv2/convolution")
200 .
set_name(unit_name +
"conv2/BatchNorm")
203 graph <<
EltwiseLayer(std::move(left), std::move(right), EltwiseOperation::Add).
set_name(unit_name +
"add");
219 int main(
int argc,
char **argv)
221 return arm_compute::utils::run_example<GraphResNet12Example>(argc, argv);
Graph configuration structure Device target types.
int main(int argc, char **argv)
Main program for ResNet12.
CLTunerMode tuner_mode
Tuner mode to be used by the CL tuner.
std::unique_ptr< graph::ITensorAccessor > get_input_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, std::unique_ptr< IPreprocessor > preprocessor=nullptr, bool bgr=true)
Generates appropriate input accessor according to the specified graph parameters. ...
void consume_common_graph_parameters(CommonGraphValidateOptions &options, CommonParams &common_params)
Consumes the consume_common_graph_parameters graph options and creates a structure containing any inf...
Includes all the Graph headers at once.
Common command line options used to configure the graph examples.
Class to parse command line arguments.
std::string mlgo_file
Filename to load MLGO heuristics from.
std::string tuner_file
File to load/store tuning values from.
#define ARM_COMPUTE_EXIT_ON_MSG(cond, msg)
If the condition is true, the given message is printed and program exits.
const T & value() const
Get the option value.
Num samples, channels, height, width.
TensorShape permute_shape(TensorShape tensor_shape, DataLayout in_data_layout, DataLayout out_data_layout)
Permutes a given tensor shape given the input and output data layout.
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
TensorDescriptor & set_layout(DataLayout data_layout)
Sets tensor descriptor data layout.
Structure holding all the common graph parameters.
bool use_tuner
Use a tuner in tunable backends.
std::unique_ptr< graph::ITensorAccessor > get_weights_accessor(const std::string &path, const std::string &data_file, DataLayout file_layout=DataLayout::NCHW)
Generates appropriate weights accessor according to the specified path.
int num_threads
Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is.
Stream frontend class to construct simple graphs in a stream fashion.
Batchnormalization Layer.
DataLayout
[DataLayout enum definition]
ILayer & set_name(std::string name)
Sets the name of the layer.
void set_help(std::string help)
Set the help message for the option.