33 #include "utils/Utils.h" 96 const std::set<ConvolutionPaddingMode> available_padding_modes
102 const std::set<arm_compute::graph::ConvolutionMethod> supported_convolution_methods
110 const std::set<DataLayout> supported_data_layouts
120 padding_mode->
set_help(
"Set padding mode");
121 help->set_help(
"Show this help message");
122 width->set_help(
"Set Input dimension width");
123 height->set_help(
"Set Input dimension height");
124 channels->set_help(
"Set Input dimension channels");
125 batch->set_help(
"Set Input dimension batch");
128 OFM->set_help(
"Set OFM");
129 padding_top->set_help(
"Set padding top");
130 padding_bottom->set_help(
"Set padding bottom");
131 padding_left->set_help(
"Set padding left");
132 padding_right->set_help(
"Set padding right");
133 stride_x->set_help(
"Set padding stride x");
134 stride_y->set_help(
"Set padding stride y");
135 conv_mode->
set_help(
"Set convolution method");
136 scale->set_help(
"Quantization scale from QASYMM8");
137 offset->set_help(
"Quantization offset from QASYMM8");
138 weights_scale->set_help(
"Quantization scale from QASYMM8");
139 weights_offset->set_help(
"Quantization offset from QASYMM8");
140 output_scale->set_help(
"Quantization scale from QASYMM8");
141 output_offset->set_help(
"Quantization offset from QASYMM8");
142 input_npy->set_help(
"Use input .npy instead");
143 output_npy->set_help(
"Use .npy as a reference");
144 input_range_low->set_help(
"Lower bound for input randomization range");
145 input_range_high->set_help(
"Lower bound for input randomization range");
146 weights_range_low->set_help(
"Lower bound for input randomization range");
147 weights_range_high->set_help(
"Lower bound for input randomization range");
161 common_params.
input.
fm = channels->value();
164 common_params.
input.
npy = input_npy->value();
171 common_params.
weights.
npy = weights_npy->value();
176 common_params.
bias.
npy = bias_npy->value();
179 common_params.
output.
npy = output_npy->value();
194 void print_parameters(::std::ostream &os,
const ExampleParams &common_params)
override 198 os <<
"Data type : " << common_params.
data_type << std::endl;
199 os <<
"Input dimensions(X,Y, Channels, Batch) : (" << common_params.
input.
width <<
"," << common_params.
input.
height <<
"," << common_params.
input.
fm <<
"," << common_params.
input.
batch <<
")" 201 os <<
"Weight dimensions(X,Y, Channels(same as input), OFM) : (" << common_params.
weights.
width <<
"," << common_params.
weights.
height <<
"," << common_params.
input.
fm <<
"," <<
202 common_params.
weights.
fm <<
")" << std::endl;
211 ConvolutionOptions(
const ConvolutionOptions &) =
delete;
213 ConvolutionOptions &operator=(
const ConvolutionOptions &) =
delete;
215 ConvolutionOptions(ConvolutionOptions &&) noexcept(true) =
default;
217 ConvolutionOptions &operator=(ConvolutionOptions &&) noexcept(true) = default;
219 ~ConvolutionOptions() override = default;
256 template <typename D>
260 using BaseClassType::BaseClassType;
261 using BaseClassType::_params;
271 1, _params.output.quant_info);
274 float relative_tolerance()
override 276 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
302 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
306 float absolute_tolerance()
override 308 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
326 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
329 float tolerance_number()
override 331 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
349 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
355 class GraphConvolutionValidateExample final :
public GraphValidateExample<ConvolutionLayer, ConvolutionOptions, ConvolutionVerifyAccessor>
360 GraphConvolutionValidateExample()
393 int main(
int argc,
char **argv)
395 return arm_compute::utils::run_example<GraphConvolutionValidateExample>(argc, argv);
int padding_top
Padding graph parameters.
PadStrideInfo calculate_convolution_padding(ExampleParams params)
Calculate stride information.
__global uchar * offset(const Image *img, int x, int y)
Get the pointer position of a Image.
Arm® Neon™ capable target device.
Class describing the value of a pixel for any image format.
arm_compute::graph::ConvolutionMethod convolution_method
ConvolutionParams convolution
1 channel, 1 F32 per channel
CommonGraphValidateOptions command line options used to configure the graph examples.
Includes all the Graph headers at once.
Class to parse command line arguments.
ConvolutionMethod
Available ConvolutionMethod.
decltype(strategy::transforms) typedef type
std::unique_ptr< graph::ITensorAccessor > get_accessor(const TensorParams &tensor, PixelValue lower, PixelValue upper, const std::random_device::result_type seed=0)
Generates appropriate accessor according to the specified graph parameters.
SimpleTensor< float > src
Copyright (c) 2017-2023 Arm Limited.
arm_compute::graph::Target target
1 channel, 1 F16 per channel
Quantization information.
arm_compute::DataType data_type
quantized, asymmetric fixed-point 8-bit number unsigned
Structure holding all the graph Example parameters.
Padding and stride information class.
Num samples, channels, height, width.
Simple tensor object that stores elements in a consecutive chunk of memory.
Graph example validation accessor class.
Class for specifying the size of an image or rectangle.
Num samples, height, width, channels.
Implementation of a simple option that accepts a value from a fixed set.
ConvolutionPaddingMode padding_mode
const size_t weights_width
const size_t weights_height
Default approach using internal heuristics.
Winograd based convolution.
FrameworkParams common_params
const T & value() const
Get the selected value.
arm_compute::graph::frontend::Stream graph
OpenCL capable target device.
arm_compute::DataLayout data_layout
DataLayout
[DataLayout enum definition]
int main(int argc, char **argv)
Main program for Graph Convolution test.
void set_help(std::string help)
Set the help message for the option.
QuantizationInfo quant_info