Compute Library
 22.11
ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
33 #include <utility>
34 
35 namespace arm_compute
36 {
37 namespace opencl
38 {
39 namespace kernels
40 {
41 namespace gemm
42 {
44 
46  : IClGemmKernelConfig(gpu)
47 {
48 }
49 
50 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
51 {
52  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
53  unsigned int b);
54 
55  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
56  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
57  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
58 
59  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
60  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
61  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
62 
63  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
64  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
65  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
66 
67  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
68  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
69  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
70 
71  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
72  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
73  &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
74 
75  ConfigurationFunctionExecutorPtr func = nullptr;
76  switch(_target)
77  {
78  case GPUTarget::G76:
79  func = configs_G76.get_function(data_type);
80  break;
81  case GPUTarget::G51:
82  func = configs_G51.get_function(data_type);
83  break;
84  case GPUTarget::G52:
85  func = configs_G52.get_function(data_type);
86  break;
87  case GPUTarget::G31:
88  func = configs_G31.get_function(data_type);
89  break;
90  default:
91  func = configs_G7x.get_function(data_type);
92  break;
93  }
94 
95  ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
96  return (this->*func)(m, n, k, b);
97 }
98 
99 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
100 {
103 
104  if(m == 1)
105  {
106  if(n <= 2548)
107  {
108  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
109  }
110  else
111  {
112  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
113  }
114  }
115  else
116  {
117  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
118  }
119 }
120 
121 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
122 {
125 
126  if(m == 1)
127  {
128  const unsigned int h0 = std::max(n / 2, 1U);
129  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
130  }
131  else
132  {
133  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
134  if(m >= 28)
135  {
136  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
137  }
138  else
139  {
140  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
141  }
142  }
143 }
144 
145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
146 {
149 
150  GEMMLHSMatrixInfo lhs_info_buf;
151  GEMMRHSMatrixInfo rhs_info_buf;
152  GEMMLHSMatrixInfo lhs_info_img;
153  GEMMRHSMatrixInfo rhs_info_img;
154 
155  const bool is_workload_big = ((m * n * b) / 16) >= 2048;
156 
157  if(m == 1)
158  {
159  if(n >= 8192)
160  {
161  const unsigned int h0 = std::max(n / 4, 1U);
162  return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
163  }
164  else
165  {
166  const unsigned int h0 = std::max(n / 2, 1U);
167  if(n <= 204)
168  {
169  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
170  }
171  else
172  {
173  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
174  }
175  }
176  }
177  else
178  {
179  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
180  if(is_workload_big)
181  {
182  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
183  }
184  else
185  {
186  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
187  }
188  }
189 
190  // Get lhs_info/rhs_info in case of OpenCL image
191  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
192  if(is_workload_big)
193  {
194  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
195  }
196  else
197  {
198  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
199  }
200 
201  const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
202  const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
203  const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
204 
205  // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
206  const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
207 
208  if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
209  {
210  return std::make_pair(lhs_info_img, rhs_info_img);
211  }
212  else
213  {
214  return std::make_pair(lhs_info_buf, rhs_info_buf);
215  }
216 }
217 
218 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
219 {
220  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
222 
223  GEMMLHSMatrixInfo lhs_info_buf;
224  GEMMRHSMatrixInfo rhs_info_buf;
225  GEMMLHSMatrixInfo lhs_info_img;
226  GEMMRHSMatrixInfo rhs_info_img;
227 
228  if(m == 1)
229  {
230  if(r_nk <= 0.4664f)
231  {
232  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
233  }
234  else
235  {
236  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
237  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
238 
239  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240  std::make_pair(lhs_info_buf, rhs_info_buf),
241  n, k, b, DataType::F32);
242  }
243  }
244  else
245  {
246  if(workload <= 274.4000f)
247  {
248  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
249  }
250  else
251  {
252  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
253  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
254 
255  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
256  std::make_pair(lhs_info_buf, rhs_info_buf),
257  n, k, b, DataType::F32);
258  }
259  }
260 }
261 
262 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
263 {
266 
267  if(m == 1)
268  {
269  const unsigned int n0 = n < 1280 ? 2 : 4;
270  const unsigned int h0 = std::max(n / n0, 1U);
271  return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
272  }
273  else
274  {
275  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
276  }
277 }
278 
279 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
280 {
283 
284  if(m == 1)
285  {
286  if(n > 2048)
287  {
288  const unsigned int h0 = std::max(n / 4, 1U);
289  return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
290  }
291  else
292  {
293  const unsigned int h0 = std::max(n / 2, 1U);
294  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
295  }
296  }
297  else
298  {
299  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
300  }
301 }
302 
303 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
304 {
305  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
306  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
307  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
308  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
309 
310  GEMMLHSMatrixInfo lhs_info_buf;
311  GEMMRHSMatrixInfo rhs_info_buf;
312  GEMMLHSMatrixInfo lhs_info_img;
313  GEMMRHSMatrixInfo rhs_info_img;
314 
315  if(m == 1)
316  {
317  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
318 
319  if(r_mk <= 0.0026f)
320  {
321  if(r_nk <= 0.4664f)
322  {
323  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
324  }
325  else
326  {
327  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
328  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
329  std::make_pair(lhs_info_buf, rhs_info_buf),
330  n, k, b, DataType::F16);
331  }
332  }
333  else
334  {
335  if(r_mk <= 0.0148f)
336  {
337  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
338  }
339  else
340  {
341  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
342  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
343  std::make_pair(lhs_info_buf, rhs_info_buf),
344  n, k, b, DataType::F16);
345  }
346  }
347  }
348  else
349  {
350  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
351 
352  if(workload <= 362.6000f)
353  {
354  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
355  }
356  else
357  {
358  if(r_mn <= 22.6067f)
359  {
360  if(workload <= 708.8000f)
361  {
362  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
363  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
364  std::make_pair(lhs_info_buf, rhs_info_buf),
365  n, k, b, DataType::F16);
366  }
367  else
368  {
369  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
370  }
371  }
372  else
373  {
374  if(r_nk <= 0.0917f)
375  {
376  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
377  }
378  else
379  {
380  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
381  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
382  std::make_pair(lhs_info_buf, rhs_info_buf),
383  n, k, b, DataType::F16);
384  }
385  }
386  }
387  }
388 }
389 
390 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
391 {
393 
394  if(m == 1)
395  {
396  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
397  }
398  else
399  {
400  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
401  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
402 
403  if(workload <= 7449.60f)
404  {
405  if(workload <= 691.60f)
406  {
407  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
408  }
409  else
410  {
411  if(workload <= 4155.20f)
412  {
413  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
414  }
415  else
416  {
417  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
418  }
419  }
420  }
421  else
422  {
423  if(workload <= 16300.80f)
424  {
425  if(r_mn <= 44.56f)
426  {
427  GEMMLHSMatrixInfo lhs_info_buf;
428  GEMMRHSMatrixInfo rhs_info_buf;
429  GEMMLHSMatrixInfo lhs_info_img;
430  GEMMRHSMatrixInfo rhs_info_img;
431 
432  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
433  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
434 
435  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
436  std::make_pair(lhs_info_buf, rhs_info_buf),
437  n, k, b, DataType::F16);
438  }
439  else
440  {
441  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
442  }
443  }
444  else
445  {
446  GEMMLHSMatrixInfo lhs_info_buf;
447  GEMMRHSMatrixInfo rhs_info_buf;
448  GEMMLHSMatrixInfo lhs_info_img;
449  GEMMRHSMatrixInfo rhs_info_img;
450 
451  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
452  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
453 
454  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
455  std::make_pair(lhs_info_buf, rhs_info_buf),
456  n, k, b, DataType::F16);
457  }
458  }
459  }
460 }
461 
462 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
463 {
466 
467  if(m == 1)
468  {
469  const unsigned int n0 = n < 1280 ? 2 : 4;
470  const unsigned int h0 = std::max(n / n0, 1U);
471  return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
472  }
473  else
474  {
475  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
476  }
477 }
478 
479 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
480 {
483 
484  if(dot8_supported(CLKernelLibrary::get().get_device()))
485  {
486  if(m == 1)
487  {
488  const unsigned int h0 = std::max(n / 2, 1U);
489  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
490  }
491  else
492  {
493  const unsigned int h0 = std::max(n / 4, 1U);
494  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
495  }
496  }
497  else
498  {
499  const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
500  if(m == 1)
501  {
502  return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
503  }
504  else
505  {
506  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
507  }
508  }
509 }
510 
511 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
512 {
515 
516  if(m == 1)
517  {
518  const unsigned int h0 = std::max(n / 2, 1U);
519  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
520  }
521  else
522  {
523  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
524  }
525 }
526 
527 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
528 {
531 
532  if(m == 1)
533  {
534  const unsigned int h0 = std::max(n / 2, 1U);
535  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
536  }
537  else
538  {
539  const unsigned int h0 = std::max(n / 2, 1U);
540  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
541  }
542 }
543 
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
Basic container for the OpenCL GEMM configuration functions.
Shape of a tensor.
Definition: TensorShape.h:39
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
Definition: CLHelpers.cpp:241
SimpleTensor< float > b
Definition: DFT.cpp:157
Basic interface for the GEMM kernel configuration.
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:2303
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Copyright (c) 2017-2022 Arm Limited.
1 channel, 1 F16 per channel
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:2318
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
DataType
Available data types.
Definition: Types.h:79