50 static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
61 static std::map<DataType, FunctionExecutorPtr> gemm_g77_configs =
72 static std::map<DataType, FunctionExecutorPtr> gemm_g78_configs =
87 if(gemm_g78_configs.find(data_type) != gemm_g78_configs.end())
89 return (this->*gemm_g78_configs[data_type])(params.
m, params.
n, params.
k, params.
b, params.
is_rhs_constant);
93 if(gemm_g77_configs.find(data_type) != gemm_g77_configs.end())
95 return (this->*gemm_g77_configs[data_type])(params.
m, params.
n, params.
k, params.
b, params.
is_rhs_constant);
99 if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
101 return (this->*gemm_default_configs[data_type])(params.
m, params.
n, params.
k, params.
b, params.
is_rhs_constant);
107 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
114 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
121 CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
133 const float r_mn =
static_cast<float>(m) / static_cast<float>(n);
134 const float r_mk =
static_cast<float>(m) / static_cast<float>(k);
135 const float r_nk =
static_cast<float>(n) / static_cast<float>(k);
136 const float workload = (
static_cast<float>(m) * static_cast<float>(n) *
static_cast<float>(
b)) / 20.0f;
138 if(r_mk <= 0.6817956566810608)
140 if(workload <= 801.6000061035156)
146 if(r_mn <= 0.0839829258620739)
152 if(r_mk <= 0.24917218834161758)
158 if(workload <= 2551.75)
164 if(workload <= 5061.574951171875)
179 if(r_mk <= 4.849947690963745)
181 if(workload <= 17618.4501953125)
183 if(workload <= 5224.699951171875)
189 if(r_nk <= 0.7933054566383362)
201 if(workload <= 20275.2001953125)
207 if(r_mk <= 3.07421875)
225 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_q8(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
239 CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f32(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
298 CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f16(
unsigned int m,
unsigned int n,
unsigned int k,
unsigned int b,
bool is_rhs_constant)
unsigned int m
Number of rows for the lhs matrix.
OpenCL GEMM kernel selection parameters.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
1 channel, 1 F32 per channel
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Reshaped GEMM kernel where only the rhs matrix is reshaped.
CLGEMMKernelType
OpenCL GEMM kernel types.
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
Basic interface for the GEMM kernel selection.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
Valhall based OpenCL GEMMKernel selection.
CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override
Given the input parameters passed through CLGEMMKernelSelectionParams, this method returns the CLGEMM...
CLGEMMDefaultTypeValhall(GPUTarget gpu)
Constructor.
Native GEMM kernel with fixed block size.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
quantized, asymmetric fixed-point 8-bit number unsigned
bool is_rhs_constant
True if the content of the rhs matrix is constant.
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
unsigned int b
Batch size.
GPUTarget
Available GPU Targets.
Native GEMM kernel with configurable block size.
DataType data_type
Data type.
unsigned int n
Number of columns for the rhs matrix.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
unsigned int k
Number of rows for the rhs matrix.