30 #include "tests/datasets/DirectConvolutionLayerDataset.h"
31 #include "tests/datasets/ShapeDatasets.h"
36 #include "tests/validation/fixtures/DirectConvolutionLayerFixture.h"
54 constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
65 const auto data_all_kernels =
concat(
concat(data_ksize_one, data_ksize_three), data_ksize_five);
67 const auto data =
combine(datasets::SmallDirectConvolutionShapes(),
combine(data_strides, data_all_kernels));
68 const auto data9x9 =
combine(datasets::SmallDirectConvolutionShapes(),
combine(data_strides, data_ksize_nine));
69 const auto data_small =
combine(datasets::SmallDirectConvolutionShapes(),
combine(data_strides_small, data_ksize_one_small));
70 const auto data_small9x9 =
combine(datasets::SmallDirectConvolutionShapes(),
combine(data_strides_small, data_ksize_nine_small));
89 { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.5f) });
106 auto src = create_tensor<CLTensor>(src_shape,
dt);
107 auto weights = create_tensor<CLTensor>(weights_shape,
dt);
116 src.allocator()->allocate();
117 weights.allocator()->allocate();
118 dst.allocator()->allocate();
129 library->fill_tensor_value(ref_src, 1.f);
130 library->fill_tensor_value(ref_weights, 1.f);
132 library->fill_tensor_value(ref_bias, 0.f);
133 auto ref_dst = reference::convolution_layer<float>(ref_src, ref_weights, ref_bias,
dst_shape,
conv_info);
167 src.allocator()->allocate();
168 weights.allocator()->allocate();
169 dst.allocator()->allocate();
180 library->fill_tensor_value(ref_src, 1.f);
181 library->fill_tensor_value(ref_weights, 1.f);
183 library->fill_tensor_value(ref_bias, 0.f);
184 auto ref_dst = reference::convolution_layer<float>(ref_src, ref_weights, ref_bias,
dst_shape,
conv_info);
224 PadStrideInfo(1, 1, 0, 0),
225 PadStrideInfo(1, 1, 0, 0),
226 PadStrideInfo(1, 1, 0, 0),
227 PadStrideInfo(1, 1, 0, 0),
228 PadStrideInfo(1, 1, 0, 0),
229 PadStrideInfo(1, 1, 0, 0),
233 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
234 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
235 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
236 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
237 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
238 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
239 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
250 template <
typename T>
252 template <
typename T>
254 template <
typename T>
256 template <
typename T>
258 template <
typename T>
260 template <
typename T>
267 framework::dataset::
make("InputInfo", {
288 PadStrideInfo(1, 1, 0, 0),
289 PadStrideInfo(1, 1, 0, 0),
290 PadStrideInfo(3, 3, 0, 0),
294 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
295 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
296 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
334 framework::dataset::make(
"ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::IDENTITY) )),
389 framework::dataset::make(
"ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::IDENTITY) )),
415 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
434 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
450 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
472 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
491 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
507 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
515 framework::dataset::
make("InputInfo", {
536 PadStrideInfo(1, 1, 0, 0),
537 PadStrideInfo(1, 1, 0, 0),
538 PadStrideInfo(3, 3, 0, 0)
542 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
543 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
544 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)
644 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
662 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
682 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
700 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
720 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
723 combine(datasets::DirectConvolutionLayerDataset(),
733 combine(datasets::DirectConvolutionLayerDataset(),
740 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
762 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
782 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
802 validate(CLAccessor(_target), _reference, tolerance_qasymm8);
806 combine(datasets::DirectConvolutionLayerDataset(),
817 combine(datasets::DirectConvolutionLayerDataset(),
824 validate(CLAccessor(_target), _reference, tolerance_qasymm8);