33 #include "utils/Utils.h"
96 const std::set<ConvolutionPaddingMode> available_padding_modes
102 const std::set<arm_compute::graph::ConvolutionMethod> supported_convolution_methods
110 const std::set<DataLayout> supported_data_layouts
120 padding_mode->
set_help(
"Set padding mode");
121 help->set_help(
"Show this help message");
122 width->set_help(
"Set Input dimension width");
123 height->set_help(
"Set Input dimension height");
124 channels->set_help(
"Set Input dimension channels");
125 batch->set_help(
"Set Input dimension batch");
126 weights_width->set_help(
"Set weights_dimensions width");
127 weights_height->set_help(
"Set weights_dimensions height");
128 OFM->set_help(
"Set OFM");
129 padding_top->set_help(
"Set padding top");
130 padding_bottom->set_help(
"Set padding bottom");
131 padding_left->set_help(
"Set padding left");
132 padding_right->set_help(
"Set padding right");
133 stride_x->set_help(
"Set padding stride x");
134 stride_y->set_help(
"Set padding stride y");
135 conv_mode->
set_help(
"Set convolution method");
136 scale->set_help(
"Quantization scale from QASYMM8");
137 offset->set_help(
"Quantization offset from QASYMM8");
138 weights_scale->set_help(
"Quantization scale from QASYMM8");
139 weights_offset->set_help(
"Quantization offset from QASYMM8");
140 output_scale->set_help(
"Quantization scale from QASYMM8");
141 output_offset->set_help(
"Quantization offset from QASYMM8");
142 input_npy->set_help(
"Use input .npy instead");
143 output_npy->set_help(
"Use .npy as a reference");
144 input_range_low->set_help(
"Lower bound for input randomization range");
145 input_range_high->set_help(
"Lower bound for input randomization range");
146 weights_range_low->set_help(
"Lower bound for input randomization range");
147 weights_range_high->set_help(
"Lower bound for input randomization range");
161 common_params.
input.
fm = channels->value();
164 common_params.
input.
npy = input_npy->value();
171 common_params.
weights.
npy = weights_npy->value();
176 common_params.
bias.
npy = bias_npy->value();
179 common_params.
output.
npy = output_npy->value();
194 void print_parameters(::std::ostream &os,
const ExampleParams &common_params)
override
198 os <<
"Data type : " << common_params.
data_type << std::endl;
199 os <<
"Input dimensions(X,Y, Channels, Batch) : (" << common_params.
input.
width <<
"," << common_params.
input.
height <<
"," << common_params.
input.
fm <<
"," << common_params.
input.
batch <<
")"
201 os <<
"Weight dimensions(X,Y, Channels(same as input), OFM) : (" << common_params.
weights.
width <<
"," << common_params.
weights.
height <<
"," << common_params.
input.
fm <<
"," <<
202 common_params.
weights.
fm <<
")" << std::endl;
211 ConvolutionOptions(
const ConvolutionOptions &) =
delete;
213 ConvolutionOptions &operator=(
const ConvolutionOptions &) =
delete;
215 ConvolutionOptions(ConvolutionOptions &&) noexcept(true) =
default;
217 ConvolutionOptions &operator=(ConvolutionOptions &&) noexcept(true) =
default;
219 ~ConvolutionOptions() override =
default;
256 template <typename D>
260 using BaseClassType::BaseClassType;
261 using BaseClassType::_params;
271 1, _params.output.quant_info);
274 float relative_tolerance()
override
276 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
302 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
306 float absolute_tolerance()
override
308 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
326 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
329 float tolerance_number()
override
331 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
349 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
355 class GraphConvolutionValidateExample final :
public GraphValidateExample<ConvolutionLayer, ConvolutionOptions, ConvolutionVerifyAccessor>
360 GraphConvolutionValidateExample()
393 int main(
int argc,
char **argv)
395 return arm_compute::utils::run_example<GraphConvolutionValidateExample>(argc, argv);