53 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
ClGemmDefaultConfigReshapedBifrost::*)(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b);
56 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
57 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
60 &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
61 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
64 &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
65 &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
67 ConfigurationFunctionExecutorPtr func =
nullptr;
83 return (this->*func)(m, n, k,
b);
86 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
93 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
97 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16,
false,
true,
false,
true);
101 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
108 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2,
true,
true,
true,
false);
112 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false);
116 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
125 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2,
true,
false,
false,
true);
129 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
true,
false,
false,
true);
136 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2,
true,
false,
false,
true);
140 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2,
true,
true,
false,
true);
145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
147 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
148 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
149 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
150 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
157 if(workload <= 274.4000f)
163 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4,
false,
true,
true,
false,
false);
167 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
168 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
171 std::make_pair(lhs_info_buf, rhs_info_buf),
177 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
178 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
181 std::make_pair(lhs_info_buf, rhs_info_buf),
189 if(workload <= 542.4000f)
191 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
192 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
195 std::make_pair(lhs_info_buf, rhs_info_buf),
200 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
201 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
204 std::make_pair(lhs_info_buf, rhs_info_buf),
212 if(workload <= 11767.6001f)
214 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
215 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
218 std::make_pair(lhs_info_buf, rhs_info_buf),
223 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
224 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
227 std::make_pair(lhs_info_buf, rhs_info_buf),
233 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
234 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
237 std::make_pair(lhs_info_buf, rhs_info_buf),
244 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
248 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
250 if(workload <= 323.4000f)
252 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8,
false,
false,
false,
true,
false);
256 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2,
true,
true,
true,
false,
false);
260 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
273 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
277 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16,
false,
false,
false,
true);
282 if((m / 4) * (n / 4) >= 2560)
285 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8,
true,
true,
true,
false,
true);
290 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1,
true,
true,
true,
false,
true);
298 const bool use_cl_image2d = (n <= 4) ?
false :
true;
302 return std::make_pair(lhs_info_img, rhs_info_img);
306 return std::make_pair(lhs_info_buf, rhs_info_buf);
310 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
312 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
313 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
315 if(workload <= 1595.2000f)
319 if(workload <= 870.4000f)
321 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2,
true,
false,
true,
false,
false);
325 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
330 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
335 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false,
false);
339 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
346 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1,
false,
false,
false,
true);
350 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
false,
true,
false,
true);
Basic container for the OpenCL GEMM configuration functions.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
Basic interface for the GEMM kernel configuration.
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2021 Arm Limited.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)
Constructor.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
GEMM RHS (Right Hand Side) matrix information.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &={})
Store the tensor's metadata.
DataType
Available data types.