Compute Library
 21.02
graph_mobilenet.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph.h"
27 #include "utils/GraphUtils.h"
28 #include "utils/Utils.h"
29 
30 using namespace arm_compute;
31 using namespace arm_compute::utils;
32 using namespace arm_compute::graph::frontend;
33 using namespace arm_compute::graph_utils;
34 
35 /** Example demonstrating how to implement MobileNet's network using the Compute Library's graph API */
36 class GraphMobilenetExample : public Example
37 {
38 public:
39  GraphMobilenetExample()
40  : cmd_parser(), common_opts(cmd_parser), common_params(), graph(0, "MobileNetV1")
41  {
42  // Add model id option
43  model_id_opt = cmd_parser.add_option<SimpleOption<int>>("model-id", 0);
44  model_id_opt->set_help("Mobilenet model id (0: 1.0_224, else: 0.75_160");
45  }
46  GraphMobilenetExample(const GraphMobilenetExample &) = delete;
47  GraphMobilenetExample &operator=(const GraphMobilenetExample &) = delete;
48  ~GraphMobilenetExample() override = default;
49  bool do_setup(int argc, char **argv) override
50  {
51  // Parse arguments
52  cmd_parser.parse(argc, argv);
53  cmd_parser.validate();
54 
55  // Consume common parameters
56  common_params = consume_common_graph_parameters(common_opts);
57 
58  // Return when help menu is requested
59  if(common_params.help)
60  {
61  cmd_parser.print_help(argv[0]);
62  return false;
63  }
64 
65  // Print parameter values
66  std::cout << common_params << std::endl;
67 
68  // Get model parameters
69  int model_id = model_id_opt->value();
70 
71  // Create input descriptor
72  unsigned int spatial_size = (model_id == 0 || common_params.data_type == DataType::QASYMM8) ? 224 : 160;
73 
74  // Create input descriptor
75  const TensorShape tensor_shape = permute_shape(TensorShape(spatial_size, spatial_size, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
76  TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
77 
78  // Set graph hints
79  graph << common_params.target
80  << common_params.fast_math_hint;
81 
82  // Create core graph
83  if(arm_compute::is_data_type_float(common_params.data_type))
84  {
85  create_graph_float(input_descriptor, model_id);
86  }
87  else
88  {
89  create_graph_qasymm(input_descriptor);
90  }
91 
92  // Create common tail
93  graph << ReshapeLayer(TensorShape(1001U)).set_name("Reshape")
94  << SoftmaxLayer().set_name("Softmax")
95  << OutputLayer(get_output_accessor(common_params, 5));
96 
97  // Finalize graph
98  GraphConfig config;
99  config.num_threads = common_params.threads;
100  config.use_tuner = common_params.enable_tuner;
101  config.tuner_mode = common_params.tuner_mode;
102  config.tuner_file = common_params.tuner_file;
103  config.mlgo_file = common_params.mlgo_file;
104  config.mlgo_file = common_params.mlgo_file;
105 
106  graph.finalize(common_params.target, config);
107 
108  return true;
109  }
110  void do_run() override
111  {
112  // Run graph
113  graph.run();
114  }
115 
116 private:
117  CommandLineParser cmd_parser;
118  CommonGraphOptions common_opts;
119  SimpleOption<int> *model_id_opt{ nullptr };
120  CommonGraphParams common_params;
121  Stream graph;
122 
123  void create_graph_float(TensorDescriptor &input_descriptor, int model_id)
124  {
125  float depth_scale = (model_id == 0) ? 1.f : 0.75;
126  std::string model_path = (model_id == 0) ? "/cnn_data/mobilenet_v1_1_224_model/" : "/cnn_data/mobilenet_v1_075_160_model/";
127 
128  // Create a preprocessor object
129  std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
130 
131  // Get trainable parameters data path
132  std::string data_path = common_params.data_path;
133 
134  // Add model path to data path
135  if(!data_path.empty())
136  {
137  data_path += model_path;
138  }
139 
140  graph << InputLayer(input_descriptor,
141  get_input_accessor(common_params, std::move(preprocessor), false))
142  << ConvolutionLayer(
143  3U, 3U, 32U * depth_scale,
144  get_weights_accessor(data_path, "Conv2d_0_weights.npy", DataLayout::NCHW),
145  std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
146  PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR))
147  .set_name("Conv2d_0")
149  get_weights_accessor(data_path, "Conv2d_0_BatchNorm_moving_mean.npy"),
150  get_weights_accessor(data_path, "Conv2d_0_BatchNorm_moving_variance.npy"),
151  get_weights_accessor(data_path, "Conv2d_0_BatchNorm_gamma.npy"),
152  get_weights_accessor(data_path, "Conv2d_0_BatchNorm_beta.npy"),
153  0.001f)
154  .set_name("Conv2d_0/BatchNorm")
156  graph << get_dwsc_node_float(data_path, "Conv2d_1", 64 * depth_scale, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
157  graph << get_dwsc_node_float(data_path, "Conv2d_2", 128 * depth_scale, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
158  graph << get_dwsc_node_float(data_path, "Conv2d_3", 128 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
159  graph << get_dwsc_node_float(data_path, "Conv2d_4", 256 * depth_scale, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
160  graph << get_dwsc_node_float(data_path, "Conv2d_5", 256 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
161  graph << get_dwsc_node_float(data_path, "Conv2d_6", 512 * depth_scale, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
162  graph << get_dwsc_node_float(data_path, "Conv2d_7", 512 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
163  graph << get_dwsc_node_float(data_path, "Conv2d_8", 512 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
164  graph << get_dwsc_node_float(data_path, "Conv2d_9", 512 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
165  graph << get_dwsc_node_float(data_path, "Conv2d_10", 512 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
166  graph << get_dwsc_node_float(data_path, "Conv2d_11", 512 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
167  graph << get_dwsc_node_float(data_path, "Conv2d_12", 1024 * depth_scale, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
168  graph << get_dwsc_node_float(data_path, "Conv2d_13", 1024 * depth_scale, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::CEIL), PadStrideInfo(1, 1, 0, 0));
169  graph << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, common_params.data_layout)).set_name("Logits/AvgPool_1a")
170  << ConvolutionLayer(
171  1U, 1U, 1001U,
172  get_weights_accessor(data_path, "Logits_Conv2d_1c_1x1_weights.npy", DataLayout::NCHW),
173  get_weights_accessor(data_path, "Logits_Conv2d_1c_1x1_biases.npy"),
174  PadStrideInfo(1, 1, 0, 0))
175  .set_name("Logits/Conv2d_1c_1x1");
176  }
177 
178  void create_graph_qasymm(TensorDescriptor &input_descriptor)
179  {
180  // Get trainable parameters data path
181  std::string data_path = common_params.data_path;
182 
183  // Add model path to data path
184  if(!data_path.empty())
185  {
186  data_path += "/cnn_data/mobilenet_qasymm8_model/";
187  }
188 
189  // Quantization info taken from the AndroidNN QASYMM8 MobileNet example
190  const QuantizationInfo in_quant_info = QuantizationInfo(0.0078125f, 128);
191 
192  const std::vector<QuantizationInfo> conv_weights_quant_info =
193  {
194  QuantizationInfo(0.02182667888700962f, 151), // conv0
195  QuantizationInfo(0.004986600950360298f, 74) // conv14
196  };
197  const std::vector<QuantizationInfo> conv_out_quant_info =
198  {
199  QuantizationInfo(0.023528477177023888f, 0), // conv0
200  QuantizationInfo(0.16609922051429749f, 66) // conv14
201  };
202 
203  const std::vector<QuantizationInfo> depth_weights_quant_info =
204  {
205  QuantizationInfo(0.29219913482666016f, 110), // dwsc1
206  QuantizationInfo(0.40277284383773804f, 130), // dwsc2
207  QuantizationInfo(0.06053730100393295f, 160), // dwsc3
208  QuantizationInfo(0.01675807684659958f, 123), // dwsc4
209  QuantizationInfo(0.04105526953935623f, 129), // dwsc5
210  QuantizationInfo(0.013460792601108551f, 122), // dwsc6
211  QuantizationInfo(0.036934755742549896f, 132), // dwsc7
212  QuantizationInfo(0.042609862983226776f, 94), // dwsc8
213  QuantizationInfo(0.028358859941363335f, 127), // dwsc9
214  QuantizationInfo(0.024329448118805885f, 134), // dwsc10
215  QuantizationInfo(0.019366811960935593f, 106), // dwsc11
216  QuantizationInfo(0.007835594937205315f, 126), // dwsc12
217  QuantizationInfo(0.12616927921772003f, 211) // dwsc13
218  };
219 
220  const std::vector<QuantizationInfo> point_weights_quant_info =
221  {
222  QuantizationInfo(0.030420949682593346f, 121), // dwsc1
223  QuantizationInfo(0.015148180536925793f, 104), // dwsc2
224  QuantizationInfo(0.013755458407104015f, 94), // dwsc3
225  QuantizationInfo(0.007601846940815449f, 151), // dwsc4
226  QuantizationInfo(0.006431614048779011f, 122), // dwsc5
227  QuantizationInfo(0.00917122047394514f, 109), // dwsc6
228  QuantizationInfo(0.005300046876072884f, 140), // dwsc7
229  QuantizationInfo(0.0049632852897048f, 127), // dwsc8
230  QuantizationInfo(0.007770895957946777f, 89), // dwsc9
231  QuantizationInfo(0.009658650495111942f, 99), // dwsc10
232  QuantizationInfo(0.005446993745863438f, 153), // dwsc11
233  QuantizationInfo(0.00817922968417406f, 130), // dwsc12
234  QuantizationInfo(0.018048152327537537f, 95) // dwsc13
235  };
236 
237  graph << InputLayer(input_descriptor.set_quantization_info(in_quant_info),
238  get_input_accessor(common_params, nullptr, false))
239  << ConvolutionLayer(
240  3U, 3U, 32U,
241  get_weights_accessor(data_path, "Conv2d_0_weights.npy"),
242  get_weights_accessor(data_path, "Conv2d_0_bias.npy"),
243  PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR),
244  1, conv_weights_quant_info.at(0), conv_out_quant_info.at(0))
245  .set_name("Conv2d_0")
247  graph << get_dwsc_node_qasymm(data_path, "Conv2d_1", 64U, PadStrideInfo(1U, 1U, 1U, 1U), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(0), point_weights_quant_info.at(0));
248  graph << get_dwsc_node_qasymm(data_path, "Conv2d_2", 128U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(1),
249  point_weights_quant_info.at(1));
250  graph << get_dwsc_node_qasymm(data_path, "Conv2d_3", 128U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(2),
251  point_weights_quant_info.at(2));
252  graph << get_dwsc_node_qasymm(data_path, "Conv2d_4", 256U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(3),
253  point_weights_quant_info.at(3));
254  graph << get_dwsc_node_qasymm(data_path, "Conv2d_5", 256U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(4),
255  point_weights_quant_info.at(4));
256  graph << get_dwsc_node_qasymm(data_path, "Conv2d_6", 512U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(5),
257  point_weights_quant_info.at(5));
258  graph << get_dwsc_node_qasymm(data_path, "Conv2d_7", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(6),
259  point_weights_quant_info.at(6));
260  graph << get_dwsc_node_qasymm(data_path, "Conv2d_8", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(7),
261  point_weights_quant_info.at(7));
262  graph << get_dwsc_node_qasymm(data_path, "Conv2d_9", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(8),
263  point_weights_quant_info.at(8));
264  graph << get_dwsc_node_qasymm(data_path, "Conv2d_10", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(9),
265  point_weights_quant_info.at(9));
266  graph << get_dwsc_node_qasymm(data_path, "Conv2d_11", 512U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(10),
267  point_weights_quant_info.at(10));
268  graph << get_dwsc_node_qasymm(data_path, "Conv2d_12", 1024U, PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(11),
269  point_weights_quant_info.at(11));
270  graph << get_dwsc_node_qasymm(data_path, "Conv2d_13", 1024U, PadStrideInfo(1U, 1U, 1U, 1U, 1U, 1U, DimensionRoundingType::FLOOR), PadStrideInfo(1U, 1U, 0U, 0U), depth_weights_quant_info.at(12),
271  point_weights_quant_info.at(12))
272  << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, common_params.data_layout)).set_name("Logits/AvgPool_1a")
273  << ConvolutionLayer(
274  1U, 1U, 1001U,
275  get_weights_accessor(data_path, "Logits_Conv2d_1c_1x1_weights.npy"),
276  get_weights_accessor(data_path, "Logits_Conv2d_1c_1x1_bias.npy"),
277  PadStrideInfo(1U, 1U, 0U, 0U), 1, conv_weights_quant_info.at(1), conv_out_quant_info.at(1))
278  .set_name("Logits/Conv2d_1c_1x1");
279  }
280 
281  ConcatLayer get_dwsc_node_float(const std::string &data_path, std::string &&param_path,
282  unsigned int conv_filt,
283  PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info)
284  {
285  std::string total_path = param_path + "_";
286  SubStream sg(graph);
288  3U, 3U,
289  get_weights_accessor(data_path, total_path + "depthwise_depthwise_weights.npy", DataLayout::NCHW),
290  std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
291  dwc_pad_stride_info)
292  .set_name(total_path + "depthwise/depthwise")
294  get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_moving_mean.npy"),
295  get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_moving_variance.npy"),
296  get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_gamma.npy"),
297  get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_beta.npy"),
298  0.001f)
299  .set_name(total_path + "depthwise/BatchNorm")
301  << ConvolutionLayer(
302  1U, 1U, conv_filt,
303  get_weights_accessor(data_path, total_path + "pointwise_weights.npy", DataLayout::NCHW),
304  std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
305  conv_pad_stride_info)
306  .set_name(total_path + "pointwise/Conv2D")
308  get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_moving_mean.npy"),
309  get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_moving_variance.npy"),
310  get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_gamma.npy"),
311  get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_beta.npy"),
312  0.001f)
313  .set_name(total_path + "pointwise/BatchNorm")
315 
316  return ConcatLayer(std::move(sg));
317  }
318 
319  ConcatLayer get_dwsc_node_qasymm(const std::string &data_path, std::string &&param_path,
320  const unsigned int conv_filt,
321  PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info,
322  QuantizationInfo depth_weights_quant_info, QuantizationInfo point_weights_quant_info)
323  {
324  std::string total_path = param_path + "_";
325  SubStream sg(graph);
326 
328  3U, 3U,
329  get_weights_accessor(data_path, total_path + "depthwise_weights.npy"),
330  get_weights_accessor(data_path, total_path + "depthwise_bias.npy"),
331  dwc_pad_stride_info, 1, std::move(depth_weights_quant_info))
332  .set_name(total_path + "depthwise/depthwise")
334  << ConvolutionLayer(
335  1U, 1U, conv_filt,
336  get_weights_accessor(data_path, total_path + "pointwise_weights.npy"),
337  get_weights_accessor(data_path, total_path + "pointwise_bias.npy"),
338  conv_pad_stride_info, 1, std::move(point_weights_quant_info))
339  .set_name(total_path + "pointwise/Conv2D")
341 
342  return ConcatLayer(std::move(sg));
343  }
344 };
345 
346 /** Main program for MobileNetV1
347  *
348  * Model is based on:
349  * https://arxiv.org/abs/1704.04861
350  * "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
351  * Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
352  *
353  * Provenance: download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz
354  * download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160.tgz
355  *
356  * @note To list all the possible arguments execute the binary appended with the --help option
357  *
358  * @param[in] argc Number of arguments
359  * @param[in] argv Arguments
360  */
361 int main(int argc, char **argv)
362 {
363  return arm_compute::utils::run_example<GraphMobilenetExample>(argc, argv);
364 }
Graph configuration structure Device target types.
Definition: Types.h:80
Shape of a tensor.
Definition: TensorShape.h:39
CLTunerMode tuner_mode
Tuner mode to be used by the CL tuner.
Definition: Types.h:87
std::unique_ptr< graph::ITensorAccessor > get_input_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, std::unique_ptr< IPreprocessor > preprocessor=nullptr, bool bgr=true)
Generates appropriate input accessor according to the specified graph parameters. ...
Definition: GraphUtils.h:497
void consume_common_graph_parameters(CommonGraphValidateOptions &options, CommonParams &common_params)
Consumes the consume_common_graph_parameters graph options and creates a structure containing any inf...
Includes all the Graph headers at once.
Common command line options used to configure the graph examples.
Class to parse command line arguments.
Activation Layer Information class.
Definition: Types.h:1550
Copyright (c) 2017-2021 Arm Limited.
std::string mlgo_file
Filename to load MLGO heuristics from.
Definition: Types.h:90
std::string tuner_file
File to load/store tuning values from.
Definition: Types.h:89
Quantization information.
quantized, asymmetric fixed-point 8-bit number unsigned
Pooling Layer Information struct.
Definition: Types.h:1214
Abstract Example class.
Definition: Utils.h:78
Padding and stride information class.
Definition: Types.h:722
int main(int argc, char **argv)
Main program for MobileNetV1.
TensorDescriptor & set_quantization_info(QuantizationInfo tensor_quant_info)
Sets tensor descriptor quantization info.
Num samples, channels, height, width.
TensorShape permute_shape(TensorShape tensor_shape, DataLayout in_data_layout, DataLayout out_data_layout)
Permutes a given tensor shape given the input and output data layout.
Definition: GraphUtils.h:664
TensorDescriptor & set_layout(DataLayout data_layout)
Sets tensor descriptor data layout.
Structure holding all the common graph parameters.
std::unique_ptr< graph::ITensorAccessor > get_output_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, size_t top_n=5, bool is_validation=false, std::ostream &output_stream=std::cout)
Generates appropriate output accessor according to the specified graph parameters.
Definition: GraphUtils.h:543
bool use_tuner
Use a tuner in tunable backends.
Definition: Types.h:85
std::unique_ptr< graph::ITensorAccessor > get_weights_accessor(const std::string &path, const std::string &data_file, DataLayout file_layout=DataLayout::NCHW)
Generates appropriate weights accessor according to the specified path.
Definition: GraphUtils.h:475
int num_threads
Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is.
Definition: Types.h:88
Stream frontend class to construct simple graphs in a stream fashion.
Definition: Stream.h:45
ILayer & set_name(std::string name)
Sets the name of the layer.
Definition: ILayer.h:55
void set_help(std::string help)
Set the help message for the option.
Definition: Option.h:125
bool is_data_type_float(DataType dt)
Check if a given data type is of floating point type.
Definition: Utils.h:1148