Compute Library
 23.11
ClGemmDefaultConfigReshapedBifrost.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
32 
34 
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace opencl
40 {
41 namespace kernels
42 {
43 namespace gemm
44 {
46 
48 {
49 }
50 
51 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure(
52  unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
53 {
54  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
55  ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
56 
58  &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
59  &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
60 
62  &ClGemmDefaultConfigReshapedBifrost::configure_G52_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
63  &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
64 
66  &ClGemmDefaultConfigReshapedBifrost::configure_G76_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
67  &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
68 
69  ConfigurationFunctionExecutorPtr func = nullptr;
70 
71  switch (_target)
72  {
73  case GPUTarget::G76:
74  func = configs_G76.get_function(data_type);
75  break;
76  case GPUTarget::G52:
77  func = configs_G52.get_function(data_type);
78  break;
79  default:
80  func = configs_G7x.get_function(data_type);
81  break;
82  }
83 
84  ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
85  return (this->*func)(m, n, k, b);
86 }
87 
88 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
89 ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
90 {
93 
94  if (n <= 4)
95  {
96  return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
97  }
98  else
99  {
100  return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
101  }
102 }
103 
104 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
105 ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
106 {
109 
110  if (n <= 4)
111  {
112  return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
113  }
114  else
115  {
116  return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
117  }
118 }
119 
120 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
121 ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
122 {
125 
126  if (dot8_supported(CLKernelLibrary::get().get_device()))
127  {
128  if (n <= 4)
129  {
130  return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
131  }
132  else
133  {
134  return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true);
135  }
136  }
137  else
138  {
139  if (n <= 4)
140  {
141  return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
142  }
143  else
144  {
145  return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true);
146  }
147  }
148 }
149 
150 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
151 ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
152 {
153  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
154  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
155  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
156  const float r_nk = static_cast<float>(n) / static_cast<float>(k);
157 
158  GEMMLHSMatrixInfo lhs_info_buf;
159  GEMMRHSMatrixInfo rhs_info_buf;
160  GEMMLHSMatrixInfo lhs_info_img;
161  GEMMRHSMatrixInfo rhs_info_img;
162 
163  if (workload <= 274.4000f)
164  {
165  if (r_nk <= 0.7461f)
166  {
167  if (r_mn <= 21.1667f)
168  {
169  return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
170  }
171  else
172  {
173  std::tie(lhs_info_img, rhs_info_img) =
174  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
175  std::tie(lhs_info_buf, rhs_info_buf) =
176  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
177 
178  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
179  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
180  }
181  }
182  else
183  {
184  std::tie(lhs_info_img, rhs_info_img) =
185  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
186  std::tie(lhs_info_buf, rhs_info_buf) =
187  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
188 
189  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
190  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
191  }
192  }
193  else
194  {
195  if (r_mk <= 17.3926f)
196  {
197  if (workload <= 542.4000f)
198  {
199  std::tie(lhs_info_img, rhs_info_img) =
200  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
201  std::tie(lhs_info_buf, rhs_info_buf) =
202  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
203 
204  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
205  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
206  }
207  else
208  {
209  std::tie(lhs_info_img, rhs_info_img) =
210  configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
211  std::tie(lhs_info_buf, rhs_info_buf) =
212  configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
213 
214  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
215  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
216  }
217  }
218  else
219  {
220  if (r_nk <= 0.5463f)
221  {
222  if (workload <= 11767.6001f)
223  {
224  std::tie(lhs_info_img, rhs_info_img) =
225  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
226  std::tie(lhs_info_buf, rhs_info_buf) =
227  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
228 
229  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
230  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
231  }
232  else
233  {
234  std::tie(lhs_info_img, rhs_info_img) =
235  configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
236  std::tie(lhs_info_buf, rhs_info_buf) =
237  configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
238 
239  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
240  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
241  }
242  }
243  else
244  {
245  std::tie(lhs_info_img, rhs_info_img) =
246  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
247  std::tie(lhs_info_buf, rhs_info_buf) =
248  configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
249 
250  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
251  std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
252  }
253  }
254  }
255 }
256 
257 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
258 ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
259 {
261 
262  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
263 
264  if (workload <= 323.4000f)
265  {
266  return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
267  }
268  else
269  {
270  return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
271  }
272 }
273 
274 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
275 ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
276 {
279 
280  GEMMLHSMatrixInfo lhs_info_buf;
281  GEMMRHSMatrixInfo rhs_info_buf;
282  GEMMLHSMatrixInfo lhs_info_img;
283  GEMMRHSMatrixInfo rhs_info_img;
284 
285  // Get lhs_info/rhs_info in case of OpenCL buffer
286  if (n <= 4)
287  {
288  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
289  }
290  else
291  {
292  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
293  }
294 
295  // Get lhs_info/rhs_info in case of OpenCL image
296  // Condition on the GPU workload
297  if ((m / 4) * (n / 4) >= 2560)
298  {
299  // Big workload
300  std::tie(lhs_info_img, rhs_info_img) =
301  configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
302  }
303  else
304  {
305  // Small workload
306  std::tie(lhs_info_img, rhs_info_img) =
307  configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
308  }
309 
310  const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
311  const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
312  const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
313 
314  // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
315  const bool use_cl_image2d = (n <= 4) ? false : true;
316 
317  if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
318  {
319  return std::make_pair(lhs_info_img, rhs_info_img);
320  }
321  else
322  {
323  return std::make_pair(lhs_info_buf, rhs_info_buf);
324  }
325 }
326 
327 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
328 ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
329 {
330  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
331  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
332 
333  if (workload <= 1595.2000f)
334  {
335  if (r_mk <= 2.1044f)
336  {
337  if (workload <= 870.4000f)
338  {
339  return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
340  }
341  else
342  {
343  return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
344  }
345  }
346  else
347  {
348  return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
349  }
350  }
351  else
352  {
353  return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false);
354  }
355 }
356 
357 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
358 ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
359 {
362 
363  if (n <= 4)
364  {
365  return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
366  }
367  else
368  {
369  return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
370  }
371 }
372 } // namespace gemm
373 } // namespace kernels
374 } // namespace opencl
375 } // namespace arm_compute
arm_compute::opencl::kernels::gemm::configure_lhs_rhs_info
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Definition: ClGemmHelpers.cpp:42
arm_compute::dot8_supported
bool dot8_supported(const cl::Device &device)
Helper function to check whether the cl_arm_integer_dot_product_int8 extension is supported.
Definition: CLHelpers.cpp:242
arm_gemm::gemm
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
Definition: gemm_implementation.hpp:320
arm_compute::opencl::kernels::gemm::IClGemmKernelConfig
Basic interface for the GEMM kernel configuration.
Definition: IClGemmKernelConfig.h:92
arm_compute::opencl::kernels::gemm::validate_image2d_support_on_rhs
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix.
Definition: ClGemmHelpers.cpp:121
TensorInfo.h
ClGemmHelpers.h
arm_compute::CLKernelLibrary::get
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Definition: CLKernelLibrary.cpp:41
CLKernelLibrary.h
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
arm_compute::opencl::kernels::gemm::ClGemmDefaultConfigReshapedBifrost::configure
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used.
Definition: ClGemmDefaultConfigReshapedBifrost.cpp:51
arm_compute::test::validation::shape
shape
Definition: DFT.cpp:115
arm_compute::opencl::kernels::gemm::ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost
ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)
Constructor.
Definition: ClGemmDefaultConfigReshapedBifrost.cpp:47
arm_compute::GPUTarget::G52
@ G52
ARM_COMPUTE_ERROR_ON_MSG
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
GPUTarget.h
arm_compute::opencl::kernels::gemm::CLGEMMConfigArray
Basic container for the OpenCL GEMM configuration functions.
Definition: IClGemmKernelConfig.h:43
arm_compute::opencl::kernels::gemm::ClGemmDefaultConfigReshapedBifrost
Bifrost based OpenCL GEMMReshaped configuration.
Definition: ClGemmDefaultConfigReshapedBifrost.h:38
ClGemmDefaultConfigReshapedBifrost.h
ARM_COMPUTE_UNUSED
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:151
arm_compute::test::validation::data_type
data_type
Definition: Cast.cpp:222
arm_compute::GPUTarget
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
arm_compute::misc::shape_calculator::compute_rhs_reshaped_shape
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
Definition: ShapeCalculator.h:233
ShapeCalculator.h
arm_compute::GPUTarget::G76
@ G76
arm_compute::test::validation::b
SimpleTensor< float > b
Definition: DFT.cpp:157
arm_compute
Copyright (c) 2017-2023 Arm Limited.
Definition: introduction.dox:24
TensorShape.h
arm_compute::opencl::kernels::gemm::select_lhs_rhs_info
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Definition: ClGemmHelpers.cpp:76
arm_compute::DataType::F32
@ F32
32-bit floating-point number
arm_compute::misc::shape_calculator
Definition: ShapeCalculator.h:41
arm_compute::DataType
DataType
Available data types.
Definition: CoreTypes.h:83
CLHelpers.h
arm_compute::opencl::kernels::gemm::CLGEMMConfigArray::get_function
T get_function(DataType data_type)
Method to return the GEMM configuration function based on data type.
Definition: IClGemmKernelConfig.h:70