55 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
59 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
60 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
61 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
64 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
65 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
66 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
69 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
70 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
71 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
74 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
75 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
76 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
79 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
80 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
81 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
83 ConfigurationFunctionExecutorPtr func =
nullptr;
104 return (this->*func)(m, n, k,
b);
107 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(
108 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
117 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4,
false,
true,
false,
true,
false);
121 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8,
false,
true,
false,
true,
false);
126 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
130 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(
131 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
138 const unsigned int h0 = std::max(n / 2, 1U);
139 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
143 const int h0 = std::max(std::min(
static_cast<int>(n / 4),
static_cast<int>(256)),
static_cast<int>(1));
146 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
150 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
155 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(
156 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
161 GEMMLHSMatrixInfo lhs_info_buf;
162 GEMMRHSMatrixInfo rhs_info_buf;
163 GEMMLHSMatrixInfo lhs_info_img;
164 GEMMRHSMatrixInfo rhs_info_img;
166 const bool is_workload_big = ((m * n *
b) / 16) >= 2048;
172 const unsigned int h0 = std::max(n / 4, 1U);
173 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0,
false,
true,
false,
true,
false);
177 const unsigned int h0 = std::max(n / 2, 1U);
180 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true,
false);
184 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true,
false);
190 const int h0 = std::max(std::min(
static_cast<int>(n / 4),
static_cast<int>(16)),
static_cast<int>(1));
193 std::tie(lhs_info_buf, rhs_info_buf) =
194 configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
true);
198 std::tie(lhs_info_buf, rhs_info_buf) =
199 configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true);
204 const int h0 = std::max(std::min(
static_cast<int>(n / 4),
static_cast<int>(16)),
static_cast<int>(1));
207 std::tie(lhs_info_img, rhs_info_img) =
208 configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
false,
true);
212 std::tie(lhs_info_img, rhs_info_img) =
213 configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true,
true);
216 const TensorInfo tensor_rhs_info(TensorShape(n, k,
b), 1,
DataType::F32);
221 const bool use_cl_image2d = ((m == 1) || ((((m * n *
b) / 16) < 2048) && n < 128)) ? false :
true;
225 return std::make_pair(lhs_info_img, rhs_info_img);
229 return std::make_pair(lhs_info_buf, rhs_info_buf);
233 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(
234 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
236 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
237 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
239 GEMMLHSMatrixInfo lhs_info_buf;
240 GEMMRHSMatrixInfo rhs_info_buf;
241 GEMMLHSMatrixInfo lhs_info_img;
242 GEMMRHSMatrixInfo rhs_info_img;
248 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16,
false,
true,
false,
true,
false);
252 std::tie(lhs_info_img, rhs_info_img) =
253 configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
true);
254 std::tie(lhs_info_buf, rhs_info_buf) =
255 configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
false);
258 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
263 if (workload <= 274.4000f)
265 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16,
false,
false,
false,
true,
false);
269 std::tie(lhs_info_img, rhs_info_img) =
270 configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
true);
271 std::tie(lhs_info_buf, rhs_info_buf) =
272 configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
false);
275 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
280 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(
281 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
288 const unsigned int n0 = n < 1280 ? 2 : 4;
289 const unsigned int h0 = std::max(n / n0, 1U);
290 return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0,
false,
true,
false,
true);
294 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
298 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(
299 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
308 const unsigned int h0 = std::max(n / 4, 1U);
309 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0,
false,
true,
false,
true);
313 const unsigned int h0 = std::max(n / 2, 1U);
314 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true);
319 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
323 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(
324 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
326 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
327 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
328 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
329 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
331 GEMMLHSMatrixInfo lhs_info_buf;
332 GEMMRHSMatrixInfo rhs_info_buf;
333 GEMMLHSMatrixInfo lhs_info_img;
334 GEMMRHSMatrixInfo rhs_info_img;
338 std::tie(lhs_info_buf, rhs_info_buf) =
339 configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
false);
345 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
349 std::tie(lhs_info_img, rhs_info_img) =
350 configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
352 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
359 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
363 std::tie(lhs_info_img, rhs_info_img) =
364 configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
366 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
372 std::tie(lhs_info_buf, rhs_info_buf) =
373 configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2,
false,
false,
false,
false,
false);
375 if (workload <= 362.6000f)
377 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
381 if (r_mn <= 22.6067f)
383 if (workload <= 708.8000f)
385 std::tie(lhs_info_img, rhs_info_img) =
386 configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
388 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
392 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16,
false,
false,
false,
false,
false);
399 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
403 std::tie(lhs_info_img, rhs_info_img) =
404 configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
406 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
413 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(
414 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
420 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
424 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
425 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
427 if (workload <= 7449.60f)
429 if (workload <= 691.60f)
431 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8,
false,
false,
false,
false,
false);
435 if (workload <= 4155.20f)
437 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
441 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32,
false,
false,
false,
false,
false);
447 if (workload <= 16300.80f)
451 GEMMLHSMatrixInfo lhs_info_buf;
452 GEMMRHSMatrixInfo rhs_info_buf;
453 GEMMLHSMatrixInfo lhs_info_img;
454 GEMMRHSMatrixInfo rhs_info_img;
456 std::tie(lhs_info_img, rhs_info_img) =
457 configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1,
false,
true,
false,
false,
true);
458 std::tie(lhs_info_buf, rhs_info_buf) =
459 configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
462 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
466 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
471 GEMMLHSMatrixInfo lhs_info_buf;
472 GEMMRHSMatrixInfo rhs_info_buf;
473 GEMMLHSMatrixInfo lhs_info_img;
474 GEMMRHSMatrixInfo rhs_info_img;
476 std::tie(lhs_info_img, rhs_info_img) =
477 configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
true,
false,
false,
true);
478 std::tie(lhs_info_buf, rhs_info_buf) =
479 configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
482 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F16);
488 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(
489 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
496 const unsigned int n0 = n < 1280 ? 2 : 4;
497 const unsigned int h0 = std::max(n / n0, 1U);
498 return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0,
false,
true,
false,
true);
502 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
506 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(
507 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
516 const unsigned int h0 = std::max(n / 2, 1U);
517 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
521 const unsigned int h0 = std::max(n / 4, 1U);
522 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0,
false,
true,
false,
true);
527 const int h0 = std::max(std::min(
static_cast<int>(n / 2),
static_cast<int>(128)),
static_cast<int>(1));
530 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0,
false,
true,
false,
true);
534 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);
539 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(
540 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
547 const unsigned int h0 = std::max(n / 2, 1U);
548 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
552 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2,
false,
true,
false,
true);
556 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(
557 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
564 const unsigned int h0 = std::max(n / 2, 1U);
565 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0,
false,
true,
false,
true);
569 const unsigned int h0 = std::max(n / 2, 1U);
570 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);