Compute Library
 22.11
ClGemmDefaultConfigReshapedValhall.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
30 
31 #include <utility>
32 
33 namespace arm_compute
34 {
35 namespace opencl
36 {
37 namespace kernels
38 {
39 namespace gemm
40 {
42  : IClGemmKernelConfig(gpu)
43 {
44 }
45 
46 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
47 {
48  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
49 
50  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedValhall::configure_G77_f32,
51  &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
52  &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
53 
54  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedValhall::configure_G78_f32,
55  &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
56  &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
57 
58  ConfigurationFunctionExecutorPtr func = nullptr;
59 
60  switch(_target)
61  {
62  case GPUTarget::G78:
63  func = configs_G78.get_function(data_type);
64  break;
65  case GPUTarget::G77:
66  default:
67  func = configs_G77.get_function(data_type);
68  break;
69  }
70 
71  ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
72  return (this->*func)(m, n, k, b);
73 }
74 
75 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
76 {
79 
80  if(n <= 4)
81  {
82  return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
83  }
84  else
85  {
86  return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1);
87  }
88 }
89 
90 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
91 {
94 
95  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
96  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
97  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
98  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
99 
100  GEMMLHSMatrixInfo lhs_info_buf;
101  GEMMRHSMatrixInfo rhs_info_buf;
102  GEMMLHSMatrixInfo lhs_info_img;
103  GEMMRHSMatrixInfo rhs_info_img;
104 
105  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
106 
107  if(r_mk <= 0.11824845522642136)
108  {
109  if(workload <= 880.0)
110  {
111  return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
112  }
113  else
114  {
115  if(r_nk <= 0.42521367967128754)
116  {
117  if(workload <= 1726.4000244140625)
118  {
119  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
120  }
121  else
122  {
123  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
124 
125  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
126  std::make_pair(lhs_info_buf, rhs_info_buf),
127  n, k, b, DataType::F16);
128  }
129  }
130  else
131  {
132  if(workload <= 1241.6000366210938)
133  {
134  return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
135  }
136  else
137  {
138  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
139  }
140  }
141  }
142  }
143  else
144  {
145  if(workload <= 11404.7998046875)
146  {
147  if(r_mk <= 1.0126488208770752)
148  {
149  if(r_mn <= 2.545312523841858)
150  {
151  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
152 
153  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154  std::make_pair(lhs_info_buf, rhs_info_buf),
155  n, k, b, DataType::F16);
156  }
157  else
158  {
159  return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
160  }
161  }
162  else
163  {
164  if(workload <= 2881.199951171875)
165  {
166  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
167 
168  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
169  std::make_pair(lhs_info_buf, rhs_info_buf),
170  n, k, b, DataType::F16);
171  }
172  else
173  {
174  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
175 
176  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
177  std::make_pair(lhs_info_buf, rhs_info_buf),
178  n, k, b, DataType::F16);
179  }
180  }
181  }
182  else
183  {
184  if(r_nk <= 0.5765306055545807)
185  {
186  if(r_mn <= 6.010416746139526)
187  {
188  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
189 
190  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
191  std::make_pair(lhs_info_buf, rhs_info_buf),
192  n, k, b, DataType::F16);
193  }
194  else
195  {
196  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
197 
198  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
199  std::make_pair(lhs_info_buf, rhs_info_buf),
200  n, k, b, DataType::F16);
201  }
202  }
203  else
204  {
205  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
206 
207  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
208  std::make_pair(lhs_info_buf, rhs_info_buf),
209  n, k, b, DataType::F16);
210  }
211  }
212  }
213 }
214 
215 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
216 {
217  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
218  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
219  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
220  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221 
222  if(workload <= 1288.0000f)
223  {
224  if(workload <= 505.6000f)
225  {
226  if(r_mn <= 0.4466f)
227  {
228  if(r_nk <= 0.2384f)
229  {
230  return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
231  }
232  else
233  {
234  return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
235  }
236  }
237  else
238  {
239  return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
240  }
241  }
242  else
243  {
244  if(r_mn <= 0.2250f)
245  {
246  if(r_mn <= 0.1599f)
247  {
248  return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
249  }
250  else
251  {
252  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
253  }
254  }
255  else
256  {
257  if(r_mk <= 0.7609f)
258  {
259  if(r_mn <= 2.5453f)
260  {
261  if(workload <= 1089.6000f)
262  {
263  return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
264  }
265  else
266  {
267  return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1);
268  }
269  }
270  else
271  {
272  return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1);
273  }
274  }
275  else
276  {
277  return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
278  }
279  }
280  }
281  }
282  else
283  {
284  if(workload <= 5434.4001f)
285  {
286  if(workload <= 1603.2000f)
287  {
288  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
289  }
290  else
291  {
292  if(r_nk <= 0.6192f)
293  {
294  if(r_mn <= 16.1016f)
295  {
296  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
297  }
298  else
299  {
300  if(workload <= 2750.0000f)
301  {
302  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
303  }
304  else
305  {
306  if(r_mk <= 6.3151f)
307  {
308  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
309  }
310  else
311  {
312  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
313  }
314  }
315  }
316  }
317  else
318  {
319  if(r_mk <= 0.0387f)
320  {
321  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
322  }
323  else
324  {
325  if(r_mk <= 2.5859f)
326  {
327  if(r_mk <= 0.2734f)
328  {
329  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
330  }
331  else
332  {
333  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
334  }
335  }
336  else
337  {
338  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
339  }
340  }
341  }
342  }
343  }
344  else
345  {
346  if(r_mk <= 25.7500f)
347  {
348  if(r_mk <= 0.3615f)
349  {
350  if(r_mn <= 0.0913f)
351  {
352  if(r_mk <= 0.0683f)
353  {
354  return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
355  }
356  else
357  {
358  return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
359  }
360  }
361  else
362  {
363  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
364  }
365  }
366  else
367  {
368  if(workload <= 11174.3999f)
369  {
370  if(r_mk <= 0.8047f)
371  {
372  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
373  }
374  else
375  {
376  if(workload <= 7185.5999f)
377  {
378  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
379  }
380  else
381  {
382  return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
383  }
384  }
385  }
386  else
387  {
388  if(workload <= 17917.5000f)
389  {
390  if(r_mk <= 1.5078f)
391  {
392  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
393  }
394  else
395  {
396  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
397  }
398  }
399  else
400  {
401  if(workload <= 34449.6016f)
402  {
403  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
404  }
405  else
406  {
407  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
408  }
409  }
410  }
411  }
412  }
413  else
414  {
415  if(r_mk <= 331.1111f)
416  {
417  if(workload <= 53397.5996f)
418  {
419  if(r_mn <= 57.8063f)
420  {
421  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
422  }
423  else
424  {
425  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
426  }
427  }
428  else
429  {
430  if(r_nk <= 0.9211f)
431  {
432  return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
433  }
434  else
435  {
436  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
437  }
438  }
439  }
440  else
441  {
442  if(workload <= 38070.4004f)
443  {
444  return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
445  }
446  else
447  {
448  return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
449  }
450  }
451  }
452  }
453  }
454 }
455 
456 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
457 {
458  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
459  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
460  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
461 
462  if(workload <= 801.6000f)
463  {
464  return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
465  }
466  else
467  {
468  if(r_mn <= 0.1211f)
469  {
470  if(workload <= 3296.0000f)
471  {
472  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
473  }
474  else
475  {
476  if(r_nk <= 1.0625f)
477  {
478  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
479  }
480  else
481  {
482  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
483  }
484  }
485  }
486  else
487  {
488  if(workload <= 5068.8000f)
489  {
490  return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
491  }
492  else
493  {
494  if(r_nk <= 0.2361f)
495  {
496  if(workload <= 12630.0000f)
497  {
498  return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
499  }
500  else
501  {
502  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1);
503  }
504  }
505  else
506  {
507  if(workload <= 178790.3984f)
508  {
509  return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
510  }
511  else
512  {
513  return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
514  }
515  }
516  }
517  }
518  }
519 }
520 
521 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
522 {
525 
526  if(n <= 4)
527  {
528  return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
529  }
530  else
531  {
532  return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1);
533  }
534 }
535 } // namespace gemm
536 } // namespace kernels
537 } // namespace opencl
538 } // namespace arm_compute
Basic container for the OpenCL GEMM configuration functions.
SimpleTensor< float > b
Definition: DFT.cpp:157
Basic interface for the GEMM kernel configuration.
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:2303
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2022 Arm Limited.
1 channel, 1 F16 per channel
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:2318
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
DataType
Available data types.
Definition: Types.h:79
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...