50 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
CLGEMMDefaultConfigReshapedBifrost::*)(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b);
53 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
55 {
DataType::F32, &CLGEMMDefaultConfigReshapedBifrost::configure_G76_f32 },
56 {
DataType::F16, &CLGEMMDefaultConfigReshapedBifrost::configure_G76_f16 },
64 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
66 {
DataType::F32, &CLGEMMDefaultConfigReshapedBifrost::configure_G52_f32 },
67 {
DataType::F16, &CLGEMMDefaultConfigReshapedBifrost::configure_G52_f16 },
75 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
77 {
DataType::F32, &CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f32 },
78 {
DataType::F16, &CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f16 },
88 if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
90 return (this->*gemm_configs_G76[data_type])(m, n, k,
b);
97 if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
99 return (this->*gemm_configs_G7x[data_type])(m, n, k,
b);
108 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
115 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
119 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16,
false,
true,
false,
true);
123 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
130 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2,
true,
true,
true,
false);
134 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false);
138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
147 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2,
true,
false,
false,
true);
151 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
true,
false,
false,
true);
158 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2,
true,
false,
false,
true);
162 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2,
true,
true,
false,
true);
167 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
169 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
170 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
171 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
172 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
179 if(workload <= 274.4000f)
185 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4,
false,
true,
true,
false,
false);
189 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
190 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
193 std::make_pair(lhs_info_buf, rhs_info_buf),
199 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
200 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
203 std::make_pair(lhs_info_buf, rhs_info_buf),
211 if(workload <= 542.4000f)
213 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
214 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
217 std::make_pair(lhs_info_buf, rhs_info_buf),
222 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
223 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
226 std::make_pair(lhs_info_buf, rhs_info_buf),
234 if(workload <= 11767.6001f)
236 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
237 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
240 std::make_pair(lhs_info_buf, rhs_info_buf),
245 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
246 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
249 std::make_pair(lhs_info_buf, rhs_info_buf),
255 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
256 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
259 std::make_pair(lhs_info_buf, rhs_info_buf),
266 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
270 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
272 if(workload <= 323.4000f)
274 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8,
false,
false,
false,
true,
false);
278 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2,
true,
true,
true,
false,
false);
282 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
295 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
299 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16,
false,
false,
false,
true);
304 if((m / 4) * (n / 4) >= 2560)
307 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8,
true,
true,
true,
false,
true);
312 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1,
true,
true,
true,
false,
true);
320 const bool use_cl_image2d = (n <= 4) ?
false :
true;
324 return std::make_pair(lhs_info_img, rhs_info_img);
328 return std::make_pair(lhs_info_buf, rhs_info_buf);
332 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
334 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
335 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
337 if(workload <= 1595.2000f)
341 if(workload <= 870.4000f)
343 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2,
true,
false,
true,
false,
false);
347 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
352 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
357 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false,
false);
361 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
368 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1,
false,
false,
false,
true);
372 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
false,
true,
false,
true);
Basic interface for the GEMM kernel configuration.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
GEMM RHS (Right Hand Side) matrix information.
quantized, asymmetric fixed-point 8-bit number unsigned
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
CLGEMMDefaultConfigReshapedBifrost(GPUTarget gpu)
Constructor.
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
GPUTarget
Available GPU Targets.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Store the tensor's metadata.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.