27 #include "tests/datasets/LargeMatMulDataset.h"
28 #include "tests/datasets/MatMulDataset.h"
29 #include "tests/datasets/SmallMatMulDataset.h"
33 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h"
52 constexpr
float abs_tolerance_f16(
76 class DFMatMulDataset final :
public datasets::MatMulDataset
97 using MatMulConfigurationPair = std::pair<MatMulKernelInfo, bool>;
99 const std::vector<MatMulConfigurationPair> supported_block_sizes = {
103 {MatMulKernelInfo(
false,
true, 0, 1, 1),
false},
104 {MatMulKernelInfo(
false,
true, 3, 11, 1),
false},
105 {MatMulKernelInfo(
false,
true, 3, 7, 1),
false},
106 {MatMulKernelInfo(
false,
true, 3, 3, 12),
false},
107 {MatMulKernelInfo(
false,
true, 3, 3, 6),
false},
108 {MatMulKernelInfo(
false,
true, 5, 1, 2),
true}, {MatMulKernelInfo(
false,
true, 3, 3, 3),
true},
109 {MatMulKernelInfo(
false,
true, 2, 4, 8),
true},
115 auto context = GpuWorkloadContext{&cl_compile_ctx};
121 const ITensorInfo *lhs_info =
context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1,
DataType::F32));
122 const ITensorInfo *rhs_info =
context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1,
DataType::F32));
124 for (
auto &pair : supported_block_sizes)
126 MatMulAttributes matmul_attr{};
127 matmul_attr.adj_lhs(pair.first.adj_lhs);
128 matmul_attr.adj_rhs(pair.first.adj_rhs);
130 GpuMatMulSettings matmul_settings{};
131 matmul_settings.m0(pair.first.m0);
132 matmul_settings.n0(pair.first.n0);
133 matmul_settings.k0(pair.first.k0);
144 auto context = GpuWorkloadContext{&cl_compile_ctx};
148 using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
149 const std::vector<ShapeConfigurationTuple> shape_configurations = {
150 {TensorShape(5U, 1U), TensorShape(3U, 5U),
true},
151 {TensorShape(10U, 12U), TensorShape(3U, 10U),
true},
152 {TensorShape(8U, 4U), TensorShape(2U, 8U),
true},
153 {TensorShape(8U, 4U), TensorShape(2U, 5U),
false},
154 {TensorShape(5U, 0U), TensorShape(2U, 5U),
false},
155 {TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U),
true},
156 {TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U),
false},
157 {TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U),
161 for (
auto &tuple : shape_configurations)
163 const bool expected = std::get<2>(tuple);
165 for (
bool adj_lhs : {
false})
167 for (
bool adj_rhs : {
true})
169 TensorShape lhs_shape = std::get<0>(tuple);
170 TensorShape rhs_shape = std::get<1>(tuple);
182 const ITensorInfo *lhs_info =
context.create_tensor_info(TensorInfo(lhs_shape, 1,
DataType::F32));
183 const ITensorInfo *rhs_info =
context.create_tensor_info(TensorInfo(rhs_shape, 1,
DataType::F32));
185 MatMulAttributes matmul_attr{};
186 matmul_attr.adj_lhs(adj_lhs);
187 matmul_attr.adj_rhs(adj_rhs);
189 GpuMatMulSettings matmul_settings{};
190 matmul_settings.m0(1);
191 matmul_settings.n0(1);
192 matmul_settings.k0(1);
204 using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
205 const std::vector<DataTypeConfigurationTuple> data_type_configurations = {
228 auto context = GpuWorkloadContext{&cl_compile_ctx};
231 const TensorShape
shape = TensorShape(10U, 10U);
232 MatMulAttributes matmul_attr{};
233 matmul_attr.adj_lhs(
false);
234 matmul_attr.adj_rhs(
false);
235 GpuMatMulSettings matmul_settings{};
236 matmul_settings.m0(1);
237 matmul_settings.n0(1);
238 matmul_settings.k0(1);
240 for (
auto &tuple : data_type_configurations)
242 const bool expected = std::get<3>(tuple);
244 const ITensorInfo *lhs_info =
context.create_tensor_info(TensorInfo(
shape, 1, std::get<0>(tuple)));
245 const ITensorInfo *rhs_info =
context.create_tensor_info(TensorInfo(
shape, 1, std::get<1>(tuple)));
254 template <typename T>
255 using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture<CLTensor, CLAccessor, GpuMatMul, T>;
261 DynamicFusionGpuMatmulFixture<
float>,
264 framework::dataset::
make("TransposeA", {
false}),
277 DynamicFusionGpuMatmulFixture<float>,
296 DynamicFusionGpuMatmulFixture<
half>,
299 framework::dataset::
make("TransposeA", {
false}),