48 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
ClGemmDefaultConfigNativeBifrost::*)(
unsigned int m,
unsigned int n,
unsigned int k,
52 &ClGemmDefaultConfigNativeBifrost::configure_G71_f32,
53 &ClGemmDefaultConfigNativeBifrost::configure_G71_u8);
56 &ClGemmDefaultConfigNativeBifrost::configure_G76_f32,
57 &ClGemmDefaultConfigNativeBifrost::configure_G76_u8);
60 &ClGemmDefaultConfigNativeBifrost::configure_default_f32,
61 &ClGemmDefaultConfigNativeBifrost::configure_default_u8);
63 ConfigurationFunctionExecutorPtr func =
nullptr;
79 return (this->*func)(m, n, k,
b);
82 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G71_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
91 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1,
false,
false,
false,
false);
93 else if(n >= 2048 && n < 8192)
95 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1,
false,
false,
false,
false);
99 return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1,
false,
false,
false,
false);
104 return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1,
false,
false,
false,
false);
108 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G71_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
119 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1,
false,
false,
false,
false);
121 else if(n >= 2048 && n < 16384)
123 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1,
false,
false,
false,
false);
127 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1,
false,
false,
false,
false);
134 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1,
false,
false,
false,
false);
138 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1,
false,
false,
false,
false);
148 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1,
false,
false,
false,
false);
152 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1,
false,
false,
false,
false);
157 return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1,
false,
false,
false,
false);
162 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G76_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
171 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1,
false,
false,
false,
false);
177 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1,
false,
false,
false,
false);
179 else if(k >= 2048 && k < 16384)
181 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1,
false,
false,
false,
false);
185 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1,
false,
false,
false,
false);
191 return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1,
false,
false,
false,
false);
195 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G76_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
204 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1,
false,
false,
false,
false);
206 else if(n >= 2048 && n < 16384)
208 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1,
false,
false,
false,
false);
212 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1,
false,
false,
false,
false);
219 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1,
false,
false,
false,
false);
223 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1,
false,
false,
false,
false);
228 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_default_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
233 return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1,
false,
false,
false,
false);
236 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_default_u8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b)
241 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1,
false,
false,
false,
false);
Basic container for the OpenCL GEMM configuration functions.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
Basic interface for the GEMM kernel configuration.
ClGemmDefaultConfigNativeBifrost(GPUTarget gpu)
Constructor.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2021 Arm Limited.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &={})
DataType
Available data types.