54 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
58 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
59 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
62 &ClGemmDefaultConfigReshapedBifrost::configure_G52_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
63 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
66 &ClGemmDefaultConfigReshapedBifrost::configure_G76_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
67 &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
69 ConfigurationFunctionExecutorPtr func =
nullptr;
85 return (this->*func)(m, n, k,
b);
88 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
89 ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
96 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
100 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16,
false,
true,
false,
true);
104 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
105 ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
112 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2,
true,
true,
true,
false);
116 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false);
120 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
121 ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
130 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2,
true,
false,
false,
true);
134 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
true,
false,
false,
true);
141 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2,
true,
false,
false,
true);
145 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2,
true,
true,
false,
true);
150 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
151 ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
153 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
154 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
155 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
156 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
158 GEMMLHSMatrixInfo lhs_info_buf;
159 GEMMRHSMatrixInfo rhs_info_buf;
160 GEMMLHSMatrixInfo lhs_info_img;
161 GEMMRHSMatrixInfo rhs_info_img;
163 if (workload <= 274.4000f)
167 if (r_mn <= 21.1667f)
169 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4,
false,
true,
true,
false,
false);
173 std::tie(lhs_info_img, rhs_info_img) =
174 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
175 std::tie(lhs_info_buf, rhs_info_buf) =
176 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
179 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
184 std::tie(lhs_info_img, rhs_info_img) =
185 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
186 std::tie(lhs_info_buf, rhs_info_buf) =
187 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
190 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
195 if (r_mk <= 17.3926f)
197 if (workload <= 542.4000f)
199 std::tie(lhs_info_img, rhs_info_img) =
200 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
201 std::tie(lhs_info_buf, rhs_info_buf) =
202 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
205 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
209 std::tie(lhs_info_img, rhs_info_img) =
210 configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
211 std::tie(lhs_info_buf, rhs_info_buf) =
212 configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
215 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
222 if (workload <= 11767.6001f)
224 std::tie(lhs_info_img, rhs_info_img) =
225 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
226 std::tie(lhs_info_buf, rhs_info_buf) =
227 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
230 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
234 std::tie(lhs_info_img, rhs_info_img) =
235 configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
236 std::tie(lhs_info_buf, rhs_info_buf) =
237 configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
240 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
245 std::tie(lhs_info_img, rhs_info_img) =
246 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
247 std::tie(lhs_info_buf, rhs_info_buf) =
248 configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
251 std::make_pair(lhs_info_buf, rhs_info_buf), n, k,
b,
DataType::F32);
257 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
258 ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
262 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
264 if (workload <= 323.4000f)
266 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8,
false,
false,
false,
true,
false);
270 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2,
true,
true,
true,
false,
false);
274 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
275 ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
280 GEMMLHSMatrixInfo lhs_info_buf;
281 GEMMRHSMatrixInfo rhs_info_buf;
282 GEMMLHSMatrixInfo lhs_info_img;
283 GEMMRHSMatrixInfo rhs_info_img;
288 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
292 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16,
false,
false,
false,
true);
297 if ((m / 4) * (n / 4) >= 2560)
300 std::tie(lhs_info_img, rhs_info_img) =
301 configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8,
true,
true,
true,
false,
true);
306 std::tie(lhs_info_img, rhs_info_img) =
307 configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1,
true,
true,
true,
false,
true);
310 const TensorInfo tensor_rhs_info(TensorShape(n, k,
b), 1,
DataType::F32);
315 const bool use_cl_image2d = (n <= 4) ?
false :
true;
319 return std::make_pair(lhs_info_img, rhs_info_img);
323 return std::make_pair(lhs_info_buf, rhs_info_buf);
327 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
328 ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
330 const float workload = (
static_cast<float>(m) *
static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
331 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
333 if (workload <= 1595.2000f)
337 if (workload <= 870.4000f)
339 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2,
true,
false,
true,
false,
false);
343 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
348 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
353 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false,
false);
357 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
358 ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
365 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1,
false,
false,
false,
true);
369 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
false,
true,
false,
true);