29 #include "tests/datasets/LargeMatMulDataset.h"
30 #include "tests/datasets/SmallMatMulDataset.h"
34 #include "tests/validation/fixtures/MatMulKernelFixture.h"
47 constexpr AbsoluteTolerance<float> tolerance_quant(1);
83 using MatMulConfigurationPair = std::pair<MatMulKernelInfo, bool>;
85 const std::vector<MatMulConfigurationPair> supported_block_sizes =
106 for(
auto &pair : supported_block_sizes)
118 using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>;
119 const std::vector<ShapeConfigurationTuple> shape_configurations =
123 { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U),
true },
124 { TensorShape(8U, 4U), TensorShape(2U, 5U), TensorShape(2U),
false },
125 { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U),
false },
126 { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U),
true },
127 { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U),
false },
128 { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U),
false },
129 { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(1U),
false },
130 { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U, 3U),
false },
133 for(
auto &tuple : shape_configurations)
135 const bool expected = std::get<3>(tuple);
147 TensorShape lhs_shape = std::get<0>(tuple);
148 TensorShape rhs_shape = std::get<1>(tuple);
149 TensorShape bia_shape = std::get<2>(tuple);
163 const TensorInfo bia_info = TensorInfo(bia_shape, 1,
DataType::S32);
166 MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1,
false };
177 using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, DataType, bool>;
178 const std::vector<DataTypeConfigurationTuple> data_type_configurations =
202 const TensorShape
shape = TensorShape(10U, 10U);
203 const TensorShape bia_shape = TensorShape(10U);
204 const MatMulKernelInfo matmul_kernel_info{
false,
false, 1, 1, 1,
false };
205 for(
auto &tuple : data_type_configurations)
207 const bool expected = std::get<4>(tuple);
209 const TensorInfo lhs_info(
shape, 1, std::get<0>(tuple));
210 const TensorInfo rhs_info(
shape, 1, std::get<1>(tuple));
211 const TensorInfo bia_info(bia_shape, 1, std::get<2>(tuple));
224 framework::dataset::
make("TransposeA", {
true,
false })),
233 validate(CLAccessor(_target), _reference, tolerance_quant);
245 validate(CLAccessor(_target), _reference, tolerance_quant);
257 validate(CLAccessor(_target), _reference, tolerance_quant);
270 validate(CLAccessor(_target), _reference, tolerance_quant);
283 validate(CLAccessor(_target), _reference, tolerance_quant);
296 validate(CLAccessor(_target), _reference, tolerance_quant);
309 validate(CLAccessor(_target), _reference, tolerance_quant);
324 validate(CLAccessor(_target), _reference, tolerance_quant);
330 framework::dataset::
make("TransposeA", {
true,
false })),
339 validate(CLAccessor(_target), _reference, tolerance_quant);
351 validate(CLAccessor(_target), _reference, tolerance_quant);
364 validate(CLAccessor(_target), _reference, tolerance_quant);
377 validate(CLAccessor(_target), _reference, tolerance_quant);
390 validate(CLAccessor(_target), _reference, tolerance_quant);
403 validate(CLAccessor(_target), _reference, tolerance_quant);