Compute Library
 22.05
graph_fully_connected.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph.h"
25 
26 #include "tests/NEON/Accessor.h"
30 
32 #include "utils/GraphUtils.h"
33 #include "utils/Utils.h"
34 
35 #include "ValidateExample.h"
36 #include "graph_validate_utils.h"
37 
38 #include <utility>
39 
40 using namespace arm_compute::utils;
41 using namespace arm_compute::graph::frontend;
42 using namespace arm_compute::graph_utils;
43 using namespace arm_compute::graph;
44 using namespace arm_compute;
45 using namespace arm_compute::test;
46 using namespace arm_compute::test::validation;
47 
48 namespace
49 {
50 /** Fully connected command line options used to configure the graph examples
51  *
52  * (Similar to common options)
53  * The options in this object get populated when "parse()" is called on the parser used to construct it.
54  * The expected workflow is:
55  *
56  * CommandLineParser parser;
57  * CommonOptions options( parser );
58  * parser.parse(argc, argv);
59  */
60 class FullyConnectedOptions final : public CommonGraphValidateOptions
61 {
62 public:
63  explicit FullyConnectedOptions(CommandLineParser &parser) noexcept
65  width(parser.add_option<SimpleOption<int>>("width", 3)),
66  batch(parser.add_option<SimpleOption<int>>("batch", 1)),
67  input_scale(parser.add_option<SimpleOption<float>>("input_scale", 1.0f)),
68  input_offset(parser.add_option<SimpleOption<int>>("input_offset", 0)),
69  weights_scale(parser.add_option<SimpleOption<float>>("weights_scale", 1.0f)),
70  weights_offset(parser.add_option<SimpleOption<int>>("weights_offset", 0)),
71  output_scale(parser.add_option<SimpleOption<float>>("output_scale", 1.0f)),
72  output_offset(parser.add_option<SimpleOption<int>>("output_offset", 0)),
73  num_outputs(parser.add_option<SimpleOption<int>>("num_outputs", 1)),
74  input_range_low(parser.add_option<SimpleOption<uint64_t>>("input_range_low")),
75  input_range_high(parser.add_option<SimpleOption<uint64_t>>("input_range_high")),
76  weights_range_low(parser.add_option<SimpleOption<uint64_t>>("weights_range_low")),
77  weights_range_high(parser.add_option<SimpleOption<uint64_t>>("weights_range_high"))
78  {
79  width->set_help("Set Input dimension width");
80  batch->set_help("Set Input dimension batch");
81  input_scale->set_help("Quantization scale from QASYMM8");
82  input_offset->set_help("Quantization offset from QASYMM8");
83  weights_scale->set_help("Quantization scale from QASYMM8");
84  weights_offset->set_help("Quantization offset from QASYMM8");
85  output_scale->set_help("Quantization scale from QASYMM8");
86  output_offset->set_help("Quantization offset from QASYMM8");
87  num_outputs->set_help("Number of outputs.");
88  input_range_low->set_help("Lower bound for input randomization range");
89  input_range_high->set_help("Lower bound for input randomization range");
90  weights_range_low->set_help("Lower bound for input randomization range");
91  weights_range_high->set_help("Lower bound for input randomization range");
92  }
93 
94  /** Fill out the supplied parameters with user supplied parameters
95  *
96  * @param[out] os Output stream.
97  * @param[in] common_params Example parameters to output
98  *
99  * @return None.
100  */
101  void consume_parameters(ExampleParams &common_params)
102  {
103  common_params.input.width = width->value();
104  common_params.input.batch = batch->value();
105  common_params.input.quant_info = QuantizationInfo(input_scale->value(), input_offset->value());
106  common_params.input.range_low = input_range_low->value();
107  common_params.input.range_high = input_range_high->value();
108 
109  common_params.weights.quant_info = QuantizationInfo(weights_scale->value(), weights_offset->value());
110  common_params.weights.range_low = weights_range_low->value();
111  common_params.weights.range_high = weights_range_high->value();
112 
113  common_params.output.quant_info = QuantizationInfo(output_scale->value(), output_offset->value());
114 
115  common_params.data_type = data_type->value();
116  common_params.fully_connected.num_outputs = num_outputs->value();
117  }
118 
119  void print_parameters(::std::ostream &os, const ExampleParams &common_params) override
120  {
121  os << "Threads : " << common_params.common_params.threads << std::endl;
122  os << "Target : " << common_params.common_params.target << std::endl;
123  os << "Data type : " << common_params.data_type << std::endl;
124  os << "Input dimensions(X,Y, Channels, Batch) : (" << common_params.input.width << "," << common_params.input.height << "," << common_params.input.fm << "," << common_params.input.batch << ")"
125  << std::endl;
126  os << "Number of outputs : " << common_params.fully_connected.num_outputs << std::endl;
127  }
128 
129  /** Prevent instances of this class from being copied (As this class contains pointers) */
130  FullyConnectedOptions(const FullyConnectedOptions &) = delete;
131  /** Prevent instances of this class from being copied (As this class contains pointers) */
132  FullyConnectedOptions &operator=(const FullyConnectedOptions &) = delete;
133  /** Allow instances of this class to be moved */
134  FullyConnectedOptions(FullyConnectedOptions &&) noexcept(true) = default;
135  /** Allow instances of this class to be moved */
136  FullyConnectedOptions &operator=(FullyConnectedOptions &&) noexcept(true) = default;
137  /** Default destructor */
138  ~FullyConnectedOptions() override = default;
139 
140 private:
141  SimpleOption<int> *width; /**< Input width */
142  SimpleOption<int> *batch; /**< Input batch */
143  SimpleOption<float> *input_scale; /**< Input Quantization scale from QASSYMM8 */
144  SimpleOption<int> *input_offset; /**< Input Quantization offset from QASSYMM8 */
145  SimpleOption<float> *weights_scale; /**< Weights Quantization scale from QASSYMM8 */
146  SimpleOption<int> *weights_offset; /**< Weights Quantization offset from QASSYMM8 */
147  SimpleOption<float> *output_scale; /**< Output Quantization scale from QASSYMM8 */
148  SimpleOption<int> *output_offset; /**< Output Quantization offset from QASSYMM8 */
149  SimpleOption<int> *num_outputs; /**< Number of outputs. */
150  SimpleOption<uint64_t> *input_range_low; /**< Lower bound for input randomization range */
151  SimpleOption<uint64_t> *input_range_high; /**< Upper bound for input randomization range */
152  SimpleOption<uint64_t> *weights_range_low; /**< Lower bound for weights randomization range */
153  SimpleOption<uint64_t> *weights_range_high; /**< Upper bound for weights randomization range */
154 };
155 
156 /** Fully Connected Layer Graph example validation accessor class */
157 template <typename D>
158 class FullyConnectedVerifyAccessor final : public VerifyAccessor<D>
159 {
160  using BaseClassType = VerifyAccessor<D>;
161  using BaseClassType::BaseClassType;
162  using BaseClassType::_params;
163  using TBias = typename std::conditional<std::is_same<typename std::decay<D>::type, uint8_t>::value, int32_t, D>::type;
164 
165  // Inherited methods overriden:
166  void create_tensors(arm_compute::test::SimpleTensor<D> &src,
169  ITensor &tensor) override
170  {
171  // Calculate Tensor shapes for verification
172  const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
173  const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
174  const TensorDescriptor weights_descriptor = FullyConnectedLayerNode::compute_weights_descriptor(input_descriptor,
175  _params.fully_connected.num_outputs,
176  _params.fully_connected.info,
177  _params.weights.quant_info);
178  const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
179 
180  //Create Input tensors
181  src = SimpleTensor<D> { input_descriptor.shape, _params.data_type, 1, input_descriptor.quant_info };
182  weights = SimpleTensor<D> { weights_descriptor.shape, _params.data_type, 1, weights_descriptor.quant_info };
183  bias = SimpleTensor<TBias> { TensorShape(tensor.info()->tensor_shape().x()), _params.data_type, 1, _params.input.quant_info };
184  }
185 
186  TensorShape output_shape(ITensor &tensor) override
187  {
188  ARM_COMPUTE_UNUSED(tensor);
189 
190  const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
191  const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
192  const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
193 
194  return output_desciptor.shape;
195  }
196 
200  const arm_compute::TensorShape &output_shape) override
201  {
202  return reference::fully_connected_layer<D>(src, weights, bias, output_shape, _params.output.quant_info);
203  }
204 
205  float relative_tolerance() override
206  {
207  const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
208  {
209  {
211  { { DataType::F16, 0.2f },
212  { DataType::F32, 0.05f },
213  { DataType::QASYMM8, 1.0f }
214  }
215  },
216  {
218  { { DataType::F16, 0.2f },
219  { DataType::F32, 0.01f },
220  { DataType::QASYMM8, 1.0f }
221  }
222  }
223  };
224 
225  return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
226  }
227 
228  float absolute_tolerance() override
229  {
230  const std::map<Target, const std::map<DataType, float>> absolute_tolerance
231  {
232  {
233  Target::CL,
234  { { DataType::F16, 0.0f },
235  { DataType::F32, 0.0001f },
236  { DataType::QASYMM8, 1.0f }
237  }
238  },
239  {
240  Target::NEON,
241  { { DataType::F16, 0.3f },
242  { DataType::F32, 0.1f },
243  { DataType::QASYMM8, 1.0f }
244  }
245  }
246  };
247 
248  return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
249  }
250 
251  float tolerance_number() override
252  {
253  const std::map<Target, const std::map<DataType, float>> absolute_tolerance
254  {
255  {
256  Target::CL,
257  { { DataType::F16, 0.07f },
258  { DataType::F32, 0.07f },
259  { DataType::QASYMM8, 0.0f }
260  }
261  },
262  {
263  Target::NEON,
264  { { DataType::F16, 0.07f },
265  { DataType::F32, 0.0f },
266  { DataType::QASYMM8, 0.0f }
267  }
268  }
269  };
270 
271  return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
272  }
273 };
274 
275 } // namespace
276 
277 class GraphFullyConnectedValidateExample final : public GraphValidateExample<FullyConnectedLayer, FullyConnectedOptions, FullyConnectedVerifyAccessor>
278 {
280 
281 public:
282  GraphFullyConnectedValidateExample()
283  : GraphValidateExample("Fully_connected Graph example")
284  {
285  }
286 
287  FullyConnectedLayer GraphFunctionLayer(ExampleParams &params) override
288  {
289  const PixelValue lower = PixelValue(params.input.range_low, params.data_type, params.input.quant_info);
290  const PixelValue upper = PixelValue(params.input.range_high, params.data_type, params.input.quant_info);
291 
292  const PixelValue weights_lower = PixelValue(params.weights.range_low, params.data_type, params.weights.quant_info);
293  const PixelValue weights_upper = PixelValue(params.weights.range_high, params.data_type, params.weights.quant_info);
294 
296  get_random_accessor(weights_lower, weights_upper, 1),
297  get_random_accessor(lower, upper, 2),
298  params.fully_connected.info, params.weights.quant_info, params.output.quant_info);
299  }
300 };
301 
302 /** Main program for Graph fully_connected test
303  *
304  * @param[in] argc Number of arguments
305  * @param[in] argv Arguments ( Input dimensions [width, batch]
306  * Fully connected [num_outputs,type]
307  * Verification[tolerance_number,absolute_tolerance,relative_tolerance] )
308  *
309  */
310 int main(int argc, char **argv)
311 {
312  return arm_compute::utils::run_example<GraphFullyConnectedValidateExample>(argc, argv);
313 }
Arm® Neon™ capable target device.
Class describing the value of a pixel for any image format.
Definition: PixelValue.h:34
Shape of a tensor.
Definition: TensorShape.h:39
1 channel, 1 F32 per channel
CommonGraphValidateOptions command line options used to configure the graph examples.
Includes all the Graph headers at once.
Class to parse command line arguments.
decltype(strategy::transforms) typedef type
Interface for CPU tensor.
Definition: ITensor.h:36
SimpleTensor< float > src
Definition: DFT.cpp:155
Copyright (c) 2017-2022 Arm Limited.
std::unique_ptr< graph::ITensorAccessor > get_random_accessor(PixelValue lower, PixelValue upper, const std::random_device::result_type seed=0)
Generates appropriate random accessor.
Definition: GraphUtils.h:461
1 channel, 1 F16 per channel
T x() const
Alias to access the size of the first dimension.
Definition: Dimensions.h:87
Quantization information.
QuantizationInfo quant_info
Quantization info.
const auto input_shape
Validate test suite is to test ARM_COMPUTE_RETURN_ON_* macros we use to check the validity of given a...
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
quantized, asymmetric fixed-point 8-bit number unsigned
Structure holding all the graph Example parameters.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
Simple tensor object that stores elements in a consecutive chunk of memory.
Definition: SimpleTensor.h:58
Graph example validation accessor class.
arm_compute::graph::frontend::Stream graph
OpenCL capable target device.
int main(int argc, char **argv)
Main program for Graph fully_connected test.
const int32_t * bias