45 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
CLGEMMDefaultConfigReshapedValhall::*)(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b);
48 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G77 =
50 {
DataType::F32, &CLGEMMDefaultConfigReshapedValhall::configure_G77_f32 },
51 {
DataType::F16, &CLGEMMDefaultConfigReshapedValhall::configure_G77_f16 },
62 if(gemm_configs_G77.find(data_type) != gemm_configs_G77.end())
64 return (this->*gemm_configs_G77[data_type])(m, n, k,
b);
73 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G77_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
80 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
84 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16,
false,
true,
false,
true);
88 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G77_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
93 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
94 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
95 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
96 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
103 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4,
false,
false,
true,
false,
false);
105 if(r_mk <= 0.11824845522642136)
107 if(workload <= 880.0)
109 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4,
false,
false,
true,
false,
false);
113 if(r_nk <= 0.42521367967128754)
115 if(workload <= 1726.4000244140625)
117 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2,
false,
false,
true,
false,
false);
121 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
false,
true,
true,
false,
true);
124 std::make_pair(lhs_info_buf, rhs_info_buf),
130 if(workload <= 1241.6000366210938)
132 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4,
false,
false,
true,
false,
false);
136 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4,
false,
false,
true,
false,
false);
143 if(workload <= 11404.7998046875)
145 if(r_mk <= 1.0126488208770752)
147 if(r_mn <= 2.545312523841858)
149 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
false,
true,
true,
false,
true);
152 std::make_pair(lhs_info_buf, rhs_info_buf),
157 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4,
false,
false,
true,
false,
false);
162 if(workload <= 2881.199951171875)
164 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
false,
false,
true,
false,
true);
167 std::make_pair(lhs_info_buf, rhs_info_buf),
172 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
false,
true,
true,
false,
true);
175 std::make_pair(lhs_info_buf, rhs_info_buf),
182 if(r_nk <= 0.5765306055545807)
184 if(r_mn <= 6.010416746139526)
186 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
false,
true,
true,
false,
true);
189 std::make_pair(lhs_info_buf, rhs_info_buf),
194 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
false,
true,
false,
true);
197 std::make_pair(lhs_info_buf, rhs_info_buf),
203 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
false,
true,
false,
true);
206 std::make_pair(lhs_info_buf, rhs_info_buf),
213 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G77_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
220 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1,
false,
false,
false,
true);
224 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
false,
true,
false,
true);
Basic interface for the GEMM kernel configuration.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
1 channel, 1 F32 per channel
GEMM LHS (Left Hand Side) matrix information.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
GEMM RHS (Right Hand Side) matrix information.
quantized, asymmetric fixed-point 8-bit number unsigned
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
GPUTarget
Available GPU Targets.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
CLGEMMDefaultConfigReshapedValhall(GPUTarget gpu)
Constructor.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.