35 #include "tests/datasets/ShapeDatasets.h"
40 #include "tests/validation/fixtures/GEMMFixture.h"
62 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
67 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
71 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
76 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
86 constexpr
float abs_tolerance_f16_mixed_precision(0.01f);
89 constexpr
float abs_tolerance_f16(0.01f);
225 GEMMLHSMatrixInfo(4,4,1,
false,
true),
226 GEMMLHSMatrixInfo(4,4,1,
false,
true),
227 GEMMLHSMatrixInfo(4,4,1,
false,
true),
228 GEMMLHSMatrixInfo(4,2,4,
false,
false),
229 GEMMLHSMatrixInfo(4,2,4,
false,
false),
230 GEMMLHSMatrixInfo(4,4,1,
false,
true),
231 GEMMLHSMatrixInfo(4,4,1,
false,
true),
232 GEMMLHSMatrixInfo(4,4,1,
false,
true),
236 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
237 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
238 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
239 GEMMRHSMatrixInfo(2,2,1,
true,
false,
false),
240 GEMMRHSMatrixInfo(2,2,1,
true,
false,
false),
241 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
242 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
243 GEMMRHSMatrixInfo(4,4,2,
true,
false,
false),
257 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
260 GEMMLHSMatrixInfo(4,4,1,
false,
true),
261 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
272 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
275 GEMMLHSMatrixInfo(4,4,1,
false,
true),
276 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
290 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
293 GEMMLHSMatrixInfo(4,4,1,
false,
true),
294 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
306 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
309 GEMMLHSMatrixInfo(4,4,1,
false,
true),
310 GEMMRHSMatrixInfo(4,4,1,
true,
true,
false),
321 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
324 GEMMLHSMatrixInfo(4,4,1,
false,
true),
325 GEMMRHSMatrixInfo(4,4,2,
true,
false,
false),
330 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
333 &input1_info.clone()->set_is_resizable(
true),
334 &input2_info.clone()->set_is_resizable(
true),
335 &
output_info.clone()->set_is_resizable(
true),1.f,1.f,
353 v0_values_precommit),
354 h0_values_precommit),
357 framework::dataset::
make("export_to_cl_image_rhs", false)),
360 beta_values_precommit),
361 broadcast_bias_values),
362 lhs_transpose_values),
393 beta_values_nightly),
394 broadcast_bias_values),
395 lhs_transpose_values),
420 v0_values_precommit),
421 h0_values_precommit),
427 beta_values_precommit),
428 lhs_transpose_values),
460 beta_values_nightly),
461 lhs_transpose_values),
508 GEMMLHSMatrixInfo(4, 4, 1,
false,
true),
509 GEMMLHSMatrixInfo(4, 8, 1,
false,
true),
510 GEMMLHSMatrixInfo(4, 4, 1,
false,
true),
511 GEMMLHSMatrixInfo(4, 2, 1,
false,
false),
512 GEMMLHSMatrixInfo(4, 4, 1,
false,
false),
516 GEMMRHSMatrixInfo(4, 4, 1,
true,
true,
true),
517 GEMMRHSMatrixInfo(4, 8, 1,
true,
true,
true),
518 GEMMRHSMatrixInfo(8, 4, 1,
true,
true,
true),
519 GEMMRHSMatrixInfo(4, 2, 1,
true,
false,
true),
520 GEMMRHSMatrixInfo(2, 4, 1,
true,
false,
true),
529 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
543 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
557 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
572 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
586 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
599 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
602 &input1_info.clone()->set_is_resizable(
true),
603 &input2_info.clone()->set_is_resizable(
true),
604 &
output_info.clone()->set_is_resizable(
true),1.f,1.f,
619 v0_values_precommit),
620 h0_values_precommit),
626 beta_values_precommit),
627 broadcast_bias_values),
628 lhs_transpose_values),
651 n0_export_to_cl_image_values_nightly),
652 k0_export_to_cl_image_values_nightly),
660 beta_values_nightly),
661 broadcast_bias_values),
662 lhs_transpose_values),
687 v0_values_precommit),
688 h0_values_precommit),
694 beta_values_precommit),
695 lhs_transpose_values),
718 n0_export_to_cl_image_values_nightly),
719 k0_export_to_cl_image_values_nightly),
727 beta_values_nightly),
728 lhs_transpose_values),
757 v0_values_precommit),
758 h0_values_precommit),
761 framework::dataset::
make("export_to_cl_image_rhs", false)),
764 beta_values_precommit),
765 broadcast_bias_values),
766 lhs_transpose_values),
797 beta_values_nightly),
798 broadcast_bias_values),
799 lhs_transpose_values),
824 v0_values_precommit),
825 h0_values_precommit),
831 beta_values_precommit),
832 lhs_transpose_values),
864 beta_values_nightly),
865 lhs_transpose_values),
912 GEMMLHSMatrixInfo(4, 4, 1,
false,
true),
913 GEMMLHSMatrixInfo(4, 8, 1,
false,
true),
914 GEMMLHSMatrixInfo(4, 4, 1,
false,
true),
915 GEMMLHSMatrixInfo(4, 2, 1,
false,
false),
916 GEMMLHSMatrixInfo(4, 4, 1,
false,
false),
920 GEMMRHSMatrixInfo(4, 4, 1,
true,
true,
true),
921 GEMMRHSMatrixInfo(4, 8, 1,
true,
true,
true),
922 GEMMRHSMatrixInfo(8, 4, 1,
true,
true,
true),
923 GEMMRHSMatrixInfo(4, 2, 1,
true,
false,
true),
924 GEMMRHSMatrixInfo(2, 4, 1,
true,
false,
true),
933 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
947 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
961 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
976 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
990 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1003 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
1006 &input1_info.clone()->set_is_resizable(
true),
1007 &input2_info.clone()->set_is_resizable(
true),
1008 &
output_info.clone()->set_is_resizable(
true),1.f,1.f,
1023 v0_values_precommit),
1024 h0_values_precommit),
1029 a_values_precommit),
1030 beta_values_precommit),
1031 broadcast_bias_values),
1032 lhs_transpose_values),
1038 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1055 n0_export_to_cl_image_values_nightly),
1056 k0_export_to_cl_image_values_nightly),
1064 beta_values_nightly),
1065 broadcast_bias_values),
1066 lhs_transpose_values),
1072 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1091 v0_values_precommit),
1092 h0_values_precommit),
1097 a_values_precommit),
1098 beta_values_precommit),
1099 lhs_transpose_values),
1105 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1122 n0_export_to_cl_image_values_nightly),
1123 k0_export_to_cl_image_values_nightly),
1131 beta_values_nightly),
1132 lhs_transpose_values),
1138 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1161 v0_values_precommit),
1162 h0_values_precommit),
1165 framework::dataset::
make("export_to_cl_image_rhs", false)),
1167 a_values_precommit),
1168 beta_values_precommit),
1169 broadcast_bias_values),
1170 lhs_transpose_values),
1176 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1201 beta_values_nightly),
1202 broadcast_bias_values),
1203 lhs_transpose_values),
1209 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1228 v0_values_precommit),
1229 h0_values_precommit),
1234 a_values_precommit),
1235 beta_values_precommit),
1236 lhs_transpose_values),
1242 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1268 beta_values_nightly),
1269 lhs_transpose_values),
1275 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);