Compute Library
 22.05
ClGemm.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Error.h"
32 #include "arm_compute/core/Log.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
40 
45 #include "src/gpu/cl/IClKernel.h"
49 
50 #include "src/common/utils/Log.h"
51 #include "support/Cast.h"
52 #include "utils/TypePrinter.h"
53 
54 namespace arm_compute
55 {
56 namespace opencl
57 {
59 using namespace arm_compute::cl_gemm;
60 using namespace arm_compute::experimental;
61 using namespace arm_compute::utils::cast;
62 using namespace arm_compute::opencl::kernels;
63 
64 namespace
65 {
66 inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67 {
68  return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
69 }
70 //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
71 inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
72 {
73  if(!constant_weights)
74  {
76  }
77 
78  auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79  if(bool(gemm_kernel))
80  {
81  if(validate_gemm_kernel(gemm_kernel.gemm_type))
82  {
83  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84  return gemm_kernel.gemm_type;
85  }
86  }
87  gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89  return gemm_kernel.gemm_type;
90 }
91 // Validate lhs_info and rhs_info for reshaped only rhs kernel
92 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94 {
95  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96  TensorInfo tmp_b_info{};
97  // Validate reshape RHS kernel
98  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100  {
101  return false;
102  }
103  // Validate mm kernel
104  gemm_kernel_info.lhs_info = lhs_info;
105  gemm_kernel_info.rhs_info = rhs_info;
106  gemm_kernel_info.has_pad_y = false;
107  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108  {
109  return false;
110  }
111  gemm_kernel_info.has_pad_y = true;
112  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113  {
114  return false;
115  }
116  return true;
117 }
118 
119 //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121  const ITensorInfo *b,
122  const ITensorInfo *c, const ITensorInfo *output)
123 {
125  if(config)
126  {
127  if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128  {
129  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130  return { config.lhs_info, config.rhs_info };
131  }
132  }
134  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135  return { config.lhs_info, config.rhs_info };
136 }
137 
138 // Validate lhs_info and rhs_info for reshaped kernel
139 inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141 {
142  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143  TensorInfo tmp_a_info{};
144  TensorInfo tmp_b_info{};
145 
146  // Validate reshape LHS kernel
147  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148  if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149  {
150  return false;
151  }
152 
153  // Validate reshape RHS kernel
154  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156  {
157  return false;
158  }
159  // Validate mm kernel
160  gemm_kernel_info.lhs_info = lhs_info;
161  gemm_kernel_info.rhs_info = rhs_info;
162  if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163  {
164  return false;
165  }
166  return true;
167 }
168 
169 //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171  const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172 {
174  if(config)
175  {
176  if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177  {
178  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179  return { config.lhs_info, config.rhs_info };
180  }
181  }
183  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184  return { config.lhs_info, config.rhs_info };
185 }
186 } // namespace
187 
189  : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
190  _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
191  _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
192  _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193  _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
194  _tmp_a(),
195  _tmp_b(),
196  _reshape_b_only_on_first_run(false),
197  _gemm_kernel_type(CLGEMMKernelType::NATIVE),
198  _is_prepared(false),
199  _aux_mem(AuxTensorIdx::Count)
200 {
201 }
202 
203 void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
204  const GEMMInfo &gemm_info)
205 {
207  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
208  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
209  const unsigned int n = b->dimension(0);
210  const unsigned int k = a->dimension(0);
211  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
212  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
213  const GPUTarget gpu_target = CLScheduler::get().target();
214  bool broadcast_bias = gemm_info.broadcast_bias();
215 
216  GEMMKernelInfo kernel_info;
217  kernel_info.m = m;
218  kernel_info.n = n;
219  kernel_info.k = k;
220  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
221  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
222  kernel_info.broadcast_bias = broadcast_bias;
223  kernel_info.activation_info = gemm_info.activation_info();
224  kernel_info.post_ops = gemm_info.post_ops();
225 
226  // Set the target for the kernels
227  _mm_native_kernel->set_target(gpu_target);
228 
230 
231  // Configure and tune matrix multiply kernel
232  _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
233 }
234 
235 void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
236  const GEMMInfo &gemm_info)
237 {
239  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
240  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
241  const unsigned int n = b->dimension(0);
242  const unsigned int k = a->dimension(0);
243  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
244  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
245  const GPUTarget gpu_target = CLScheduler::get().target();
246  bool broadcast_bias = gemm_info.broadcast_bias();
247 
248  GEMMKernelInfo kernel_info;
249  kernel_info.m = m;
250  kernel_info.n = n;
251  kernel_info.k = k;
252  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
253  kernel_info.reinterpret_input_as_3d = false;
254  kernel_info.broadcast_bias = broadcast_bias;
255  kernel_info.activation_info = gemm_info.activation_info();
256  kernel_info.post_ops = gemm_info.post_ops();
257 
258  // Set the target for the kernels
259  _reshape_lhs_kernel->set_target(gpu_target);
260  _mm_reshaped_kernel->set_target(gpu_target);
261 
262  GEMMLHSMatrixInfo lhs_info{};
263  GEMMRHSMatrixInfo rhs_info{};
264 
265  // Pick up the GEMM configuration
266  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
267  c, output, gemm_info.reinterpret_input_as_3d());
268 
269  _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
270  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
271 
272  // Configure and tune matrix multiply kernel
273  _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
274 
275  // Request memory for LHS and RHS reshape matrix
276  _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
277  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
278 }
279 
280 void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
281  const GEMMInfo &gemm_info)
282 {
284  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
285  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
286  const unsigned int n = b->dimension(0);
287  const unsigned int k = a->dimension(0);
288  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
289  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
290  const GPUTarget gpu_target = CLScheduler::get().target();
291  bool broadcast_bias = gemm_info.broadcast_bias();
292 
293  GEMMKernelInfo kernel_info;
294  kernel_info.m = m;
295  kernel_info.n = n;
296  kernel_info.k = k;
297  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
298  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
299  kernel_info.broadcast_bias = broadcast_bias;
300  kernel_info.activation_info = gemm_info.activation_info();
301  kernel_info.post_ops = gemm_info.post_ops();
302 
303  // Set the target for the kernels
304  _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
305 
306  GEMMLHSMatrixInfo lhs_info{};
307  GEMMRHSMatrixInfo rhs_info{};
308 
309  // Pick up the GEMM configuration
310  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
311 
312  // Transpose matrix
313  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
314 
315  // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
316  // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
317  // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
318 
319  // Configure matrix multiply kernel with no y padding support
320  kernel_info.has_pad_y = false;
321  _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
322 
323  // Request memory for RHS reshape matrix
324  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
325 }
326 
327 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
328 {
329  ARM_COMPUTE_UNUSED(alpha);
330  ARM_COMPUTE_UNUSED(output);
331 
332  // Get the GPU target
333  const GPUTarget gpu_target = CLScheduler::get().target();
335  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
336  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
337  const unsigned int n = b->dimension(0);
338  const unsigned int k = a->dimension(0);
339  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
340  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
341  const bool broadcast_bias = gemm_info.broadcast_bias();
342 
343  GEMMKernelInfo kernel_info;
344  kernel_info.m = m;
345  kernel_info.n = n;
346  kernel_info.k = k;
347  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
348  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
349  kernel_info.broadcast_bias = broadcast_bias;
350  kernel_info.activation_info = gemm_info.activation_info();
351  kernel_info.post_ops = gemm_info.post_ops();
352 
354 
355  // Validate matrix multiply
356  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
357 
358  return Status{};
359 }
360 
361 Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
362 {
363  ARM_COMPUTE_UNUSED(alpha);
364  ARM_COMPUTE_UNUSED(output);
365 
366  TensorInfo tmp_a_info{};
367  TensorInfo tmp_b_info{};
368 
369  // Get the GPU target
370  const GPUTarget gpu_target = CLScheduler::get().target();
372  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
373  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
374  const unsigned int n = b->dimension(0);
375  const unsigned int k = a->dimension(0);
376  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
377  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
378  const bool broadcast_bias = gemm_info.broadcast_bias();
379 
380  GEMMKernelInfo kernel_info;
381  kernel_info.m = m;
382  kernel_info.n = n;
383  kernel_info.k = k;
384  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
385  kernel_info.reinterpret_input_as_3d = false;
386  kernel_info.broadcast_bias = broadcast_bias;
387  kernel_info.activation_info = gemm_info.activation_info();
388  kernel_info.post_ops = gemm_info.post_ops();
389 
392 
393  // Pick up the GEMM configuration
394  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
395  const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
396  lhs_info = gemm_config.lhs_info;
397  rhs_info = gemm_config.rhs_info;
398 
399  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
401 
402  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
404 
405  // Validate matrix multiply
406  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
407 
408  return Status{};
409 }
410 
411 Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
412 {
413  ARM_COMPUTE_UNUSED(alpha);
414  ARM_COMPUTE_UNUSED(output);
415 
416  TensorInfo tmp_b_info{};
417 
418  // Get the GPU target
419  const GPUTarget gpu_target = CLScheduler::get().target();
420  const DataType data_type = a->data_type();
421  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
422  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
423  const unsigned int n = b->dimension(0);
424  const unsigned int k = a->dimension(0);
425  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
426  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
427  const bool broadcast_bias = gemm_info.broadcast_bias();
428 
429  GEMMKernelInfo kernel_info;
430  kernel_info.m = m;
431  kernel_info.n = n;
432  kernel_info.k = k;
433  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
434  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
435  kernel_info.broadcast_bias = broadcast_bias;
436  kernel_info.activation_info = gemm_info.activation_info();
437  kernel_info.post_ops = gemm_info.post_ops();
438 
441 
442  // Pick up the GEMM configuration
443  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
444  const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
445  lhs_info = gemm_config.lhs_info;
446  rhs_info = gemm_config.rhs_info;
447 
448  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
450 
451  // Validate matrix multiply
452  kernel_info.has_pad_y = false;
453  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
454 
455  kernel_info.has_pad_y = true;
456  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
457 
458  return Status{};
459 }
460 
461 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
462 {
463  ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
464 
465  // Perform validation step
466  ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
467  ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
468 
469  // Check if we need to reshape the matrix B only on the first run
470  _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
471  _is_prepared = gemm_info.retain_internal_weights();
472 
473  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
474  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
475  const unsigned int n = b->dimension(0);
476  const unsigned int k = a->dimension(0);
477  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
478 
479  // Select GEMMType
480  _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
481  b->are_values_constant());
482 
483  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
484 
485  ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
486 
487  switch(_gemm_kernel_type)
488  {
490  {
491  configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
492  break;
493  }
495  {
496  configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
497  break;
498  }
500  {
501  configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
502  break;
503  }
504  default:
505  {
506  ARM_COMPUTE_ERROR("GEMMType not supported");
507  }
508  }
509 }
510 
511 Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
512 {
513  // Get the GPU target
514  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
515  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
516  const unsigned int n = b->dimension(0);
517  const unsigned int k = a->dimension(0);
518  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
519 
520  // Select GEMMType
521  CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
522  {
523  CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
524  },
526 
527  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
528 
529  const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
530 
531  switch(gemm_kernel_type)
532  {
534  {
535  ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
536  break;
537  }
539  {
540  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
541  break;
542  }
544  {
545  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
546  break;
547  }
548  default:
549  {
550  ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
551  }
552  }
553 
554  return Status{};
555 }
556 
557 void ClGemm::run(ITensorPack &tensors)
558 {
559  const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
560  const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
561  ITensor *dst = tensors.get_tensor(ACL_DST);
562 
564 
565  CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
566  CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
567 
568  // Prepare the consts if needed
569  prepare(tensors);
570 
571  // Run matrix multiply kernel
572  switch(_gemm_kernel_type)
573  {
575  {
576  CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
577  break;
578  }
580  {
581  // Run interleave kernel
582  ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
583  CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
584 
585  if(!_reshape_b_only_on_first_run)
586  {
587  // Run transpose kernel
588  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
589  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
590  }
591  // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
592  ITensorPack gemm_reshaped_pack(tensors);
593  gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
594  gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
595 
596  if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
597  {
598  CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
599  }
600  break;
601  }
603  {
604  if(!_reshape_b_only_on_first_run)
605  {
606  // Run transpose kernel
607  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
608  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
609  }
610  // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
611  // Check if the lhs or dst tensors have padding
612  const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
613  const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
614  bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
615 
616  // Copy original tensor pack and overwrite rhs with reshaped counterpart
617  ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
618  gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
619 
620  if(has_pad_y)
621  {
622  ARM_COMPUTE_ERROR_ON(has_pad_y);
623  }
624  else
625  {
626  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
627  }
628  break;
629  }
630  default:
631  {
632  ARM_COMPUTE_ERROR("GEMMType not supported");
633  }
634  }
635 }
636 
637 void ClGemm::prepare(ITensorPack &constants)
638 {
639  if(!_is_prepared)
640  {
641  const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
642  ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
643 
644  // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
645  if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
646  {
647  ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
648 
649  CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
650  ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
651 
652  ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
653  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
654  }
655  _is_prepared = true;
656  }
657 }
658 
660 {
661  return _aux_mem;
662 }
663 } // namespace opencl
664 } // namespace arm_compute
unsigned int top
top of the border
Definition: Types.h:390
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
bool broadcast_bias
Flag used to broadcast the bias addition.
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on default heuristics.
Descriptor used by the GEMM kernels.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
void add_const_tensor(int id, const ITensor *tensor)
Add const tensor to the pack.
Definition: ITensorPack.cpp:49
SimpleTensor< float > b
Definition: DFT.cpp:157
static CLScheduler & get()
Access the scheduler singleton.
#define ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL(msg)
Log an information message to the logger with function name before the message.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
GPUTarget target() const
Get the target GPU.
Definition: CLScheduler.cpp:49
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Initialise the kernel&#39;s inputs and output.
Definition: ClGemm.cpp:461
OpenCL kernel to multiply matrices when only the input matrix RHS (src1) has been reshaped...
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:2054
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
unsigned int bottom
bottom of the border
Definition: Types.h:392
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
Definition: Types.h:2177
Status class.
Definition: Error.h:52
ActivationLayerInfo activation_info
Activation function to perform after the matrix multiplication.
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
Definition: Types.h:2193
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2022 Arm Limited.
std::vector< MemoryInfo > MemoryRequirements
Definition: Types.h:134
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:54
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
unsigned int m
Number of LHS rows.
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
unsigned int n
Number of RHS columns.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:2069
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration.
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:250
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual bool are_values_constant() const =0
Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel...
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_src_as_3d)
Static function to check if given info will lead to a valid configuration.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
virtual PaddingSize padding() const =0
Padding of tensor.
std::string to_string(const T &val)
Fallback method: try to use std::to_string:
Definition: TypePrinter.h:79
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:2185
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:2241
CLCompileContext class.
const experimental::PostOpList< ITensorInfo * > & post_ops() const
Post operations to apply after the matrix multiplication.
Definition: Types.h:2297
OpenCL kernel to multiply matrices when neither of the input matrices have been reshaped.
bool has_pad_y
Flag used to indicate if the input/output tensors have internal pad on the y direction.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:64
#define ARM_COMPUTE_RETURN_ERROR_MSG(...)
An error is returned with the given description.
Definition: Error.h:194
Interface for OpenCL tensor.
Definition: ICLTensor.h:42
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular...
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Native GEMM kernel with configurable block size.
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
bool is_zero(float a, float epsilon=0.00001f)
Checks if the input floating point number is 0.0f checking if the difference is within a range define...
Definition: float_ops.h:109
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Definition: ClGemm.cpp:557
OpenCL kernel to reshape the LHS matrix when performing the matrix multiplication.
experimental::PostOpList< ITensorInfo * > post_ops
(EXPERIMENTAL_POST_OPS) Specifies a list of post ops to be fused after the main op.
Tensor packing service.
Definition: ITensorPack.h:39
virtual const cl::Buffer & cl_buffer() const =0
Interface to be implemented by the child class to return a reference to the OpenCL buffer containing ...
#define ARM_COMPUTE_LOG_PARAMS(...)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
Definition: Types.h:2169
void prepare(ITensorPack &constants) override
Prepare the function for executing.
Definition: ClGemm.cpp:637
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
GEMM information class.
Definition: Types.h:2090
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
Definition: ClGemm.cpp:511
DataType
Available data types.
Definition: Types.h:79
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Definition: Types.h:2281
ClGemm()
Constructor.
Definition: ClGemm.cpp:188
OpenCL kernel to multiply matrices when both the input matrices LHS (src0) and RHS (src1) have been r...
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
Definition: ClGemm.cpp:659