56 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
57 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
60 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
61 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
64 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
65 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
68 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
69 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
72 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
73 &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
75 ConfigurationFunctionExecutorPtr func =
nullptr;
96 return (this->*func)(m, n, k,
b);
99 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
108 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4,
false,
true,
false,
true,
false);
112 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8,
false,
true,
false,
true,
false);
117 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
121 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
128 const unsigned int h0 = std::max(n / 2, 1
U);
129 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
133 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
136 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
140 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
155 const bool is_workload_big = ((m * n *
b) / 16) >= 2048;
161 const unsigned int h0 = std::max(n / 4, 1
U);
162 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0,
false,
true,
false,
true,
false);
166 const unsigned int h0 = std::max(n / 2, 1
U);
169 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true,
false);
173 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true,
false);
179 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
182 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
true);
186 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true);
191 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
194 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0,
false,
true,
false,
false,
true);
198 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0,
false,
true,
false,
true,
true);
206 const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false :
true;
210 return std::make_pair(lhs_info_img, rhs_info_img);
214 return std::make_pair(lhs_info_buf, rhs_info_buf);
218 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
220 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
221 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
232 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16,
false,
true,
false,
true,
false);
236 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
true);
237 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16,
false,
true,
false,
true,
false);
240 std::make_pair(lhs_info_buf, rhs_info_buf),
246 if(workload <= 274.4000f)
248 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16,
false,
false,
false,
true,
false);
252 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
true);
253 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
false,
false,
true,
false);
256 std::make_pair(lhs_info_buf, rhs_info_buf),
262 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
269 const unsigned int n0 = n < 1280 ? 2 : 4;
270 const unsigned int h0 = std::max(n / n0, 1
U);
271 return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0,
false,
true,
false,
true);
275 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
279 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
288 const unsigned int h0 = std::max(n / 4, 1
U);
289 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0,
false,
true,
false,
true);
293 const unsigned int h0 = std::max(n / 2, 1
U);
294 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0,
false,
true,
false,
true);
299 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4,
false,
true,
false,
true);
303 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
305 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
306 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
307 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
308 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
317 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
false);
323 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
327 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
329 std::make_pair(lhs_info_buf, rhs_info_buf),
337 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
341 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16,
false,
true,
false,
false,
true);
343 std::make_pair(lhs_info_buf, rhs_info_buf),
350 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2,
false,
false,
false,
false,
false);
352 if(workload <= 362.6000f)
354 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
360 if(workload <= 708.8000f)
362 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
364 std::make_pair(lhs_info_buf, rhs_info_buf),
369 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16,
false,
false,
false,
false,
false);
376 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16,
false,
false,
false,
true,
false);
380 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
false,
false,
false,
true);
382 std::make_pair(lhs_info_buf, rhs_info_buf),
390 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
396 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32,
false,
true,
false,
true,
false);
400 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
401 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
403 if(workload <= 7449.60f)
405 if(workload <= 691.60f)
407 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8,
false,
false,
false,
false,
false);
411 if(workload <= 4155.20f)
413 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
417 return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32,
false,
false,
false,
false,
false);
423 if(workload <= 16300.80f)
432 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1,
false,
true,
false,
false,
true);
433 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
436 std::make_pair(lhs_info_buf, rhs_info_buf),
441 return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
451 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2,
false,
true,
false,
false,
true);
452 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16,
false,
false,
false,
false,
false);
455 std::make_pair(lhs_info_buf, rhs_info_buf),
462 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
469 const unsigned int n0 = n < 1280 ? 2 : 4;
470 const unsigned int h0 = std::max(n / n0, 1
U);
471 return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0,
false,
true,
false,
true);
475 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2,
false,
true,
false,
true);
479 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
488 const unsigned int h0 = std::max(n / 2, 1
U);
489 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
493 const unsigned int h0 = std::max(n / 4, 1
U);
494 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0,
false,
true,
false,
true);
499 const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
502 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0,
false,
true,
false,
true);
506 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);
511 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
518 const unsigned int h0 = std::max(n / 2, 1
U);
519 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0,
false,
true,
false,
true);
523 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2,
false,
true,
false,
true);
527 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
534 const unsigned int h0 = std::max(n / 2, 1
U);
535 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0,
false,
true,
false,
true);
539 const unsigned int h0 = std::max(n / 2, 1
U);
540 return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0,
false,
true,
false,
true);
Basic container for the OpenCL GEMM configuration functions.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu)
Constructor.
Basic interface for the GEMM kernel configuration.
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
GEMM RHS (Right Hand Side) matrix information.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &={})
Store the tensor's metadata.
DataType
Available data types.