Compute Library
 21.05
CLGEMMDefaultConfigReshapedRHSOnlyBifrost.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
33 
34 #include <utility>
35 
36 namespace arm_compute
37 {
38 namespace cl_gemm
39 {
41 
44 {
45 }
46 
47 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
48 {
49  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMDefaultConfigReshapedRHSOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
50  unsigned int b);
51 
52  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32,
53  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16,
54  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8);
55 
56  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32,
57  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16,
58  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8);
59 
60  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32,
61  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16,
62  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8);
63 
64  CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32,
65  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16,
66  &CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8);
67 
68  ConfigurationFunctionExecutorPtr func = nullptr;
69 
70  switch(_target)
71  {
72  case GPUTarget::G76:
73  func = configs_G76.get_function(data_type);
74  break;
75  case GPUTarget::G51:
76  func = configs_G51.get_function(data_type);
77  break;
78  case GPUTarget::G52:
79  func = configs_G52.get_function(data_type);
80  break;
81  default:
82  func = configs_G7x.get_function(data_type);
83  break;
84  }
85 
86  ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
87  return (this->*func)(m, n, k, b);
88 }
89 
90 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
91 {
94 
95  if(m == 1)
96  {
97  if(n <= 2548)
98  {
99  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
100  }
101  else
102  {
103  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 8, false, true, false, true, false);
104  }
105  }
106  else
107  {
108  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
109  }
110 }
111 
112 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
113 {
116 
117  GEMMLHSMatrixInfo lhs_info_buf;
118  GEMMRHSMatrixInfo rhs_info_buf;
119  GEMMLHSMatrixInfo lhs_info_img;
120  GEMMRHSMatrixInfo rhs_info_img;
121 
122  const bool is_workload_big = ((m * n * b) / 16) >= 2048;
123 
124  if(m == 1)
125  {
126  if(n >= 8192)
127  {
128  const unsigned int h0 = std::max(n / 4, 1U);
129  return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
130  }
131  else
132  {
133  const unsigned int h0 = std::max(n / 2, 1U);
134  if(n <= 204)
135  {
136  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
137  }
138  else
139  {
140  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
141  }
142  }
143  }
144  else
145  {
146  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
147  if(is_workload_big)
148  {
149  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
150  }
151  else
152  {
153  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
154  }
155  }
156 
157  // Get lhs_info/rhs_info in case of OpenCL image
158  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
159  if(is_workload_big)
160  {
161  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
162  }
163  else
164  {
165  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
166  }
167 
168  const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
169  const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
170  const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
171 
172  // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
173  const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
174 
175  if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
176  {
177  return std::make_pair(lhs_info_img, rhs_info_img);
178  }
179  else
180  {
181  return std::make_pair(lhs_info_buf, rhs_info_buf);
182  }
183 }
184 
185 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
186 {
187  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
188  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
189 
190  GEMMLHSMatrixInfo lhs_info_buf;
191  GEMMRHSMatrixInfo rhs_info_buf;
192  GEMMLHSMatrixInfo lhs_info_img;
193  GEMMRHSMatrixInfo rhs_info_img;
194 
195  if(m == 1)
196  {
197  if(r_nk <= 0.4664f)
198  {
199  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
200  }
201  else
202  {
203  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
204  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
205 
206  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
207  std::make_pair(lhs_info_buf, rhs_info_buf),
208  n, k, b, DataType::F32);
209  }
210  }
211  else
212  {
213  if(workload <= 274.4000f)
214  {
215  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
216  }
217  else
218  {
219  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
220  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
221 
222  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
223  std::make_pair(lhs_info_buf, rhs_info_buf),
224  n, k, b, DataType::F32);
225  }
226  }
227 }
228 
229 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
230 {
233 
234  if(m == 1)
235  {
236  const unsigned int n0 = n < 1280 ? 2 : 4;
237  const unsigned int h0 = std::max(n / n0, 1U);
238  return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
239  }
240  else
241  {
242  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
243  }
244 }
245 
246 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
247 {
250 
251  if(m == 1)
252  {
253  if(n > 2048)
254  {
255  const unsigned int h0 = std::max(n / 4, 1U);
256  return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
257  }
258  else
259  {
260  const unsigned int h0 = std::max(n / 2, 1U);
261  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
262  }
263  }
264  else
265  {
266  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
267  }
268 }
269 
270 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
271 {
272  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
273  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
274  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
275  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
276 
277  GEMMLHSMatrixInfo lhs_info_buf;
278  GEMMRHSMatrixInfo rhs_info_buf;
279  GEMMLHSMatrixInfo lhs_info_img;
280  GEMMRHSMatrixInfo rhs_info_img;
281 
282  if(m == 1)
283  {
284  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
285 
286  if(r_mk <= 0.0026f)
287  {
288  if(r_nk <= 0.4664f)
289  {
290  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
291  }
292  else
293  {
294  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
295  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
296  std::make_pair(lhs_info_buf, rhs_info_buf),
297  n, k, b, DataType::F16);
298  }
299  }
300  else
301  {
302  if(r_mk <= 0.0148f)
303  {
304  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
305  }
306  else
307  {
308  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
309  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
310  std::make_pair(lhs_info_buf, rhs_info_buf),
311  n, k, b, DataType::F16);
312  }
313  }
314  }
315  else
316  {
317  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
318 
319  if(workload <= 362.6000f)
320  {
321  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
322  }
323  else
324  {
325  if(r_mn <= 22.6067f)
326  {
327  if(workload <= 708.8000f)
328  {
329  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
330  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
331  std::make_pair(lhs_info_buf, rhs_info_buf),
332  n, k, b, DataType::F16);
333  }
334  else
335  {
336  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 16, false, false, false, false, false);
337  }
338  }
339  else
340  {
341  if(r_nk <= 0.0917f)
342  {
343  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
344  }
345  else
346  {
347  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
348  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
349  std::make_pair(lhs_info_buf, rhs_info_buf),
350  n, k, b, DataType::F16);
351  }
352  }
353  }
354  }
355 }
356 
357 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
358 {
360 
361  if(m == 1)
362  {
363  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
364  }
365  else
366  {
367  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
368  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
369 
370  if(workload <= 7449.60f)
371  {
372  if(workload <= 691.60f)
373  {
374  return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
375  }
376  else
377  {
378  if(workload <= 4155.20f)
379  {
380  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
381  }
382  else
383  {
384  return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
385  }
386  }
387  }
388  else
389  {
390  if(workload <= 16300.80f)
391  {
392  if(r_mn <= 44.56f)
393  {
394  GEMMLHSMatrixInfo lhs_info_buf;
395  GEMMRHSMatrixInfo rhs_info_buf;
396  GEMMLHSMatrixInfo lhs_info_img;
397  GEMMRHSMatrixInfo rhs_info_img;
398 
399  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
400  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
401 
402  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
403  std::make_pair(lhs_info_buf, rhs_info_buf),
404  n, k, b, DataType::F16);
405  }
406  else
407  {
408  return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
409  }
410  }
411  else
412  {
413  GEMMLHSMatrixInfo lhs_info_buf;
414  GEMMRHSMatrixInfo rhs_info_buf;
415  GEMMLHSMatrixInfo lhs_info_img;
416  GEMMRHSMatrixInfo rhs_info_img;
417 
418  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
419  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
420 
421  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
422  std::make_pair(lhs_info_buf, rhs_info_buf),
423  n, k, b, DataType::F16);
424  }
425  }
426  }
427 }
428 
429 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
430 {
433 
434  if(m == 1)
435  {
436  const unsigned int n0 = n < 1280 ? 2 : 4;
437  const unsigned int h0 = std::max(n / n0, 1U);
438  return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
439  }
440  else
441  {
442  return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
443  }
444 }
445 
446 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
447 {
450 
451  if(dot8_supported(CLKernelLibrary::get().get_device()))
452  {
453  if(m == 1)
454  {
455  const unsigned int h0 = std::max(n / 2, 1U);
456  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
457  }
458  else
459  {
460  const unsigned int h0 = std::max(n / 4, 1U);
461  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
462  }
463  }
464  else
465  {
466  const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
467  if(m == 1)
468  {
469  return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
470  }
471  else
472  {
473  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
474  }
475  }
476 }
477 
478 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
479 {
482 
483  if(m == 1)
484  {
485  const unsigned int h0 = std::max(n / 2, 1U);
486  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
487  }
488  else
489  {
490  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
491  }
492 }
493 
494 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
495 {
498 
499  if(m == 1)
500  {
501  const unsigned int h0 = std::max(n / 2, 1U);
502  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
503  }
504  else
505  {
506  const unsigned int h0 = std::max(n / 2, 1U);
507  return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
508  }
509 }
510 
511 } // namespace cl_gemm
512 } // namespace arm_compute
Basic interface for the GEMM kernel configuration.
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported.
Definition: CLHelpers.cpp:239
SimpleTensor< float > b
Definition: DFT.cpp:157
Basic container for the OpenCL GEMM configuration functions.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix.
1 channel, 1 F32 per channel
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
const DataType data_type
Definition: Im2Col.cpp:150
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
FloorUKernelPtr func
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
DataType
Available data types.
Definition: Types.h:77
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.