36 #include "tests/datasets/ShapeDatasets.h" 41 #include "tests/validation/fixtures/GEMMFixture.h" 63 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
68 GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
73 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
78 GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
82 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
87 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
97 constexpr
float abs_tolerance_f16_mixed_precision(0.01f);
100 constexpr
float abs_tolerance_f16(0.01f);
193 std::make_tuple(
true,
true,
false),
203 std::make_tuple(
false,
true,
true),
214 std::make_tuple(
false,
false,
true),
225 std::make_tuple(
false,
false,
true),
237 std::make_tuple(
false,
false,
false),
280 &reshaped_input1_info.
clone()->set_is_resizable(
true),
281 &input2_info.
clone()->set_is_resizable(
true),
282 &output_info.
clone()->set_is_resizable(
true),1.f,1.f,
445 &input1_info.clone()->set_is_resizable(
true),
446 &input2_info.clone()->set_is_resizable(
true),
447 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
457 const unsigned int m = 17;
458 const unsigned int n = 1;
459 const unsigned int k = 13;
460 const unsigned int batch = 2;
463 auto post_op_arg1_info = post_op_arg_info.
clone();
472 post_op_arg1_info.get(),
482 const unsigned int m = 17;
483 const unsigned int n = 1;
484 const unsigned int k = 13;
485 const unsigned int batch = 2;
497 const unsigned int m = 22;
498 const unsigned int n = 16;
499 const unsigned int k = 15;
500 const unsigned int batch = 3;
513 const unsigned int m = 22;
514 const unsigned int n = 16;
515 const unsigned int k = 15;
516 const unsigned int batch = 3;
524 const unsigned int m = 22;
525 const unsigned int n = 16;
526 const unsigned int k = 15;
527 const unsigned int batch = 3;
538 const unsigned int m = 22;
539 const unsigned int n = 16;
540 const unsigned int k = 15;
541 const unsigned int batch = 3;
552 const unsigned int m = 22;
553 const unsigned int n = 16;
554 const unsigned int k = 15;
555 const unsigned int batch = 3;
574 m0_values_precommit),
575 n0_values_precommit),
576 k0_values_precommit),
577 v0_values_precommit),
578 h0_values_precommit),
581 framework::dataset::
make("export_to_cl_image_rhs", false)),
584 beta_values_precommit),
585 broadcast_bias_values),
586 lhs_transpose_values),
617 beta_values_nightly),
618 broadcast_bias_values),
619 lhs_transpose_values),
641 m0_values_precommit),
642 n0_values_precommit),
643 k0_values_precommit),
644 v0_values_precommit),
645 h0_values_precommit),
651 beta_values_precommit),
652 lhs_transpose_values),
684 beta_values_nightly),
685 lhs_transpose_values),
707 m0_values_precommit),
708 n0_values_precommit),
709 k0_values_precommit),
710 v0_values_precommit),
711 h0_values_precommit),
712 framework::dataset::
make("interleave_lhs", {
false })),
717 beta_values_precommit),
719 lhs_transpose_values),
864 &input1_info.clone()->set_is_resizable(
true),
865 &input2_info.clone()->set_is_resizable(
true),
866 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
878 m0_values_precommit),
879 n0_values_precommit),
880 k0_values_precommit),
881 v0_values_precommit),
882 h0_values_precommit),
888 beta_values_precommit),
889 broadcast_bias_values),
890 lhs_transpose_values),
913 n0_export_to_cl_image_values_nightly),
914 k0_export_to_cl_image_values_nightly),
922 beta_values_nightly),
923 broadcast_bias_values),
924 lhs_transpose_values),
946 m0_values_precommit),
947 n0_values_precommit),
948 k0_values_precommit),
949 v0_values_precommit),
950 h0_values_precommit),
956 beta_values_precommit),
957 lhs_transpose_values),
980 n0_export_to_cl_image_values_nightly),
981 k0_export_to_cl_image_values_nightly),
989 beta_values_nightly),
990 lhs_transpose_values),
1012 m0_values_precommit),
1013 n0_values_precommit),
1014 k0_values_precommit),
1015 v0_values_precommit),
1016 h0_values_precommit),
1017 framework::dataset::
make("interleave_lhs", {
false })),
1021 a_values_precommit),
1022 beta_values_precommit),
1024 lhs_transpose_values),
1054 m0_values_precommit),
1055 n0_values_precommit),
1056 k0_values_precommit),
1057 v0_values_precommit),
1058 h0_values_precommit),
1061 framework::dataset::
make("export_to_cl_image_rhs", false)),
1063 a_values_precommit),
1064 beta_values_precommit),
1065 broadcast_bias_values),
1066 lhs_transpose_values),
1072 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1097 beta_values_nightly),
1098 broadcast_bias_values),
1099 lhs_transpose_values),
1105 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1121 m0_values_precommit),
1122 n0_values_precommit),
1123 k0_values_precommit),
1124 v0_values_precommit),
1125 h0_values_precommit),
1130 a_values_precommit),
1131 beta_values_precommit),
1132 lhs_transpose_values),
1138 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1164 beta_values_nightly),
1165 lhs_transpose_values),
1171 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1188 m0_values_precommit),
1189 n0_values_precommit),
1190 k0_values_precommit),
1191 v0_values_precommit),
1192 h0_values_precommit),
1193 framework::dataset::
make("interleave_lhs", {
false })),
1197 a_values_precommit),
1198 beta_values_precommit),
1200 lhs_transpose_values),
1208 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1345 &input1_info.clone()->set_is_resizable(
true),
1346 &input2_info.clone()->set_is_resizable(
true),
1347 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
1359 m0_values_precommit),
1360 n0_values_precommit),
1361 k0_values_precommit),
1362 v0_values_precommit),
1363 h0_values_precommit),
1368 a_values_precommit),
1369 beta_values_precommit),
1370 broadcast_bias_values),
1371 lhs_transpose_values),
1377 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1394 n0_export_to_cl_image_values_nightly),
1395 k0_export_to_cl_image_values_nightly),
1403 beta_values_nightly),
1404 broadcast_bias_values),
1405 lhs_transpose_values),
1411 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1427 m0_values_precommit),
1428 n0_values_precommit),
1429 k0_values_precommit),
1430 v0_values_precommit),
1431 h0_values_precommit),
1436 a_values_precommit),
1437 beta_values_precommit),
1438 lhs_transpose_values),
1444 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1461 n0_export_to_cl_image_values_nightly),
1462 k0_export_to_cl_image_values_nightly),
1470 beta_values_nightly),
1471 lhs_transpose_values),
1477 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1493 m0_values_precommit),
1494 n0_values_precommit),
1495 k0_values_precommit),
1496 v0_values_precommit),
1497 h0_values_precommit),
1498 framework::dataset::
make("interleave_lhs", {
false })),
1502 a_values_precommit),
1503 beta_values_precommit),
1505 lhs_transpose_values),
1513 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1535 m0_values_precommit),
1536 n0_values_precommit),
1537 k0_values_precommit),
1538 v0_values_precommit),
1539 h0_values_precommit),
1542 framework::dataset::
make("export_to_cl_image_rhs", false)),
1544 a_values_precommit),
1545 beta_values_precommit),
1546 broadcast_bias_values),
1547 lhs_transpose_values),
1553 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1578 beta_values_nightly),
1579 broadcast_bias_values),
1580 lhs_transpose_values),
1586 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1602 m0_values_precommit),
1603 n0_values_precommit),
1604 k0_values_precommit),
1605 v0_values_precommit),
1606 h0_values_precommit),
1611 a_values_precommit),
1612 beta_values_precommit),
1613 lhs_transpose_values),
1619 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1645 beta_values_nightly),
1646 lhs_transpose_values),
1652 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1669 m0_values_precommit),
1670 n0_values_precommit),
1671 k0_values_precommit),
1672 v0_values_precommit),
1673 h0_values_precommit),
1674 framework::dataset::
make("interleave_lhs", {
false })),
1678 a_values_precommit),
1679 beta_values_precommit),
1681 lhs_transpose_values),
1689 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped > CLGEMMMatrixMultiplyReshapedWithPostOpsFixture
bool image2d_from_buffer_supported(const cl::Device &device)
Helper function to check whether the cl_khr_image2d_from_buffer extension is supported.
experimental::PostOpList< ITensorInfo * > post_ops
std::unique_ptr< ITensorInfo > clone() const override
Provide a clone of the current object of class T.
Descriptor used by the GEMM kernels.
GEMMMatrixMultiplyReshaped3DValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped > CLGEMMMatrixMultiplyReshaped3DFixture
half_float::half half
16-bit floating point type
1 channel, 1 F32 per channel
ARM_COMPUTE_EXPECT(has_error==expected, framework::LogLevel::ERRORS)
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
std::enable_if< is_container< T >::value, ContainerDataset< T > >::type make(std::string name, T &&values)
Helper function to create a ContainerDataset.
Activation Layer Information class.
#define ARM_COMPUTE_TEST_INFO(INFO)
Copyright (c) 2017-2022 Arm Limited.
GEMMMatrixMultiplyReshapedValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true > CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture
1 channel, 1 F16 per channel
CLSynthetizeOperator< opencl::kernels::ClGemmReshapeRhsMatrixKernel > CLGEMMReshapeRHSMatrix
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), }), framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), })), framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::SQRT), })), framework::dataset::make("Expected", { false, true, true, true, false, false, true, true, false })), input_info, output_info, act_info, expected)
DatasetMode
Possible dataset modes.
GEMMMatrixMultiplyReshaped3DValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true > CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
TEST_SUITE_END() FIXTURE_DATA_TEST_CASE(RunSmall
[CLActivationLayer Test snippet]
quantized, asymmetric fixed-point 8-bit number unsigned
Accessor implementation for CLTensor objects.
GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true > CLGEMMMatrixMultiplyReshapedMixedPrecisionWithPostOpsFixture
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
validate(CLAccessor(output_state), expected_output)
void ARM_COMPUTE_PRINT_INFO()
Lower and Upper Bounded Rectifier ( )
FIXTURE_DATA_TEST_CASE(RunSmall, CLAbsLayerFixture< half >, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)))
Class reprensenting a relative tolerance value.
GEMMMatrixMultiplyReshapedValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped > CLGEMMMatrixMultiplyReshapedFixture
Store the tensor's metadata.
TEST_CASE(FusedActivation, framework::DatasetMode::ALL)
Validate fused activation expecting the following behaviours:
zip(zip(framework::dataset::make("Weights", { TensorInfo(TensorShape(32U, 13U, 2U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U, 1U), 1, DataType::F32), }), framework::dataset::make("MVBGInfo",{ TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F16), TensorInfo(TensorShape(5U), 1, DataType::F32), })), framework::dataset::make("Expected", { true, false, false}))
TEST_SUITE(QASYMM8_to_F32) FIXTURE_DATA_TEST_CASE(RunSmall
DataType
Available data types.
constexpr float abs_tolerance_f32(0.0001f)
F32 Absolute tolerance value for comparing reference's output against implementation's output for flo...
A sequence of PostOps that can be appended to the end of other operators.
combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)))
(EXPERIMENTAL_POST_OPS) Implementation of specific IPostOps