33 #include "utils/Utils.h"
96 const std::set<ConvolutionPaddingMode> available_padding_modes
102 const std::set<arm_compute::graph::DepthwiseConvolutionMethod> supported_convolution_methods
109 const std::set<DataLayout> supported_data_layouts
120 padding_mode->
set_help(
"Set padding mode");
121 width->set_help(
"Set Input dimension width");
122 height->set_help(
"Set Input dimension height");
123 channels->set_help(
"Set Input dimension channels");
124 batch->set_help(
"Set Input dimension batch");
125 weights_width->set_help(
"Set weights_dimensions width");
126 weights_height->set_help(
"Set weights_dimensions height");
127 padding_top->set_help(
"Set padding top");
128 padding_bottom->set_help(
"Set padding bottom");
129 padding_left->set_help(
"Set padding left");
130 padding_right->set_help(
"Set padding right");
131 stride_x->set_help(
"Set padding stride x");
132 stride_y->set_help(
"Set padding stride y");
133 conv_mode->
set_help(
"Set convolution method");
135 scale->set_help(
"Quantization scale from QASYMM8");
136 offset->set_help(
"Quantization offset from QASYMM8");
137 output_scale->set_help(
"Quantization scale from QASYMM8");
138 output_offset->set_help(
"Quantization offset from QASYMM8");
139 input_npy->set_help(
"Use input .npy instead");
140 output_npy->set_help(
"Use .npy as a reference");
141 input_range_low->set_help(
"Lower bound for input randomization range");
142 input_range_high->set_help(
"Lower bound for input randomization range");
143 weights_scale->set_help(
"Quantization scale from QASYMM8");
144 weights_offset->set_help(
"Quantization offset from QASYMM8");
145 weights_range_low->set_help(
"Lower bound for input randomization range");
146 weights_range_high->set_help(
"Lower bound for input randomization range");
147 depth_multiplier->set_help(
"Depth multiplier");
161 common_params.
input.
fm = channels->value();
164 common_params.
input.
npy = input_npy->value();
170 common_params.
weights.
npy = weights_npy->value();
175 common_params.
bias.
npy = bias_npy->value();
178 common_params.
output.
npy = output_npy->value();
194 void print_parameters(::std::ostream &os,
const ExampleParams &common_params)
override
198 os <<
"Data type : " << common_params.
data_type << std::endl;
199 os <<
"Input dimensions(X,Y, Channels, Batch) : (" << common_params.
input.
width <<
"," << common_params.
input.
height <<
"," << common_params.
input.
fm <<
"," << common_params.
input.
batch <<
")"
201 os <<
"Weight dimensions(X,Y, Channels(same as input)) : (" << common_params.
weights.
width <<
"," << common_params.
weights.
height <<
"," << common_params.
input.
fm <<
","
212 DepthConvolutionOptions(
const DepthConvolutionOptions &) =
delete;
214 DepthConvolutionOptions &operator=(
const DepthConvolutionOptions &) =
delete;
216 DepthConvolutionOptions(DepthConvolutionOptions &&) noexcept(true) =
default;
218 DepthConvolutionOptions &operator=(DepthConvolutionOptions &&) noexcept(true) =
default;
220 ~DepthConvolutionOptions() override =
default;
257 template <typename D>
258 class DepthConvolutionVerifyAccessor final : public
VerifyAccessor<D>
262 using BaseClassType::BaseClassType;
263 using BaseClassType::_params;
274 _params.convolution.depth_multiplier,
276 _params.output.quant_info);
279 float relative_tolerance()
override
281 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
299 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
302 float absolute_tolerance()
override
304 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
322 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
325 float tolerance_number()
override
327 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
345 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
351 class GraphDepthwiseConvolutionValidateExample final :
public GraphValidateExample<DepthwiseConvolutionLayer, DepthConvolutionOptions, DepthConvolutionVerifyAccessor>
356 GraphDepthwiseConvolutionValidateExample()
389 int main(
int argc,
char **argv)
391 return arm_compute::utils::run_example<GraphDepthwiseConvolutionValidateExample>(argc, argv);