Compute Library
 21.02
CLGEMMDefaultConfigReshapedRHSOnlyBifrost.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
33 
34 #include <map>
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
42 
45 {
46 }
47 
48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMDefaultConfigReshapedRHSOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
51  unsigned int b);
52 
53  // Configurations for Mali-G51
54  static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
55  {
56  { DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32 },
57  { DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16 },
58  { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 },
59  { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 },
60  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 },
61  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8 }
62  };
63 
64  // Configurations for Mali-G52
65  static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
66  {
67  { DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32 },
68  { DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16 },
69  { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
70  { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
71  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
72  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 }
73  };
74 
75  // Configurations for Mali-G76
76  static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
77  {
78  { DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32 },
79  { DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16 },
80  { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 },
81  { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 },
82  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 },
83  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8 }
84  };
85 
86  // Configurations for Mali-G7x
87  static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
88  {
89  { DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32 },
90  { DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16 },
91  { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
92  { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
93  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 },
94  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8 }
95  };
96 
97  switch(_target)
98  {
99  case GPUTarget::G76:
100  if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
101  {
102  return (this->*gemm_configs_G76[data_type])(m, n, k, b);
103  }
104  else
105  {
106  ARM_COMPUTE_ERROR("Not supported data type");
107  }
108  case GPUTarget::G52:
109  if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
110  {
111  return (this->*gemm_configs_G52[data_type])(m, n, k, b);
112  }
113  else
114  {
115  ARM_COMPUTE_ERROR("Not supported data type");
116  }
117  case GPUTarget::G51:
118  if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
119  {
120  return (this->*gemm_configs_G51[data_type])(m, n, k, b);
121  }
122  else
123  {
124  ARM_COMPUTE_ERROR("Not supported data type");
125  }
126  default:
127  if(gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
128  {
129  return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
130  }
131  else
132  {
133  ARM_COMPUTE_ERROR("Not supported data type");
134  }
135  }
136 }
137 
138 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
139 {
142 
143  if(m == 1)
144  {
145  if(n <= 2548)
146  {
147  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
148  }
149  else
150  {
151  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
152  }
153  }
154  else
155  {
156  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
157  }
158 }
159 
160 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
161 {
164 
165  GEMMLHSMatrixInfo lhs_info_buf;
166  GEMMRHSMatrixInfo rhs_info_buf;
167  GEMMLHSMatrixInfo lhs_info_img;
168  GEMMRHSMatrixInfo rhs_info_img;
169 
170  const bool is_workload_big = ((m * n * b) / 16) >= 2048;
171 
172  if(m == 1)
173  {
174  if(n >= 8192)
175  {
176  const unsigned int h0 = std::max(n / 4, 1U);
177  return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
178  }
179  else
180  {
181  const unsigned int h0 = std::max(n / 2, 1U);
182  if(n <= 204)
183  {
184  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
185  }
186  else
187  {
188  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
189  }
190  }
191  }
192  else
193  {
194  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
195  if(is_workload_big)
196  {
197  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
198  }
199  else
200  {
201  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
202  }
203  }
204 
205  // Get lhs_info/rhs_info in case of OpenCL image
206  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
207  if(is_workload_big)
208  {
209  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
210  }
211  else
212  {
213  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
214  }
215 
216  const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
217  const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
218  const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
219 
220  // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
221  const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
222 
223  if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
224  {
225  return std::make_pair(lhs_info_img, rhs_info_img);
226  }
227  else
228  {
229  return std::make_pair(lhs_info_buf, rhs_info_buf);
230  }
231 }
232 
233 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
234 {
235  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
236  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
237 
238  GEMMLHSMatrixInfo lhs_info_buf;
239  GEMMRHSMatrixInfo rhs_info_buf;
240  GEMMLHSMatrixInfo lhs_info_img;
241  GEMMRHSMatrixInfo rhs_info_img;
242 
243  if(m == 1)
244  {
245  if(r_nk <= 0.4664f)
246  {
247  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
248  }
249  else
250  {
251  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
252  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
253 
254  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
255  std::make_pair(lhs_info_buf, rhs_info_buf),
256  n, k, b, DataType::F32);
257  }
258  }
259  else
260  {
261  if(workload <= 274.4000f)
262  {
263  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
264  }
265  else
266  {
267  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
268  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
269 
270  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
271  std::make_pair(lhs_info_buf, rhs_info_buf),
272  n, k, b, DataType::F32);
273  }
274  }
275 }
276 
277 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
278 {
281 
282  if(m == 1)
283  {
284  const unsigned int n0 = n < 1280 ? 2 : 4;
285  const unsigned int h0 = std::max(n / n0, 1U);
286  return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
287  }
288  else
289  {
290  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
291  }
292 }
293 
294 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
295 {
298 
299  if(m == 1)
300  {
301  if(n > 2048)
302  {
303  const unsigned int h0 = std::max(n / 4, 1U);
304  return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
305  }
306  else
307  {
308  const unsigned int h0 = std::max(n / 2, 1U);
309  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
310  }
311  }
312  else
313  {
314  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
315  }
316 }
317 
318 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
319 {
320  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
321  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
322  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
323  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
324 
325  GEMMLHSMatrixInfo lhs_info_buf;
326  GEMMRHSMatrixInfo rhs_info_buf;
327  GEMMLHSMatrixInfo lhs_info_img;
328  GEMMRHSMatrixInfo rhs_info_img;
329 
330  if(m == 1)
331  {
332  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
333 
334  if(r_mk <= 0.0026f)
335  {
336  if(r_nk <= 0.4664f)
337  {
338  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
339  }
340  else
341  {
342  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
343  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
344  std::make_pair(lhs_info_buf, rhs_info_buf),
345  n, k, b, DataType::F16);
346  }
347  }
348  else
349  {
350  if(r_mk <= 0.0148f)
351  {
352  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
353  }
354  else
355  {
356  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
357  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
358  std::make_pair(lhs_info_buf, rhs_info_buf),
359  n, k, b, DataType::F16);
360  }
361  }
362  }
363  else
364  {
365  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
366 
367  if(workload <= 362.6000f)
368  {
369  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
370  }
371  else
372  {
373  if(r_mn <= 22.6067f)
374  {
375  if(workload <= 708.8000f)
376  {
377  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
378  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
379  std::make_pair(lhs_info_buf, rhs_info_buf),
380  n, k, b, DataType::F16);
381  }
382  else
383  {
384  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
385  }
386  }
387  else
388  {
389  if(r_nk <= 0.0917f)
390  {
391  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
392  }
393  else
394  {
395  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
396  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
397  std::make_pair(lhs_info_buf, rhs_info_buf),
398  n, k, b, DataType::F16);
399  }
400  }
401  }
402  }
403 }
404 
405 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
406 {
408 
409  if(m == 1)
410  {
411  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
412  }
413  else
414  {
415  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
416  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417 
418  if(workload <= 7449.60f)
419  {
420  if(workload <= 691.60f)
421  {
422  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
423  }
424  else
425  {
426  if(workload <= 4155.20f)
427  {
428  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
429  }
430  else
431  {
432  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
433  }
434  }
435  }
436  else
437  {
438  if(workload <= 16300.80f)
439  {
440  if(r_mn <= 44.56f)
441  {
442  GEMMLHSMatrixInfo lhs_info_buf;
443  GEMMRHSMatrixInfo rhs_info_buf;
444  GEMMLHSMatrixInfo lhs_info_img;
445  GEMMRHSMatrixInfo rhs_info_img;
446 
447  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
448  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
449 
450  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
451  std::make_pair(lhs_info_buf, rhs_info_buf),
452  n, k, b, DataType::F16);
453  }
454  else
455  {
456  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
457  }
458  }
459  else
460  {
461  GEMMLHSMatrixInfo lhs_info_buf;
462  GEMMRHSMatrixInfo rhs_info_buf;
463  GEMMLHSMatrixInfo lhs_info_img;
464  GEMMRHSMatrixInfo rhs_info_img;
465 
466  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
467  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
468 
469  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
470  std::make_pair(lhs_info_buf, rhs_info_buf),
471  n, k, b, DataType::F16);
472  }
473  }
474  }
475 }
476 
477 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
478 {
481 
482  if(m == 1)
483  {
484  const unsigned int n0 = n < 1280 ? 2 : 4;
485  const unsigned int h0 = std::max(n / n0, 1U);
486  return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
487  }
488  else
489  {
490  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
491  }
492 }
493 
494 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
495 {
498 
499  if(dot8_supported(CLKernelLibrary::get().get_device()))
500  {
501  if(m == 1)
502  {
503  const unsigned int h0 = std::max(n / 2, 1U);
504  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
505  }
506  else
507  {
508  const unsigned int h0 = std::max(n / 4, 1U);
509  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
510  }
511  }
512  else
513  {
514  const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
515  if(m == 1)
516  {
517  return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
518  }
519  else
520  {
521  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
522  }
523  }
524 }
525 
526 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
527 {
530 
531  if(m == 1)
532  {
533  const unsigned int h0 = std::max(n / 2, 1U);
534  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
535  }
536  else
537  {
538  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
539  }
540 }
541 
542 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
543 {
546 
547  if(m == 1)
548  {
549  const unsigned int h0 = std::max(n / 2, 1U);
550  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
551  }
552  else
553  {
554  const unsigned int h0 = std::max(n / 2, 1U);
555  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
556  }
557 }
558 
559 } // namespace cl_gemm
560 } // namespace arm_compute
Shape of a tensor.
Definition: TensorShape.h:39
Basic interface for the GEMM kernel configuration.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported...
Definition: CLHelpers.cpp:239
SimpleTensor< float > b
Definition: DFT.cpp:157
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:1968
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
const DataType data_type
Definition: Im2Col.cpp:150
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:1983
quantized, asymmetric fixed-point 8-bit number unsigned
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:45
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
Definition: Types.h:77
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.