54 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
56 {
DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32 },
57 {
DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16 },
58 {
DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 },
59 {
DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 },
65 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
67 {
DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32 },
68 {
DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16 },
69 {
DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
70 {
DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
76 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
78 {
DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32 },
79 {
DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16 },
80 {
DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 },
81 {
DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 },
87 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
89 {
DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32 },
90 {
DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16 },
91 {
DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
92 {
DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
100 if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
102 return (this->*gemm_configs_G76[data_type])(m, n, k,
b);
109 if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
111 return (this->*gemm_configs_G52[data_type])(m, n, k,
b);
118 if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
120 return (this->*gemm_configs_G51[data_type])(m, n, k,
b);
127 if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
129 return (this->*gemm_configs_G7x[data_type])(m, n, k,
b);
138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
147 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4,
false,
true,
false,
true,
false);
151 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8,
false,
true,
false,
true,
false);
156 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
160 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
170 const bool is_workload_big = ((m * n *
b) / 16) >= 2048;
176 const unsigned int h0 = std::max(n / 4, 1
U);
177 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0,
false,
true,
false,
true,
false);
181 const unsigned int h0 = std::max(n / 2, 1
U);
184 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true,
false);
188 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true,
false);
194 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
197 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
true);
201 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true);
206 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
209 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
false,
true);
213 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true,
true);
221 const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false :
true;
225 return std::make_pair(lhs_info_img, rhs_info_img);
229 return std::make_pair(lhs_info_buf, rhs_info_buf);
233 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
235 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
236 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
247 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16,
false,
true,
false,
true,
false);
251 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
true);
252 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
false);
255 std::make_pair(lhs_info_buf, rhs_info_buf),
261 if(workload <= 274.4000f)
263 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16,
false,
false,
false,
true,
false);
267 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
true);
268 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
false);
271 std::make_pair(lhs_info_buf, rhs_info_buf),
277 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
284 const unsigned int n0 = n < 1280 ? 2 : 4;
285 const unsigned int h0 = std::max(n / n0, 1
U);
286 return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0,
false,
true,
false,
true);
290 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
294 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
303 const unsigned int h0 = std::max(n / 4, 1
U);
304 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0,
false,
true,
false,
true);
308 const unsigned int h0 = std::max(n / 2, 1
U);
309 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true);
314 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
318 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
320 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
321 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
322 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
323 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
332 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
false);
338 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
342 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
344 std::make_pair(lhs_info_buf, rhs_info_buf),
352 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
356 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
358 std::make_pair(lhs_info_buf, rhs_info_buf),
365 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2,
false,
false,
false,
false,
false);
367 if(workload <= 362.6000f)
369 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
375 if(workload <= 708.8000f)
377 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
379 std::make_pair(lhs_info_buf, rhs_info_buf),
384 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16,
false,
false,
false,
false,
false);
391 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
395 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
397 std::make_pair(lhs_info_buf, rhs_info_buf),
405 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
411 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
415 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
416 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
418 if(workload <= 7449.60f)
420 if(workload <= 691.60f)
422 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8,
false,
false,
false,
false,
false);
426 if(workload <= 4155.20f)
428 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
432 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32,
false,
false,
false,
false,
false);
438 if(workload <= 16300.80f)
447 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1,
false,
true,
false,
false,
true);
448 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
451 std::make_pair(lhs_info_buf, rhs_info_buf),
456 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
466 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
true,
false,
false,
true);
467 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
470 std::make_pair(lhs_info_buf, rhs_info_buf),
477 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
484 const unsigned int n0 = n < 1280 ? 2 : 4;
485 const unsigned int h0 = std::max(n / n0, 1
U);
486 return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0,
false,
true,
false,
true);
490 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
494 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
503 const unsigned int h0 = std::max(n / 2, 1
U);
504 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
508 const unsigned int h0 = std::max(n / 4, 1
U);
509 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0,
false,
true,
false,
true);
514 const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
517 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0,
false,
true,
false,
true);
521 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);
526 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
533 const unsigned int h0 = std::max(n / 2, 1
U);
534 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
538 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2,
false,
true,
false,
true);
542 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
549 const unsigned int h0 = std::max(n / 2, 1
U);
550 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0,
false,
true,
false,
true);
554 const unsigned int h0 = std::max(n / 2, 1
U);
555 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);
Basic interface for the GEMM kernel configuration.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
1 channel, 1 F32 per channel
CLGEMMDefaultConfigReshapedRHSOnlyBifrost(GPUTarget gpu)
Constructor.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
GEMM RHS (Right Hand Side) matrix information.
quantized, asymmetric fixed-point 8-bit number unsigned
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
GPUTarget
Available GPU Targets.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Store the tensor's metadata.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.