48 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant);
51 static std::map<DataType, FunctionExecutorPtr> gemm_default_configs = {
60 static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs = {
69 static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs = {
78 static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs = {
91 if (gemm_g71_configs.find(
data_type) != gemm_g71_configs.end())
93 return (this->*gemm_g71_configs[
data_type])(params.
m, params.
n, params.
k, params.
b,
98 if (gemm_g76_configs.find(
data_type) != gemm_g76_configs.end())
100 return (this->*gemm_g76_configs[
data_type])(params.
m, params.
n, params.
k, params.
b,
105 if (gemm_g52_configs.find(
data_type) != gemm_g52_configs.end())
107 return (this->*gemm_g52_configs[
data_type])(params.
m, params.
n, params.
k, params.
b,
112 if (gemm_default_configs.find(
data_type) != gemm_default_configs.end())
114 return (this->*gemm_default_configs[
data_type])(params.
m, params.
n, params.
k, params.
b,
122 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
130 if ((m > 1) && (n < 16))
140 if ((k > 256) && (m > 4))
142 constexpr
float alpha = 3.2f;
143 constexpr
float fact0 = 1.51f;
144 constexpr
float fact1 = 1.66f;
145 constexpr
float ops = 12.0f;
146 const float scale = k > 1024 ? 1.07f : 1.0f;
147 gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n *
scale) / ops))
157 const auto workload =
static_cast<float>((m * n) / 20.0f);
167 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
189 unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
204 CLGEMMDefaultTypeBifrost::g76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
208 if (!is_rhs_constant)
262 CLGEMMDefaultTypeBifrost::g52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
266 if (!is_rhs_constant)
276 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
277 const float r_mk =
static_cast<float>(m) /
static_cast<float>(k);
278 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
279 const float r_mnk =
static_cast<float>(m) / (
static_cast<float>(n) *
static_cast<float>(k));
287 if (r_mnk <= 77.5833f)
310 if (r_mnk <= 193.0000f)
336 if (r_mn <= 17.7370f)
338 if (r_mnk <= 1391.2875f)
346 if (r_mnk <= 470.0000f)
360 if (r_mnk <= 9040.5000f)
390 CLGEMMDefaultTypeBifrost::g76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
394 if (!is_rhs_constant)
404 const float r_mn =
static_cast<float>(m) /
static_cast<float>(n);
405 const float r_nk =
static_cast<float>(n) /
static_cast<float>(k);
413 if (r_nk <= 0.4990234375f)
439 if (r_mn <= 0.04475911520421505f)
453 CLGEMMDefaultTypeBifrost::g52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
455 if (!is_rhs_constant)
567 CLGEMMDefaultTypeBifrost::g71_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)