56 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
60 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
61 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
62 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
65 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
66 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
67 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
70 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
71 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16,
72 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
75 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
76 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
77 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
79 ConfigurationFunctionExecutorPtr func =
nullptr;
101 return (this->*func)(m, n, k,
b);
104 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(
105 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
109 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
110 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
112 if (r_mk <= 0.0064484127797186375)
114 if (r_mn <= 0.0028273810748942196)
121 const unsigned int h0 = std::max(n / 4, 1U);
122 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1);
123 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0);
126 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
130 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0);
135 if (r_mk <= 0.020312500186264515)
137 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0);
141 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0);
147 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
148 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
149 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
151 if (workload <= 1999.2000122070312)
153 if (workload <= 747.1999816894531)
155 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
159 GEMMLHSMatrixInfo lhs_info_buf;
160 GEMMRHSMatrixInfo rhs_info_buf;
161 GEMMLHSMatrixInfo lhs_info_img;
162 GEMMRHSMatrixInfo rhs_info_img;
163 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
164 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
167 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
172 if (r_mn <= 0.03348214365541935)
174 if (r_mk <= 0.028125000186264515)
176 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
180 GEMMLHSMatrixInfo lhs_info_buf;
181 GEMMRHSMatrixInfo rhs_info_buf;
182 GEMMLHSMatrixInfo lhs_info_img;
183 GEMMRHSMatrixInfo rhs_info_img;
184 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
185 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
188 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
193 GEMMLHSMatrixInfo lhs_info_buf;
194 GEMMRHSMatrixInfo rhs_info_buf;
195 GEMMLHSMatrixInfo lhs_info_img;
196 GEMMRHSMatrixInfo rhs_info_img;
197 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1);
198 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0);
201 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
207 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(
208 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
211 {1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0},
212 {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0},
213 {1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0}, {1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0},
214 {1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}};
216 const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0},
217 {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1},
218 {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1},
219 {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1}};
221 const GeMMConfigsMatrix configs_mnkb_n_small_fallback = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0},
222 {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0},
223 {16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0},
224 {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}};
227 {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1},
228 {369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
229 {23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
230 {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
231 {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
232 {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0},
233 {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}};
236 {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0},
237 {369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
238 {23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
239 {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
240 {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0},
241 {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0},
242 {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}};
245 {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0},
246 {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1},
247 {49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
251 {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0},
252 {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
253 {49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0},
257 {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0},
258 {180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
259 {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1},
260 {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}};
263 {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0},
264 {180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
265 {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0},
266 {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}};
269 {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
270 {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
271 {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
272 {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}};
275 {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
276 {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
277 {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
278 {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}};
285 constexpr
float ratio_m_gt_n = 10.f;
286 constexpr
float ratio_n_gt_m = 0.1f;
287 constexpr
unsigned int n_small_thr = 4;
288 const float ratio =
static_cast<float>(m) /
static_cast<float>(n);
293 configs_best_to_use = &configs_1nkb_best;
294 configs_fallback_to_use = &configs_1nkb_best;
296 else if (n <= n_small_thr && ratio > ratio_m_gt_n)
298 configs_best_to_use = &configs_mnkb_n_small_best;
299 configs_fallback_to_use = &configs_mnkb_n_small_fallback;
301 else if (ratio > ratio_m_gt_n)
303 configs_best_to_use = &configs_mnkb_m_gt_n_best;
304 configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
306 else if (ratio < ratio_n_gt_m)
308 configs_best_to_use = &configs_mnkb_n_gt_m_best;
309 configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
313 configs_best_to_use = &configs_mnkb_squared_best;
314 configs_fallback_to_use = &configs_mnkb_squared_fallback;
319 configs_best_to_use = &configs_mnkb_best_batched;
320 configs_fallback_to_use = &configs_mnkb_fallback_batched;
323 GEMMLHSMatrixInfo lhs_info0;
324 GEMMRHSMatrixInfo rhs_info0;
325 GEMMLHSMatrixInfo lhs_info1;
326 GEMMRHSMatrixInfo rhs_info1;
329 std::tie(lhs_info1, rhs_info1) =
find_lhs_rhs_info(*configs_fallback_to_use, m, n, k,
b);
331 return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k,
b,
335 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(
336 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
343 const unsigned int h0 = std::max(n / 2, 1U);
344 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
348 const int h0 = std::max(std::min(
static_cast<int>(n / 4),
static_cast<int>(256)),
static_cast<int>(1));
351 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1);
355 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1);
360 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(
361 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
363 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
364 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
365 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
366 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
370 if (workload <= 278.7000f)
372 if (workload <= 7.5000f)
374 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
380 if (workload <= 256.6000f)
382 if (workload <= 16.7500f)
386 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
390 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
395 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
400 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
409 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
413 if (workload <= 8.9500f)
415 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
419 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
425 if (workload <= 14.1500f)
427 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
433 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
437 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
446 if (workload <= 363.7000f)
450 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
454 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0);
459 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
465 if (workload <= 1384.8000f)
467 if (workload <= 704.0000f)
469 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0);
473 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
478 if (workload <= 16761.6006f)
480 if (r_mn <= 187.1250f)
482 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1);
486 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
491 if (r_mk <= 432.4630f)
493 return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1);
497 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1);
504 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(
505 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
507 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
508 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
509 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
510 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
515 {1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
516 {1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0},
517 {1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
518 {1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}};
524 if (workload <= 1384.8000f)
530 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
536 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
540 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
548 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
552 return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
558 if (workload <= 11404.7998f)
564 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
568 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
573 return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
580 if (r_mn <= 1385.7917f)
582 return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
586 return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0);
591 return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1);
598 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(
599 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
601 unsigned int best_m0;
602 unsigned int best_n0;
606 return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4,
false,
true,
false,
false,
true);
610 return configure_G77_f32(m, n, k,
b);
614 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(
615 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
618 {1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0},
619 {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
620 {1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
621 {1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}};
623 const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
624 {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
625 {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
626 {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}};
629 {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1},
630 {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
631 {23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
632 {8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
633 {50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
634 {12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
635 {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0},
639 {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0},
640 {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
641 {23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0},
642 {8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, {2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0},
643 {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0},
644 {12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0},
645 {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0},
648 const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0},
649 {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0},
650 {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}};
652 const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0},
653 {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0},
654 {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}};
657 {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0},
658 {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
659 {180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0},
660 {272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1},
664 {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0},
665 {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
666 {180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0},
667 {272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0},
671 {3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1}, {4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
672 {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1}, {24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
673 {112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
674 {1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}};
677 {3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0},
678 {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0},
679 {112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0},
680 {1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}};
687 constexpr
float ratio_m_gt_n = 10.f;
688 constexpr
float ratio_n_gt_m = 0.1f;
689 constexpr
unsigned int n_small_thr = 4;
690 const float ratio =
static_cast<float>(m) /
static_cast<float>(n);
695 configs_best_to_use = &configs_1nkb_best;
696 configs_fallback_to_use = &configs_1nkb_best;
698 else if (n <= n_small_thr && ratio > ratio_m_gt_n)
700 configs_best_to_use = &configs_mnkb_n_small_best;
701 configs_fallback_to_use = &configs_mnkb_n_small_best;
703 else if (ratio > ratio_m_gt_n)
705 configs_best_to_use = &configs_mnkb_m_gt_n_best;
706 configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
708 else if (ratio < ratio_n_gt_m)
710 configs_best_to_use = &configs_mnkb_n_gt_m_best;
711 configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
715 configs_best_to_use = &configs_mnkb_squared_best;
716 configs_fallback_to_use = &configs_mnkb_squared_fallback;
721 configs_best_to_use = &configs_mnkb_best_batched;
722 configs_fallback_to_use = &configs_mnkb_fallback_batched;
725 GEMMLHSMatrixInfo lhs_info0;
726 GEMMRHSMatrixInfo rhs_info0;
727 GEMMLHSMatrixInfo lhs_info1;
728 GEMMRHSMatrixInfo rhs_info1;
731 std::tie(lhs_info1, rhs_info1) =
find_lhs_rhs_info(*configs_fallback_to_use, m, n, k,
b);
733 return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k,
b,
737 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(
738 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
740 unsigned int best_m0;
741 unsigned int best_n0;
745 return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4,
false,
true,
false,
false,
true);
749 return configure_G78_f16(m, n, k,
b);