49 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
CLGEMMDefaultConfigReshapedBifrost::*)(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b);
52 &CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f16,
53 &CLGEMMDefaultConfigReshapedBifrost::configure_G7x_u8);
56 &CLGEMMDefaultConfigReshapedBifrost::configure_G52_f16,
57 &CLGEMMDefaultConfigReshapedBifrost::configure_G7x_u8);
60 &CLGEMMDefaultConfigReshapedBifrost::configure_G76_f16,
61 &CLGEMMDefaultConfigReshapedBifrost::configure_G76_u8);
63 ConfigurationFunctionExecutorPtr
func =
nullptr;
79 return (this->*func)(m, n, k,
b);
82 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
89 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
93 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16,
false,
true,
false,
true);
97 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
104 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2,
true,
true,
true,
false);
108 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false);
112 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G7x_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
121 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2,
true,
false,
false,
true);
125 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
true,
false,
false,
true);
132 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2,
true,
false,
false,
true);
136 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2,
true,
true,
false,
true);
141 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G52_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
143 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
144 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(
b)) / 20.0f;
145 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
146 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
148 GEMMLHSMatrixInfo lhs_info_buf;
149 GEMMRHSMatrixInfo rhs_info_buf;
150 GEMMLHSMatrixInfo lhs_info_img;
151 GEMMRHSMatrixInfo rhs_info_img;
153 if(workload <= 274.4000f)
159 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4,
false,
true,
true,
false,
false);
163 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
164 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
167 std::make_pair(lhs_info_buf, rhs_info_buf),
173 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
174 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
177 std::make_pair(lhs_info_buf, rhs_info_buf),
185 if(workload <= 542.4000f)
187 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
188 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
191 std::make_pair(lhs_info_buf, rhs_info_buf),
196 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
197 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
200 std::make_pair(lhs_info_buf, rhs_info_buf),
208 if(workload <= 11767.6001f)
210 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
211 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
214 std::make_pair(lhs_info_buf, rhs_info_buf),
219 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
true);
220 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1,
true,
true,
false,
true,
false);
223 std::make_pair(lhs_info_buf, rhs_info_buf),
229 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
true);
230 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2,
true,
true,
false,
true,
false);
233 std::make_pair(lhs_info_buf, rhs_info_buf),
240 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G52_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
244 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(
b)) / 20.0f;
246 if(workload <= 323.4000f)
248 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8,
false,
false,
false,
true,
false);
252 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2,
true,
true,
true,
false,
false);
256 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
261 GEMMLHSMatrixInfo lhs_info_buf;
262 GEMMRHSMatrixInfo rhs_info_buf;
263 GEMMLHSMatrixInfo lhs_info_img;
264 GEMMRHSMatrixInfo rhs_info_img;
269 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16,
true,
false,
false,
true);
273 std::tie(lhs_info_buf, rhs_info_buf) =
configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16,
false,
false,
false,
true);
278 if((m / 4) * (n / 4) >= 2560)
281 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8,
true,
true,
true,
false,
true);
286 std::tie(lhs_info_img, rhs_info_img) =
configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1,
true,
true,
true,
false,
true);
289 const TensorInfo tensor_rhs_info(TensorShape(n, k,
b), 1,
DataType::F32);
294 const bool use_cl_image2d = (n <= 4) ?
false :
true;
298 return std::make_pair(lhs_info_img, rhs_info_img);
302 return std::make_pair(lhs_info_buf, rhs_info_buf);
306 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
308 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(
b)) / 20.0f;
309 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
311 if(workload <= 1595.2000f)
315 if(workload <= 870.4000f)
317 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2,
true,
false,
true,
false,
false);
321 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
326 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2,
false,
false,
true,
false,
false);
331 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2,
true,
true,
true,
false,
false);
335 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
342 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1,
false,
false,
false,
true);
346 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2,
false,
true,
false,
true);
Basic interface for the GEMM kernel configuration.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported.
Basic container for the OpenCL GEMM configuration functions.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used.
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Copyright (c) 2017-2021 Arm Limited.
Bifrost based OpenCL GEMMReshaped configuration.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
CLGEMMDefaultConfigReshapedBifrost(GPUTarget gpu)
Constructor.
GPUTarget
Available GPU Targets.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
DataType
Available data types.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.