49 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
53 &ClGemmDefaultConfigReshapedValhall::configure_G77_f32, &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
54 &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
57 &ClGemmDefaultConfigReshapedValhall::configure_G78_f32, &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
58 &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
60 ConfigurationFunctionExecutorPtr func =
nullptr;
74 return (this->*func)(m, n, k,
b);
77 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
78 ClGemmDefaultConfigReshapedValhall::configure_G77_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
85 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
89 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1);
93 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
94 ClGemmDefaultConfigReshapedValhall::configure_G77_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
99 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
100 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
101 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
102 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
109 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
111 if (r_mk <= 0.11824845522642136)
113 if (workload <= 880.0)
115 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
119 if (r_nk <= 0.42521367967128754)
121 if (workload <= 1726.4000244140625)
123 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
127 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
130 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
135 if (workload <= 1241.6000366210938)
137 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
141 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
148 if (workload <= 11404.7998046875)
150 if (r_mk <= 1.0126488208770752)
152 if (r_mn <= 2.545312523841858)
154 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
157 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
161 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
166 if (workload <= 2881.199951171875)
168 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
171 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
175 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
178 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
184 if (r_nk <= 0.5765306055545807)
186 if (r_mn <= 6.010416746139526)
188 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
191 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
195 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
198 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
203 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
206 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
212 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
213 ClGemmDefaultConfigReshapedValhall::configure_G78_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
215 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
216 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
217 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
218 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
220 if (workload <= 1288.0000f)
222 if (workload <= 505.6000f)
228 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
232 return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
237 return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
246 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
250 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
259 if (workload <= 1089.6000f)
261 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
265 return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1);
270 return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1);
275 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
282 if (workload <= 5434.4001f)
284 if (workload <= 1603.2000f)
286 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
292 if (r_mn <= 16.1016f)
294 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
298 if (workload <= 2750.0000f)
300 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
306 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
310 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
319 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
327 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
331 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
336 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
344 if (r_mk <= 25.7500f)
352 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
356 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
361 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
366 if (workload <= 11174.3999f)
370 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
374 if (workload <= 7185.5999f)
376 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
380 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
386 if (workload <= 17917.5000f)
390 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
394 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
399 if (workload <= 34449.6016f)
401 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
405 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
413 if (r_mk <= 331.1111f)
415 if (workload <= 53397.5996f)
417 if (r_mn <= 57.8063f)
419 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
423 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
430 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
434 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
440 if (workload <= 38070.4004f)
442 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
446 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
454 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
455 ClGemmDefaultConfigReshapedValhall::configure_G78_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
457 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
458 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
459 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
461 if (workload <= 801.6000f)
463 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
469 if (workload <= 3296.0000f)
471 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
477 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
481 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
487 if (workload <= 5068.8000f)
489 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
495 if (workload <= 12630.0000f)
497 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
501 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1);
506 if (workload <= 178790.3984f)
508 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
512 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
520 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
521 ClGemmDefaultConfigReshapedValhall::configure_G77_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
528 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
532 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1);