Compute Library
 21.11
ClGemm.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Error.h"
32 #include "arm_compute/core/Log.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
40 
45 #include "src/gpu/cl/IClKernel.h"
49 
50 #include "src/common/utils/Log.h"
51 #include "support/Cast.h"
52 #include "utils/TypePrinter.h"
53 
54 namespace arm_compute
55 {
56 namespace opencl
57 {
59 using namespace arm_compute::cl_gemm;
60 using namespace arm_compute::experimental;
61 using namespace arm_compute::utils::cast;
62 using namespace arm_compute::opencl::kernels;
63 
64 namespace
65 {
66 inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67 {
68  return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
69 }
70 //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
71 inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
72 {
73  if(!constant_weights)
74  {
76  }
77 
78  auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79  if(bool(gemm_kernel))
80  {
81  if(validate_gemm_kernel(gemm_kernel.gemm_type))
82  {
83  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84  return gemm_kernel.gemm_type;
85  }
86  }
87  gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89  return gemm_kernel.gemm_type;
90 }
91 // Validate lhs_info and rhs_info for reshaped only rhs kernel
92 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94 {
95  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96  TensorInfo tmp_b_info{};
97  // Validate reshape RHS kernel
98  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100  {
101  return false;
102  }
103  // Validate mm kernel
104  gemm_kernel_info.lhs_info = lhs_info;
105  gemm_kernel_info.rhs_info = rhs_info;
106  gemm_kernel_info.has_pad_y = false;
107  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108  {
109  return false;
110  }
111  gemm_kernel_info.has_pad_y = true;
112  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113  {
114  return false;
115  }
116  return true;
117 }
118 
119 //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121  const ITensorInfo *b,
122  const ITensorInfo *c, const ITensorInfo *output)
123 {
125  if(config)
126  {
127  if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128  {
129  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130  return { config.lhs_info, config.rhs_info };
131  }
132  }
134  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135  return { config.lhs_info, config.rhs_info };
136 }
137 
138 // Validate lhs_info and rhs_info for reshaped kernel
139 inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141 {
142  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143  TensorInfo tmp_a_info{};
144  TensorInfo tmp_b_info{};
145 
146  // Validate reshape LHS kernel
147  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148  if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149  {
150  return false;
151  }
152 
153  // Validate reshape RHS kernel
154  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156  {
157  return false;
158  }
159  // Validate mm kernel
160  gemm_kernel_info.lhs_info = lhs_info;
161  gemm_kernel_info.rhs_info = rhs_info;
162  if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163  {
164  return false;
165  }
166  return true;
167 }
168 
169 //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171  const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172 {
174  if(config)
175  {
176  if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177  {
178  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179  return { config.lhs_info, config.rhs_info };
180  }
181  }
183  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184  return { config.lhs_info, config.rhs_info };
185 }
186 } // namespace
187 
189  : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
190  _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
191  _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
192  _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193  _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
194  _mm_reshaped_only_rhs_fallback_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
195  _tmp_a(),
196  _tmp_b(),
197  _reshape_b_only_on_first_run(false),
198  _gemm_kernel_type(CLGEMMKernelType::NATIVE),
199  _is_prepared(false),
200  _aux_mem(AuxTensorIdx::Count)
201 {
202 }
203 
204 void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
205  const GEMMInfo &gemm_info)
206 {
208  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
209  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
210  const unsigned int n = b->dimension(0);
211  const unsigned int k = a->dimension(0);
212  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
213  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
214  const GPUTarget gpu_target = CLScheduler::get().target();
215  bool broadcast_bias = gemm_info.broadcast_bias();
216 
217  GEMMKernelInfo kernel_info;
218  kernel_info.m = m;
219  kernel_info.n = n;
220  kernel_info.k = k;
221  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
222  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
223  kernel_info.broadcast_bias = broadcast_bias;
224  kernel_info.activation_info = gemm_info.activation_info();
225  kernel_info.post_ops = gemm_info.post_ops();
226 
227  // Set the target for the kernels
228  _mm_native_kernel->set_target(gpu_target);
229 
231 
232  // Configure and tune matrix multiply kernel
233  _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
234 }
235 
236 void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237  const GEMMInfo &gemm_info)
238 {
240  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
242  const unsigned int n = b->dimension(0);
243  const unsigned int k = a->dimension(0);
244  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
245  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
246  const GPUTarget gpu_target = CLScheduler::get().target();
247  bool broadcast_bias = gemm_info.broadcast_bias();
248 
249  GEMMKernelInfo kernel_info;
250  kernel_info.m = m;
251  kernel_info.n = n;
252  kernel_info.k = k;
253  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
254  kernel_info.reinterpret_input_as_3d = false;
255  kernel_info.broadcast_bias = broadcast_bias;
256  kernel_info.activation_info = gemm_info.activation_info();
257  kernel_info.post_ops = gemm_info.post_ops();
258 
259  // Set the target for the kernels
260  _reshape_lhs_kernel->set_target(gpu_target);
261  _mm_reshaped_kernel->set_target(gpu_target);
262 
263  GEMMLHSMatrixInfo lhs_info{};
264  GEMMRHSMatrixInfo rhs_info{};
265 
266  // Pick up the GEMM configuration
267  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
268  c, output, gemm_info.reinterpret_input_as_3d());
269 
270  _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
271  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
272 
273  // Configure and tune matrix multiply kernel
274  _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
275 
276  // Request memory for LHS and RHS reshape matrix
277  _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
278  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
279 }
280 
281 void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
282  const GEMMInfo &gemm_info)
283 {
285  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
286  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
287  const unsigned int n = b->dimension(0);
288  const unsigned int k = a->dimension(0);
289  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
290  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
291  const GPUTarget gpu_target = CLScheduler::get().target();
292  bool broadcast_bias = gemm_info.broadcast_bias();
293 
294  GEMMKernelInfo kernel_info;
295  kernel_info.m = m;
296  kernel_info.n = n;
297  kernel_info.k = k;
298  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
299  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
300  kernel_info.broadcast_bias = broadcast_bias;
301  kernel_info.activation_info = gemm_info.activation_info();
302  kernel_info.post_ops = gemm_info.post_ops();
303 
304  // Set the target for the kernels
305  _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
306  _mm_reshaped_only_rhs_fallback_kernel->set_target(gpu_target);
307 
308  GEMMLHSMatrixInfo lhs_info{};
309  GEMMRHSMatrixInfo rhs_info{};
310 
311  // Pick up the GEMM configuration
312  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
313 
314  // Transpose matrix
315  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
316 
317  // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
318  // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
319  // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
320 
321  // Configure matrix multiply kernel with no y padding support
322  kernel_info.has_pad_y = false;
323  _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
324 
325  // Configure matrix multiply kernel with y padding support
326  kernel_info.has_pad_y = true;
327  _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
328 
329  // Request memory for RHS reshape matrix
330  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
331 }
332 
333 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
334 {
335  ARM_COMPUTE_UNUSED(alpha);
336  ARM_COMPUTE_UNUSED(output);
337 
338  // Get the GPU target
339  const GPUTarget gpu_target = CLScheduler::get().target();
341  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
342  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
343  const unsigned int n = b->dimension(0);
344  const unsigned int k = a->dimension(0);
345  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
346  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
347  const bool broadcast_bias = gemm_info.broadcast_bias();
348 
349  GEMMKernelInfo kernel_info;
350  kernel_info.m = m;
351  kernel_info.n = n;
352  kernel_info.k = k;
353  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
354  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
355  kernel_info.broadcast_bias = broadcast_bias;
356  kernel_info.activation_info = gemm_info.activation_info();
357  kernel_info.post_ops = gemm_info.post_ops();
358 
360 
361  // Validate matrix multiply
362  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
363 
364  return Status{};
365 }
366 
367 Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
368 {
369  ARM_COMPUTE_UNUSED(alpha);
370  ARM_COMPUTE_UNUSED(output);
371 
372  TensorInfo tmp_a_info{};
373  TensorInfo tmp_b_info{};
374 
375  // Get the GPU target
376  const GPUTarget gpu_target = CLScheduler::get().target();
378  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
379  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
380  const unsigned int n = b->dimension(0);
381  const unsigned int k = a->dimension(0);
382  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
383  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
384  const bool broadcast_bias = gemm_info.broadcast_bias();
385 
386  GEMMKernelInfo kernel_info;
387  kernel_info.m = m;
388  kernel_info.n = n;
389  kernel_info.k = k;
390  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
391  kernel_info.reinterpret_input_as_3d = false;
392  kernel_info.broadcast_bias = broadcast_bias;
393  kernel_info.activation_info = gemm_info.activation_info();
394  kernel_info.post_ops = gemm_info.post_ops();
395 
398 
399  // Pick up the GEMM configuration
400  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
401  const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
402  lhs_info = gemm_config.lhs_info;
403  rhs_info = gemm_config.rhs_info;
404 
405  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
407 
408  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
410 
411  // Validate matrix multiply
412  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
413 
414  return Status{};
415 }
416 
417 Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
418 {
419  ARM_COMPUTE_UNUSED(alpha);
420  ARM_COMPUTE_UNUSED(output);
421 
422  TensorInfo tmp_b_info{};
423 
424  // Get the GPU target
425  const GPUTarget gpu_target = CLScheduler::get().target();
426  const DataType data_type = a->data_type();
427  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
428  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
429  const unsigned int n = b->dimension(0);
430  const unsigned int k = a->dimension(0);
431  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
432  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
433  const bool broadcast_bias = gemm_info.broadcast_bias();
434 
435  GEMMKernelInfo kernel_info;
436  kernel_info.m = m;
437  kernel_info.n = n;
438  kernel_info.k = k;
439  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
440  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
441  kernel_info.broadcast_bias = broadcast_bias;
442  kernel_info.activation_info = gemm_info.activation_info();
443  kernel_info.post_ops = gemm_info.post_ops();
444 
447 
448  // Pick up the GEMM configuration
449  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
450  const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
451  lhs_info = gemm_config.lhs_info;
452  rhs_info = gemm_config.rhs_info;
453 
454  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
456 
457  // Validate matrix multiply
458  kernel_info.has_pad_y = false;
459  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
460 
461  kernel_info.has_pad_y = true;
462  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
463 
464  return Status{};
465 }
466 
467 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
468 {
469  ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
470 
471  // Perform validation step
472  ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
473  ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
474 
475  // Check if we need to reshape the matrix B only on the first run
476  _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
477  _is_prepared = gemm_info.retain_internal_weights();
478 
479  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
480  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
481  const unsigned int n = b->dimension(0);
482  const unsigned int k = a->dimension(0);
483  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
484 
485  // Select GEMMType
486  _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
487  b->are_values_constant());
488 
489  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
490 
491  ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
492 
493  switch(_gemm_kernel_type)
494  {
496  {
497  configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
498  break;
499  }
501  {
502  configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
503  break;
504  }
506  {
507  configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
508  break;
509  }
510  default:
511  {
512  ARM_COMPUTE_ERROR("GEMMType not supported");
513  }
514  }
515 }
516 
517 Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
518 {
519  // Get the GPU target
520  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
521  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
522  const unsigned int n = b->dimension(0);
523  const unsigned int k = a->dimension(0);
524  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
525 
526  // Select GEMMType
527  CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
528  {
529  CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
530  },
532 
533  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
534 
535  const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
536 
537  switch(gemm_kernel_type)
538  {
540  {
541  ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
542  break;
543  }
545  {
546  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
547  break;
548  }
550  {
551  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
552  break;
553  }
554  default:
555  {
556  ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
557  }
558  }
559 
560  return Status{};
561 }
562 
563 void ClGemm::run(ITensorPack &tensors)
564 {
565  const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
566  const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
567  ITensor *dst = tensors.get_tensor(ACL_DST);
568 
570 
571  CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
572  CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
573 
574  // Prepare the consts if needed
575  prepare(tensors);
576 
577  // Run matrix multiply kernel
578  switch(_gemm_kernel_type)
579  {
581  {
582  CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
583  break;
584  }
586  {
587  // Run interleave kernel
588  ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
589  CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
590 
591  if(!_reshape_b_only_on_first_run)
592  {
593  // Run transpose kernel
594  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
595  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
596  }
597  // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
598  ITensorPack gemm_reshaped_pack(tensors);
599  gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
600  gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
601 
602  if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
603  {
604  CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
605  }
606  break;
607  }
609  {
610  if(!_reshape_b_only_on_first_run)
611  {
612  // Run transpose kernel
613  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
614  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
615  }
616  // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
617  // Check if the lhs or dst tensors have padding
618  const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
619  const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
620  bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
621 
622  // Copy original tensor pack and overwrite rhs with reshaped counterpart
623  ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
624  gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
625 
626  if(has_pad_y)
627  {
628  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_fallback_kernel, gemm_reshaped_onlyrhs_pack, true);
629  }
630  else
631  {
632  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
633  }
634  break;
635  }
636  default:
637  {
638  ARM_COMPUTE_ERROR("GEMMType not supported");
639  }
640  }
641 }
642 
643 void ClGemm::prepare(ITensorPack &constants)
644 {
645  if(!_is_prepared)
646  {
647  const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
648  ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
649 
650  // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
651  if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
652  {
653  ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
654 
655  CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
656  ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
657 
658  ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
659  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
660  }
661  _is_prepared = true;
662  }
663 }
664 
666 {
667  return _aux_mem;
668 }
669 } // namespace opencl
670 } // namespace arm_compute
unsigned int top
top of the border
Definition: Types.h:377
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
bool broadcast_bias
Flag used to broadcast the bias addition.
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on default heuristics.
Descriptor used by the GEMM kernels.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
void add_const_tensor(int id, const ITensor *tensor)
Add const tensor to the pack.
Definition: ITensorPack.cpp:49
SimpleTensor< float > b
Definition: DFT.cpp:157
static CLScheduler & get()
Access the scheduler singleton.
#define ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL(msg)
Log an information message to the logger with function name before the message.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
GPUTarget target() const
Get the target GPU.
Definition: CLScheduler.cpp:45
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Initialise the kernel&#39;s inputs and output.
Definition: ClGemm.cpp:467
OpenCL kernel to multiply matrices when only the input matrix RHS (src1) has been reshaped...
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:1938
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
unsigned int bottom
bottom of the border
Definition: Types.h:379
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
Definition: Types.h:2059
Status class.
Definition: Error.h:52
ActivationLayerInfo activation_info
Activation function to perform after the matrix multiplication.
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
Definition: Types.h:2075
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2021 Arm Limited.
std::vector< MemoryInfo > MemoryRequirements
Definition: Types.h:132
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
const DataType data_type
Definition: Im2Col.cpp:150
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:54
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
unsigned int m
Number of LHS rows.
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
unsigned int n
Number of RHS columns.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:1953
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration.
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:250
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual bool are_values_constant() const =0
Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel...
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_src_as_3d)
Static function to check if given info will lead to a valid configuration.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
virtual PaddingSize padding() const =0
Padding of tensor.
std::string to_string(const T &val)
Fallback method: try to use std::to_string:
Definition: TypePrinter.h:79
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:2067
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:2123
CLCompileContext class.
const experimental::PostOpList< ITensorInfo * > & post_ops() const
Post operations to apply after the matrix multiplication.
Definition: Types.h:2163
OpenCL kernel to multiply matrices when neither of the input matrices have been reshaped.
bool has_pad_y
Flag used to indicate if the input/output tensors have internal pad on the y direction.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:64
#define ARM_COMPUTE_RETURN_ERROR_MSG(...)
An error is returned with the given description.
Definition: Error.h:194
Interface for OpenCL tensor.
Definition: ICLTensor.h:42
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular...
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Native GEMM kernel with configurable block size.
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
bool is_zero(float a, float epsilon=0.00001f)
Checks if the input floating point number is 0.0f checking if the difference is within a range define...
Definition: float_ops.h:109
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Definition: ClGemm.cpp:563
OpenCL kernel to reshape the LHS matrix when performing the matrix multiplication.
experimental::PostOpList< ITensorInfo * > post_ops
(EXPERIMENTAL_POST_OPS) Specifies a list of post ops to be fused after the main op.
Tensor packing service.
Definition: ITensorPack.h:39
virtual const cl::Buffer & cl_buffer() const =0
Interface to be implemented by the child class to return a reference to the OpenCL buffer containing ...
#define ARM_COMPUTE_LOG_PARAMS(...)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
Definition: Types.h:2051
void prepare(ITensorPack &constants) override
Prepare the function for executing.
Definition: ClGemm.cpp:643
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
GEMM information class.
Definition: Types.h:1974
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
Definition: ClGemm.cpp:517
DataType
Available data types.
Definition: Types.h:79
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Definition: Types.h:2147
ClGemm()
Constructor.
Definition: ClGemm.cpp:188
OpenCL kernel to multiply matrices when both the input matrices LHS (src0) and RHS (src1) have been r...
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
Definition: ClGemm.cpp:665