Compute Library
 23.08
CLGEMMDefaultTypeBifrost.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
29 
30 #include <map>
31 #include <utility>
32 
33 namespace arm_compute
34 {
35 namespace cl_gemm
36 {
39 {
40 }
41 
43 {
44  // _target could be used in the future to have a dedicated heuristic for each GPU IP
45  ARM_COMPUTE_UNUSED(_target);
46 
47  using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
48 
49  // Default configurations for Bifrost architectures
50  static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
51  {
52  { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
53  { DataType::F16, &CLGEMMDefaultTypeBifrost::default_f16 },
54  { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
55  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
56  { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
57  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
58  };
59 
60  // Mali-G71 configurations
61  static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs =
62  {
63  { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
64  { DataType::F16, &CLGEMMDefaultTypeBifrost::g71_f16 },
65  { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
66  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
67  { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
68  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
69  };
70 
71  // Mali-G52 configurations
72  static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs =
73  {
74  { DataType::F32, &CLGEMMDefaultTypeBifrost::g52_f32 },
75  { DataType::F16, &CLGEMMDefaultTypeBifrost::g52_f16 },
76  { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
77  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
78  { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
79  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
80  };
81 
82  // Mali-G76 configurations
83  static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs =
84  {
85  { DataType::F32, &CLGEMMDefaultTypeBifrost::g76_f32 },
86  { DataType::F16, &CLGEMMDefaultTypeBifrost::g76_f16 },
87  { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
88  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
89  { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
90  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
91  };
92 
93  const DataType data_type = params.data_type;
94 
95  switch(_target)
96  {
97  case GPUTarget::G71:
98  if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end())
99  {
100  return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
101  }
102  ARM_COMPUTE_ERROR("Not supported data type");
103  case GPUTarget::G76:
104  if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end())
105  {
106  return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
107  }
108  ARM_COMPUTE_ERROR("Not supported data type");
109  case GPUTarget::G52:
110  if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end())
111  {
112  return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
113  }
114  ARM_COMPUTE_ERROR("Not supported data type");
115  default:
116  if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
117  {
118  return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
119  }
120  ARM_COMPUTE_ERROR("Not supported data type");
121  }
122 }
123 
124 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
125 {
127 
129 
130  if(is_rhs_constant)
131  {
132  if((m > 1) && (n < 16))
133  {
135  }
136  else if(m == 1)
137  {
139  }
140  else
141  {
142  if((k > 256) && (m > 4))
143  {
144  constexpr float alpha = 3.2f;
145  constexpr float fact0 = 1.51f;
146  constexpr float fact1 = 1.66f;
147  constexpr float ops = 12.0f;
148  const float scale = k > 1024 ? 1.07f : 1.0f;
149  gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? CLGEMMKernelType::RESHAPED : CLGEMMKernelType::RESHAPED_ONLY_RHS;
150  }
151  else
152  {
154  }
155  }
156 
157  const auto workload = static_cast<float>((m * n) / 20.0f);
158 
160  }
161 
162  return gemm_type;
163 }
164 
165 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
166 {
167  ARM_COMPUTE_UNUSED(n, k, b);
168 
169  if(is_rhs_constant)
170  {
171  if(m == 1)
172  {
174  }
175  else
176  {
178  }
179  }
180  else
181  {
183  }
184 }
185 
186 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
187 {
188  ARM_COMPUTE_UNUSED(m, n, k, b);
189 
190  if(is_rhs_constant)
191  {
193  }
194  else
195  {
197  }
198 }
199 
200 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
201 {
203 
204  if(!is_rhs_constant)
205  {
207  }
208  if(m == 1)
209  {
211  }
212  if(k <= 496)
213  {
214  if(n <= 544)
215  {
217  }
218  else
219  {
221  }
222  }
223  else
224  {
225  if(k <= 588)
226  {
227  if(k <= 552)
228  {
229  if(m <= 148)
230  {
232  }
233  else
234  {
235  if(m <= 278)
236  {
238  }
239  else
240  {
242  }
243  }
244  }
245  else
246  {
248  }
249  }
250  else
251  {
253  }
254  }
255 }
256 
257 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
258 {
260 
261  if(!is_rhs_constant)
262  {
264  }
265 
266  if(m == 1)
267  {
269  }
270 
271  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
272  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
273  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
274  const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
275 
276  if(r_mn <= 1.5469f)
277  {
278  if(r_mk <= 0.8766f)
279  {
280  if(r_mk <= 0.0211f)
281  {
282  if(r_mnk <= 77.5833f)
283  {
285  }
286  else
287  {
289  }
290  }
291  else
292  {
293  if(r_nk <= 0.0832f)
294  {
296  }
297  else
298  {
300  }
301  }
302  }
303  else
304  {
305  if(r_mnk <= 193.0000f)
306  {
307  if(r_mn <= 0.9948f)
308  {
309  if(r_mk <= 2.5453f)
310  {
312  }
313  else
314  {
316  }
317  }
318  else
319  {
321  }
322  }
323  else
324  {
326  }
327  }
328  }
329  else
330  {
331  if(r_mn <= 17.7370f)
332  {
333  if(r_mnk <= 1391.2875f)
334  {
335  if(r_mk <= 2.9724f)
336  {
338  }
339  else
340  {
341  if(r_mnk <= 470.0000f)
342  {
344  }
345  else
346  {
348  }
349  }
350  }
351  else
352  {
353  if(r_nk <= 0.1381f)
354  {
355  if(r_mnk <= 9040.5000f)
356  {
358  }
359  else
360  {
362  }
363  }
364  else
365  {
366  if(r_mn <= 5.6790f)
367  {
369  }
370  else
371  {
373  }
374  }
375  }
376  }
377  else
378  {
380  }
381  }
382 }
383 
384 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
385 {
387 
388  if(!is_rhs_constant)
389  {
391  }
392 
393  if(m == 1)
394  {
396  }
397 
398  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
399  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
400 
401  if(k <= 212)
402  {
404  }
405  else
406  {
407  if(r_nk <= 0.4990234375f)
408  {
409  if(k <= 1392)
410  {
412  }
413  else
414  {
415  if(m <= 325)
416  {
418  }
419  else
420  {
422  }
423  }
424  }
425  else
426  {
427  if(k <= 471)
428  {
430  }
431  else
432  {
433  if(r_mn <= 0.04475911520421505f)
434  {
436  }
437  else
438  {
440  }
441  }
442  }
443  }
444 }
445 
446 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
447 {
448  if(!is_rhs_constant)
449  {
451  }
452 
453  if(m == 1)
454  {
456  }
457 
458  if(n <= 127.0000f)
459  {
460  if(n <= 63.5000f)
461  {
463  }
464  else
465  {
466  if(m <= 3616.0000f)
467  {
468  if(b <= 18.5000f)
469  {
470  if(m <= 2970.5000f)
471  {
473  }
474  else
475  {
476  if(k <= 104.0000f)
477  {
479  }
480  else
481  {
483  }
484  }
485  }
486  else
487  {
489  }
490  }
491  else
492  {
494  }
495  }
496  }
497  else
498  {
499  if(m <= 12.5000f)
500  {
502  }
503  else
504  {
505  if(k <= 104.0000f)
506  {
507  if(b <= 18.5000f)
508  {
509  if(m <= 490.0000f)
510  {
511  if(n <= 272.0000f)
512  {
514  }
515  else
516  {
518  }
519  }
520  else
521  {
523  }
524  }
525  else
526  {
528  }
529  }
530  else
531  {
532  if(m <= 226.0000f)
533  {
534  if(n <= 140.0000f)
535  {
536  if(m <= 179.5000f)
537  {
539  }
540  else
541  {
543  }
544  }
545  else
546  {
548  }
549  }
550  else
551  {
553  }
554  }
555  }
556  }
557 }
558 
559 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
560 {
564 
565  if(is_rhs_constant)
566  {
567  if(m == 1)
568  {
570  }
571  else
572  {
574  }
575  }
576  else
577  {
579  }
580 }
581 } // namespace cl_gemm
582 } // namespace arm_compute
arm_compute::DataType::QSYMM8_PER_CHANNEL
@ QSYMM8_PER_CHANNEL
quantized, symmetric per channel fixed-point 8-bit number
arm_compute::CLGEMMKernelSelectionParams
OpenCL GEMM kernel selection parameters.
Definition: CLTypes.h:44
arm_compute::CLGEMMKernelSelectionParams::data_type
DataType data_type
Data type.
Definition: CLTypes.h:51
arm_compute::DataType::QASYMM8
@ QASYMM8
quantized, asymmetric fixed-point 8-bit number unsigned
arm_compute::DataType::QSYMM8
@ QSYMM8
quantized, symmetric fixed-point 8-bit number
ARM_COMPUTE_ERROR
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:353
arm_compute::CLGEMMKernelSelectionParams::n
unsigned int n
Number of columns for the rhs matrix.
Definition: CLTypes.h:47
arm_compute::cl_gemm::CLGEMMDefaultTypeBifrost::CLGEMMDefaultTypeBifrost
CLGEMMDefaultTypeBifrost(GPUTarget gpu)
Constructor.
Definition: CLGEMMDefaultTypeBifrost.cpp:37
arm_compute::test::validation::k
const unsigned int k
Definition: GEMMMatrixMultiplyNative.cpp:361
ClGemmHelpers.h
arm_compute::CLGEMMKernelType::NATIVE
@ NATIVE
Native GEMM kernel with configurable block size.
arm_compute::CLGEMMKernelSelectionParams::m
unsigned int m
Number of rows for the lhs matrix.
Definition: CLTypes.h:46
CLKernelLibrary.h
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
arm_compute::CLGEMMKernelSelectionParams::is_rhs_constant
bool is_rhs_constant
True if the content of the rhs matrix is constant.
Definition: CLTypes.h:50
arm_compute::CLGEMMKernelType
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
arm_compute::test::validation::m
const unsigned int m
Definition: GEMMMatrixMultiplyNative.cpp:359
arm_compute::GPUTarget::G52
@ G52
arm_compute::mlgo::parser::gemm_type
GEMMType gemm_type(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:567
arm_compute::DataType::QASYMM8_SIGNED
@ QASYMM8_SIGNED
quantized, asymmetric fixed-point 8-bit number signed
arm_compute::CLGEMMKernelSelectionParams::b
unsigned int b
Batch size.
Definition: CLTypes.h:49
arm_compute::cl_gemm::CLGEMMDefaultTypeBifrost::select_kernel
CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override
Given the input parameters passed through CLGEMMKernelSelectionParams, this method returns the CLGEMM...
Definition: CLGEMMDefaultTypeBifrost.cpp:42
ARM_COMPUTE_UNUSED
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
arm_compute::cl_gemm::ICLGEMMKernelSelection
Basic interface for the GEMM kernel selection.
Definition: ICLGEMMKernelSelection.h:36
arm_compute::test::validation::data_type
data_type
Definition: Cast.cpp:223
arm_compute::GPUTarget
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
arm_compute::GPUTarget::G76
@ G76
arm_compute::CLGEMMKernelSelectionParams::k
unsigned int k
Number of rows for the rhs matrix.
Definition: CLTypes.h:48
CLGEMMDefaultTypeBifrost.h
arm_compute::test::validation::b
SimpleTensor< float > b
Definition: DFT.cpp:157
arm_compute::test::validation::scale
NEScale scale
Definition: Scale.cpp:272
arm_compute
Copyright (c) 2017-2023 Arm Limited.
Definition: introduction.dox:24
arm_compute::DataType::F16
@ F16
16-bit floating-point number
arm_compute::CLGEMMKernelType::RESHAPED
@ RESHAPED
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
arm_compute::GPUTarget::G71
@ G71
arm_compute::DataType::F32
@ F32
32-bit floating-point number
arm_compute::CLGEMMKernelType::RESHAPED_ONLY_RHS
@ RESHAPED_ONLY_RHS
Reshaped GEMM kernel where only the rhs matrix is reshaped.
arm_compute::DataType
DataType
Available data types.
Definition: CoreTypes.h:82
arm_compute::test::validation::n
const unsigned int n
Definition: GEMMMatrixMultiplyNative.cpp:360
arm_compute::cl_gemm::CLGEMMDefaultTypeBifrost
Bifrost based OpenCL GEMMKernel selection.
Definition: CLGEMMDefaultTypeBifrost.h:34
CLHelpers.h