Compute Library
 22.11
ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
32 
35 
36 #include <utility>
37 
38 namespace arm_compute
39 {
40 namespace opencl
41 {
42 namespace kernels
43 {
44 namespace gemm
45 {
47 
49  : IClGemmKernelConfig(gpu)
50 {
51 }
52 
53 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
54 {
55  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k,
56  unsigned int b);
57 
58  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
59  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
60  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
61 
62  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
63  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
64  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
65 
66  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
67  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
68  &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
69 
70  ConfigurationFunctionExecutorPtr func = nullptr;
71 
72  switch(_target)
73  {
74  case GPUTarget::G78:
75  func = configs_G78.get_function(data_type);
76  break;
77  case GPUTarget::G715:
78  case GPUTarget::G615:
79  func = configs_G715.get_function(data_type);
80  break;
81  case GPUTarget::G77:
82  default:
83  func = configs_G77.get_function(data_type);
84  break;
85  }
86 
87  ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
88  return (this->*func)(m, n, k, b);
89 }
90 
91 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
92 {
93  if(m == 1)
94  {
95  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
96  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
97 
98  if(r_mk <= 0.0064484127797186375)
99  {
100  if(r_mn <= 0.0028273810748942196)
101  {
102  GEMMLHSMatrixInfo lhs_info_buf;
103  GEMMRHSMatrixInfo rhs_info_buf;
104  GEMMLHSMatrixInfo lhs_info_img;
105  GEMMRHSMatrixInfo rhs_info_img;
106 
107  const unsigned int h0 = std::max(n / 4, 1U);
108  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1);
109  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0);
110 
111  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
112  std::make_pair(lhs_info_buf, rhs_info_buf),
113  n, k, b, DataType::F32);
114  }
115  else
116  {
117  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0);
118  }
119  }
120  else
121  {
122  if(r_mk <= 0.020312500186264515)
123  {
124  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0);
125  }
126  else
127  {
128  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0);
129  }
130  }
131  }
132  else
133  {
134  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
135  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
136  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
137 
138  if(workload <= 1999.2000122070312)
139  {
140  if(workload <= 747.1999816894531)
141  {
142  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
143  }
144  else
145  {
146  GEMMLHSMatrixInfo lhs_info_buf;
147  GEMMRHSMatrixInfo rhs_info_buf;
148  GEMMLHSMatrixInfo lhs_info_img;
149  GEMMRHSMatrixInfo rhs_info_img;
150  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
151  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
152 
153  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154  std::make_pair(lhs_info_buf, rhs_info_buf),
155  n, k, b, DataType::F32);
156  }
157  }
158  else
159  {
160  if(r_mn <= 0.03348214365541935)
161  {
162  if(r_mk <= 0.028125000186264515)
163  {
164  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
165  }
166  else
167  {
168  GEMMLHSMatrixInfo lhs_info_buf;
169  GEMMRHSMatrixInfo rhs_info_buf;
170  GEMMLHSMatrixInfo lhs_info_img;
171  GEMMRHSMatrixInfo rhs_info_img;
172  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
173  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
174 
175  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
176  std::make_pair(lhs_info_buf, rhs_info_buf),
177  n, k, b, DataType::F32);
178  }
179  }
180  else
181  {
182  GEMMLHSMatrixInfo lhs_info_buf;
183  GEMMRHSMatrixInfo rhs_info_buf;
184  GEMMLHSMatrixInfo lhs_info_img;
185  GEMMRHSMatrixInfo rhs_info_img;
186  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1);
187  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0);
188 
189  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
190  std::make_pair(lhs_info_buf, rhs_info_buf),
191  n, k, b, DataType::F32);
192  }
193  }
194  }
195 }
196 
197 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
198 {
201 
202  if(m == 1)
203  {
204  const unsigned int h0 = std::max(n / 2, 1U);
205  if(n <= 836.0)
206  {
207  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0);
208  }
209  else
210  {
211  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0);
212  }
213  }
214  else if(m < 128)
215  {
216  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217  if(k >= 512)
218  {
219  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
220  }
221  else
222  {
223  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
224  }
225  }
226  else
227  {
228  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
229  if(n >= 64)
230  {
231  return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0);
232  }
233  else
234  {
235  if(k >= 512)
236  {
237  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
238  }
239  else
240  {
241  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
242  }
243  }
244  }
245 }
246 
247 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
248 {
251 
252  if(m == 1)
253  {
254  const unsigned int h0 = std::max(n / 2, 1U);
255  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
256  }
257  else
258  {
259  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
260  if(m >= 28)
261  {
262  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1);
263  }
264  else
265  {
266  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1);
267  }
268  }
269 }
270 
271 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
272 {
273  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
274  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
275  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
276  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
277 
278  if(m == 1)
279  {
280  if(workload <= 278.7000f)
281  {
282  if(workload <= 7.5000f)
283  {
284  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
285  }
286  else
287  {
288  if(r_mn <= 0.0031f)
289  {
290  if(workload <= 256.6000f)
291  {
292  if(workload <= 16.7500f)
293  {
294  if(r_nk <= 1.6671f)
295  {
296  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
297  }
298  else
299  {
300  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
301  }
302  }
303  else
304  {
305  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
306  }
307  }
308  else
309  {
310  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
311  }
312  }
313  else
314  {
315  if(r_mk <= 0.0027f)
316  {
317  if(r_mk <= 0.0014f)
318  {
319  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
320  }
321  else
322  {
323  if(workload <= 8.9500f)
324  {
325  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
326  }
327  else
328  {
329  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
330  }
331  }
332  }
333  else
334  {
335  if(workload <= 14.1500f)
336  {
337  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
338  }
339  else
340  {
341  if(r_mk <= 0.0041f)
342  {
343  return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
344  }
345  else
346  {
347  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
348  }
349  }
350  }
351  }
352  }
353  }
354  else
355  {
356  if(workload <= 363.7000f)
357  {
358  if(r_mk <= 0.0031f)
359  {
360  return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
361  }
362  else
363  {
364  return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0);
365  }
366  }
367  else
368  {
369  return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
370  }
371  }
372  }
373  else
374  {
375  if(workload <= 1384.8000f)
376  {
377  if(workload <= 704.0000f)
378  {
379  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0);
380  }
381  else
382  {
383  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
384  }
385  }
386  else
387  {
388  if(workload <= 16761.6006f)
389  {
390  if(r_mn <= 187.1250f)
391  {
392  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1);
393  }
394  else
395  {
396  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
397  }
398  }
399  else
400  {
401  if(r_mk <= 432.4630f)
402  {
403  return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1);
404  }
405  else
406  {
407  return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1);
408  }
409  }
410  }
411  }
412 }
413 
414 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
415 {
416  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
418  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
419  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
420 
421  if(m == 1)
422  {
423  if(r_mn <= 0.0045f)
424  {
425  if(workload <= 278.7000f)
426  {
427  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 0, 1, 1);
428  }
429  else
430  {
431  return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 32, 0, 0, 1, 0, 0);
432  }
433  }
434  else
435  {
436  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 1, 0, 0);
437  }
438  }
439  else
440  {
441  if(workload <= 1384.8000f)
442  {
443  if(r_nk <= 0.8333f)
444  {
445  if(r_mk <= 0.9119f)
446  {
447  return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
448  }
449  else
450  {
451  if(r_nk <= 0.1181f)
452  {
453  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
454  }
455  else
456  {
457  return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
458  }
459  }
460  }
461  else
462  {
463  if(r_mk <= 1.0013f)
464  {
465  return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
466  }
467  else
468  {
469  return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
470  }
471  }
472  }
473  else
474  {
475  if(workload <= 11404.7998f)
476  {
477  if(r_mk <= 2.2884f)
478  {
479  if(r_nk <= 0.9286f)
480  {
481  return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
482  }
483  else
484  {
485  return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
486  }
487  }
488  else
489  {
490  return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
491  }
492  }
493  else
494  {
495  if(r_nk <= 1.1926f)
496  {
497  if(r_mn <= 1385.7917f)
498  {
499  return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
500  }
501  else
502  {
503  return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0);
504  }
505  }
506  else
507  {
508  return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1);
509  }
510  }
511  }
512  }
513 }
514 
515 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
516 {
517  unsigned int best_m0;
518  unsigned int best_n0;
519 
520  if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
521  {
522  return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
523  }
524  else
525  {
526  return configure_G77_f32(m, n, k, b);
527  }
528 }
529 
530 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
531 {
532  unsigned int best_m0;
533  unsigned int best_n0;
534 
535  if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
536  {
537  return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
538  }
539  else
540  {
541  return configure_G78_f16(m, n, k, b);
542  }
543 }
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
Basic container for the OpenCL GEMM configuration functions.
SimpleTensor< float > b
Definition: DFT.cpp:157
Basic interface for the GEMM kernel configuration.
1 channel, 1 F32 per channel
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:2303
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2022 Arm Limited.
1 channel, 1 F16 per channel
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:2318
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, const DataType data_type, unsigned int &best_m0, unsigned int &best_n0)
Determine if the MMUL kernels should be preferred.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
DataType
Available data types.
Definition: Types.h:79