35 #include "tests/datasets/ShapeDatasets.h" 40 #include "tests/validation/fixtures/GEMMFixture.h" 61 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
66 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
70 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
75 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
85 constexpr
float abs_tolerance_f16_mixed_precision(0.01f);
88 constexpr
float abs_tolerance_f16(0.01f);
328 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
331 &input1_info.clone()->set_is_resizable(
true),
332 &input2_info.clone()->set_is_resizable(
true),
333 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
347 m0_values_precommit),
348 n0_values_precommit),
349 k0_values_precommit),
350 v0_values_precommit),
351 h0_values_precommit),
354 framework::dataset::
make("export_to_cl_image_rhs", false)),
357 beta_values_precommit),
358 broadcast_bias_values),
359 lhs_transpose_values),
382 beta_values_nightly),
383 broadcast_bias_values),
384 lhs_transpose_values),
398 m0_values_precommit),
399 n0_values_precommit),
400 k0_values_precommit),
401 v0_values_precommit),
402 h0_values_precommit),
408 beta_values_precommit),
409 lhs_transpose_values),
433 beta_values_nightly),
434 lhs_transpose_values),
563 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
566 &input1_info.clone()->set_is_resizable(
true),
567 &input2_info.clone()->set_is_resizable(
true),
568 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
580 m0_values_precommit),
581 n0_values_precommit),
582 k0_values_precommit),
583 v0_values_precommit),
584 h0_values_precommit),
590 beta_values_precommit),
591 broadcast_bias_values),
592 lhs_transpose_values),
615 n0_export_to_cl_image_values_nightly),
616 k0_export_to_cl_image_values_nightly),
624 beta_values_nightly),
625 broadcast_bias_values),
626 lhs_transpose_values),
648 m0_values_precommit),
649 n0_values_precommit),
650 k0_values_precommit),
651 v0_values_precommit),
652 h0_values_precommit),
658 beta_values_precommit),
659 lhs_transpose_values),
682 n0_export_to_cl_image_values_nightly),
683 k0_export_to_cl_image_values_nightly),
691 beta_values_nightly),
692 lhs_transpose_values),
717 m0_values_precommit),
718 n0_values_precommit),
719 k0_values_precommit),
720 v0_values_precommit),
721 h0_values_precommit),
727 beta_values_precommit),
728 broadcast_bias_values),
729 lhs_transpose_values),
752 beta_values_nightly),
753 broadcast_bias_values),
754 lhs_transpose_values),
768 m0_values_precommit),
769 n0_values_precommit),
770 k0_values_precommit),
771 v0_values_precommit),
772 h0_values_precommit),
778 beta_values_precommit),
779 lhs_transpose_values),
803 beta_values_nightly),
804 lhs_transpose_values),
934 input0_info ,input1_info, input2_info,
output_info, lhs_info, rhs_info, gemm_info,
expected)
937 &input1_info.clone()->set_is_resizable(
true),
938 &input2_info.clone()->set_is_resizable(
true),
939 &output_info.clone()->set_is_resizable(
true),1.f,1.f,
951 m0_values_precommit),
952 n0_values_precommit),
953 k0_values_precommit),
954 v0_values_precommit),
955 h0_values_precommit),
961 beta_values_precommit),
962 broadcast_bias_values),
963 lhs_transpose_values),
986 n0_export_to_cl_image_values_nightly),
987 k0_export_to_cl_image_values_nightly),
995 beta_values_nightly),
996 broadcast_bias_values),
997 lhs_transpose_values),
1003 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1019 m0_values_precommit),
1020 n0_values_precommit),
1021 k0_values_precommit),
1022 v0_values_precommit),
1023 h0_values_precommit),
1028 a_values_precommit),
1029 beta_values_precommit),
1030 lhs_transpose_values),
1036 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1053 n0_export_to_cl_image_values_nightly),
1054 k0_export_to_cl_image_values_nightly),
1062 beta_values_nightly),
1063 lhs_transpose_values),
1069 validate(
CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1088 m0_values_precommit),
1089 n0_values_precommit),
1090 k0_values_precommit),
1091 v0_values_precommit),
1092 h0_values_precommit),
1095 framework::dataset::
make("export_to_cl_image_rhs", false)),
1097 a_values_precommit),
1098 beta_values_precommit),
1099 broadcast_bias_values),
1100 lhs_transpose_values),
1104 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1123 beta_values_nightly),
1124 broadcast_bias_values),
1125 lhs_transpose_values),
1129 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1139 m0_values_precommit),
1140 n0_values_precommit),
1141 k0_values_precommit),
1142 v0_values_precommit),
1143 h0_values_precommit),
1148 a_values_precommit),
1149 beta_values_precommit),
1150 lhs_transpose_values),
1154 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1174 beta_values_nightly),
1175 lhs_transpose_values),
1179 validate(
CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
bool image2d_from_buffer_supported(const cl::Device &device)
Helper function to check whether the cl_khr_image2d_from_buffer extension is supported.
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration of CLGEMMMatrixMultiplyResh...
Descriptor used by the GEMM kernels.
GEMMMatrixMultiplyInterleavedTransposed3DValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped > CLGEMMMatrixMultiplyReshaped3DFixture
half_float::half half
16-bit floating point type
1 channel, 1 F32 per channel
ARM_COMPUTE_EXPECT(has_error==expected, framework::LogLevel::ERRORS)
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
std::enable_if< is_container< T >::value, ContainerDataset< T > >::type make(std::string name, T &&values)
Helper function to create a ContainerDataset.
Activation Layer Information class.
#define ARM_COMPUTE_TEST_INFO(INFO)
Copyright (c) 2017-2021 Arm Limited.
GEMMMatrixMultiplyReshapedValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true > CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture
1 channel, 1 F16 per channel
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16), }), framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QASYMM8), TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QSYMM16, QuantizationInfo(1.f/32768.f, 0)), })), framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::SQRT), })), framework::dataset::make("Expected", { false, true, true, true, false, false, true, true, false })), input_info, output_info, act_info, expected)
DatasetMode
Possible dataset modes.
GEMMMatrixMultiplyReshaped3DValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true > CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture
GEMM RHS (Right Hand Side) matrix information.
TEST_SUITE_END() FIXTURE_DATA_TEST_CASE(RunSmall
[CLActivationLayer Test snippet]
quantized, asymmetric fixed-point 8-bit number unsigned
Accessor implementation for CLTensor objects.
CLSynthetizeFunction< CLGEMMReshapeRHSMatrixKernel > CLGEMMReshapeRHSMatrix
TEST_SUITE(U8_to_S8) FIXTURE_DATA_TEST_CASE(RunSmall
validate(CLAccessor(output_state), expected_output)
void ARM_COMPUTE_PRINT_INFO()
Lower and Upper Bounded Rectifier ( )
FIXTURE_DATA_TEST_CASE(RunSmall, CLAbsLayerFixture< half >, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)))
Class reprensenting a relative tolerance value.
GEMMMatrixMultiplyInterleavedTransposedValidationFixture< CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped > CLGEMMMatrixMultiplyReshapedFixture
Store the tensor's metadata.
zip(zip(framework::dataset::make("Weights", { TensorInfo(TensorShape(32U, 13U, 2U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U, 1U), 1, DataType::F32), }), framework::dataset::make("MVBGInfo",{ TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F16), TensorInfo(TensorShape(5U), 1, DataType::F32), })), framework::dataset::make("Expected", { true, false, false}))
DataType
Available data types.
constexpr float abs_tolerance_f32(0.0001f)
F32 Absolute tolerance value for comparing reference's output against implementation's output for flo...
combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)))