Compute Library
 22.11
ClGemm.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Error.h"
32 #include "arm_compute/core/Log.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
40 
45 #include "src/gpu/cl/IClKernel.h"
49 
50 #include "src/common/utils/Log.h"
51 #include "support/Cast.h"
52 #include "utils/TypePrinter.h"
53 
54 namespace arm_compute
55 {
56 namespace opencl
57 {
59 using namespace arm_compute::cl_gemm;
60 using namespace arm_compute::experimental;
61 using namespace arm_compute::utils::cast;
62 using namespace arm_compute::opencl::kernels;
63 
64 namespace
65 {
66 inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67 {
68  return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
69 }
70 //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
71 inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
72 {
73  if(!constant_weights)
74  {
76  }
77 
78  auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79  if(bool(gemm_kernel))
80  {
81  if(validate_gemm_kernel(gemm_kernel.gemm_type))
82  {
83  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84  return gemm_kernel.gemm_type;
85  }
86  }
87  gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89  return gemm_kernel.gemm_type;
90 }
91 // Validate lhs_info and rhs_info for reshaped only rhs kernel
92 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94 {
95  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96  TensorInfo tmp_b_info{};
97  // Validate reshape RHS kernel
98  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100  {
101  return false;
102  }
103  // Validate mm kernel
104  gemm_kernel_info.lhs_info = lhs_info;
105  gemm_kernel_info.rhs_info = rhs_info;
106  gemm_kernel_info.has_pad_y = false;
107  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108  {
109  return false;
110  }
111  gemm_kernel_info.has_pad_y = true;
112  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113  {
114  return false;
115  }
116  return true;
117 }
118 
119 //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121  const ITensorInfo *b,
122  const ITensorInfo *c, const ITensorInfo *output)
123 {
125  if(config)
126  {
127  if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128  {
129  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130  return { config.lhs_info, config.rhs_info };
131  }
132  }
134  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135  return { config.lhs_info, config.rhs_info };
136 }
137 
138 // Validate lhs_info and rhs_info for reshaped kernel
139 inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141 {
142  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143  TensorInfo tmp_a_info{};
144  TensorInfo tmp_b_info{};
145 
146  // Validate reshape LHS kernel
147  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148  if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149  {
150  return false;
151  }
152 
153  // Validate reshape RHS kernel
154  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156  {
157  return false;
158  }
159  // Validate mm kernel
160  gemm_kernel_info.lhs_info = lhs_info;
161  gemm_kernel_info.rhs_info = rhs_info;
162  if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163  {
164  return false;
165  }
166  return true;
167 }
168 
169 //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171  const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172 {
174  if(config)
175  {
176  if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177  {
178  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179  return { config.lhs_info, config.rhs_info };
180  }
181  }
183  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184  return { config.lhs_info, config.rhs_info };
185 }
186 } // namespace
187 
189  : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
190  _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
191  _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
192  _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193  _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
194  _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
195  _tmp_a(),
196  _tmp_b(),
197  _reshape_b_only_on_first_run(false),
198  _gemm_kernel_type(CLGEMMKernelType::NATIVE),
199  _is_prepared(false),
200  _aux_mem(AuxTensorIdx::Count)
201 {
202 }
203 
204 void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
205  const GEMMInfo &gemm_info)
206 {
208  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
209  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
210  const unsigned int n = b->dimension(0);
211  const unsigned int k = a->dimension(0);
212  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
213  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
214  const GPUTarget gpu_target = CLScheduler::get().target();
215  bool broadcast_bias = gemm_info.broadcast_bias();
216 
217  GEMMKernelInfo kernel_info;
218  kernel_info.m = m;
219  kernel_info.n = n;
220  kernel_info.k = k;
221  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
222  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
223  kernel_info.broadcast_bias = broadcast_bias;
224  kernel_info.activation_info = gemm_info.activation_info();
225  kernel_info.post_ops = gemm_info.post_ops();
226 
227  // Set the target for the kernels
228  _mm_native_kernel->set_target(gpu_target);
229 
231 
232  // Configure and tune matrix multiply kernel
233  _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
234 }
235 
236 void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237  const GEMMInfo &gemm_info)
238 {
240  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
242  const unsigned int n = b->dimension(0);
243  const unsigned int k = a->dimension(0);
244  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
245  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
246  const GPUTarget gpu_target = CLScheduler::get().target();
247  bool broadcast_bias = gemm_info.broadcast_bias();
248 
249  GEMMKernelInfo kernel_info;
250  kernel_info.m = m;
251  kernel_info.n = n;
252  kernel_info.k = k;
253  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
254  kernel_info.reinterpret_input_as_3d = false;
255  kernel_info.broadcast_bias = broadcast_bias;
256  kernel_info.activation_info = gemm_info.activation_info();
257  kernel_info.post_ops = gemm_info.post_ops();
258 
259  // Set the target for the kernels
260  _reshape_lhs_kernel->set_target(gpu_target);
261  _mm_reshaped_kernel->set_target(gpu_target);
262 
263  GEMMLHSMatrixInfo lhs_info{};
264  GEMMRHSMatrixInfo rhs_info{};
265 
266  // Pick up the GEMM configuration
267  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
268  c, output, gemm_info.reinterpret_input_as_3d());
269 
270  _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
271  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
272 
273  // Configure and tune matrix multiply kernel
274  _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
275 
276  // Request memory for LHS and RHS reshape matrix
277  _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
278  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
279 }
280 
281 void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
282  const GEMMInfo &gemm_info)
283 {
285  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
286  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
287  const unsigned int n = b->dimension(0);
288  const unsigned int k = a->dimension(0);
289  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
290  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
291  const GPUTarget gpu_target = CLScheduler::get().target();
292  bool broadcast_bias = gemm_info.broadcast_bias();
293 
294  GEMMKernelInfo kernel_info;
295  kernel_info.m = m;
296  kernel_info.n = n;
297  kernel_info.k = k;
298  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
299  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
300  kernel_info.broadcast_bias = broadcast_bias;
301  kernel_info.activation_info = gemm_info.activation_info();
302  kernel_info.post_ops = gemm_info.post_ops();
303 
304  // Set the target for the kernels
305  _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
306 
307  GEMMLHSMatrixInfo lhs_info{};
308  GEMMRHSMatrixInfo rhs_info{};
309 
310  // Pick up the GEMM configuration
311  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
312 
313  // Transpose matrix
314  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
315 
316  // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
317  // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
318  // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
319 
320  // Configure matrix multiply kernel with no y padding support
321  kernel_info.has_pad_y = false;
322  _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
323 
324  // Request memory for RHS reshape matrix
325  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
326 }
327 
328 void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
329  const GEMMInfo &gemm_info)
330 {
332  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
333  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
334  const unsigned int n = b->dimension(0);
335  const unsigned int k = a->dimension(0);
336  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
337  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
338  const GPUTarget gpu_target = CLScheduler::get().target();
339  bool broadcast_bias = gemm_info.broadcast_bias();
340 
341  GEMMKernelInfo kernel_info;
342  kernel_info.m = m;
343  kernel_info.n = n;
344  kernel_info.k = k;
345  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
346  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
347  kernel_info.broadcast_bias = broadcast_bias;
348  kernel_info.activation_info = gemm_info.activation_info();
349  kernel_info.post_ops = gemm_info.post_ops();
350 
351  // Set the target for the kernels
352  _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
353 
354  GEMMLHSMatrixInfo lhs_info{};
355  GEMMRHSMatrixInfo rhs_info{};
356 
357  // Pick up the GEMM configuration
358  auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
359  lhs_info = gemm_config.lhs_info;
360  rhs_info = gemm_config.rhs_info;
361  // Force H0 to 4 in order to use the MMUL extension
362  rhs_info.h0 = 4;
363 
364  // Reshape Rhs matrix
365  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
366 
367  // Configure matrix multiply kernel with no y padding support
368  kernel_info.has_pad_y = false;
369  _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
370 
371  // Request memory for RHS reshape matrix
372  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
373 }
374 
375 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
376 {
377  ARM_COMPUTE_UNUSED(alpha);
378  ARM_COMPUTE_UNUSED(output);
379 
380  // Get the GPU target
381  const GPUTarget gpu_target = CLScheduler::get().target();
383  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
384  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
385  const unsigned int n = b->dimension(0);
386  const unsigned int k = a->dimension(0);
387  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
388  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
389  const bool broadcast_bias = gemm_info.broadcast_bias();
390 
391  GEMMKernelInfo kernel_info;
392  kernel_info.m = m;
393  kernel_info.n = n;
394  kernel_info.k = k;
395  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
396  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
397  kernel_info.broadcast_bias = broadcast_bias;
398  kernel_info.activation_info = gemm_info.activation_info();
399  kernel_info.post_ops = gemm_info.post_ops();
400 
402 
403  // Validate matrix multiply
404  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
405 
406  return Status{};
407 }
408 
409 Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
410 {
411  ARM_COMPUTE_UNUSED(alpha);
412  ARM_COMPUTE_UNUSED(output);
413 
414  TensorInfo tmp_a_info{};
415  TensorInfo tmp_b_info{};
416 
417  // Get the GPU target
418  const GPUTarget gpu_target = CLScheduler::get().target();
420  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
421  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
422  const unsigned int n = b->dimension(0);
423  const unsigned int k = a->dimension(0);
424  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
425  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
426  const bool broadcast_bias = gemm_info.broadcast_bias();
427 
428  GEMMKernelInfo kernel_info;
429  kernel_info.m = m;
430  kernel_info.n = n;
431  kernel_info.k = k;
432  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
433  kernel_info.reinterpret_input_as_3d = false;
434  kernel_info.broadcast_bias = broadcast_bias;
435  kernel_info.activation_info = gemm_info.activation_info();
436  kernel_info.post_ops = gemm_info.post_ops();
437 
440 
441  // Pick up the GEMM configuration
442  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
443  const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
444  lhs_info = gemm_config.lhs_info;
445  rhs_info = gemm_config.rhs_info;
446 
447  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
449 
450  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
452 
453  // Validate matrix multiply
454  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
455 
456  return Status{};
457 }
458 
459 Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
460 {
461  ARM_COMPUTE_UNUSED(alpha);
462  ARM_COMPUTE_UNUSED(output);
463 
464  TensorInfo tmp_b_info{};
465 
466  // Get the GPU target
467  const GPUTarget gpu_target = CLScheduler::get().target();
468  const DataType data_type = a->data_type();
469  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
470  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
471  const unsigned int n = b->dimension(0);
472  const unsigned int k = a->dimension(0);
473  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
474  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
475  const bool broadcast_bias = gemm_info.broadcast_bias();
476 
477  GEMMKernelInfo kernel_info;
478  kernel_info.m = m;
479  kernel_info.n = n;
480  kernel_info.k = k;
481  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
482  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
483  kernel_info.broadcast_bias = broadcast_bias;
484  kernel_info.activation_info = gemm_info.activation_info();
485  kernel_info.post_ops = gemm_info.post_ops();
486 
489 
490  // Pick up the GEMM configuration
491  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
492  const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
493  lhs_info = gemm_config.lhs_info;
494  rhs_info = gemm_config.rhs_info;
495 
496  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
498 
499  // Validate matrix multiply
500  kernel_info.has_pad_y = false;
501  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
502 
503  kernel_info.has_pad_y = true;
504  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
505 
506  return Status{};
507 }
508 
509 Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
510 {
511  ARM_COMPUTE_UNUSED(alpha);
512  ARM_COMPUTE_UNUSED(output);
513  TensorInfo tmp_b_info{};
514 
515  // Get the GPU target
516  const GPUTarget gpu_target = CLScheduler::get().target();
517  const DataType data_type = a->data_type();
518  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
519  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
520  const unsigned int n = b->dimension(0);
521  const unsigned int k = a->dimension(0);
522  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
523  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
524  const bool broadcast_bias = gemm_info.broadcast_bias();
525 
526  GEMMKernelInfo kernel_info;
527  kernel_info.m = m;
528  kernel_info.n = n;
529  kernel_info.k = k;
530  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
531  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
532  kernel_info.broadcast_bias = broadcast_bias;
533  kernel_info.activation_info = gemm_info.activation_info();
534  kernel_info.post_ops = gemm_info.post_ops();
535 
538 
539  // Pick up the GEMM configuration
540  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
541  const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
542  lhs_info = gemm_config.lhs_info;
543  rhs_info = gemm_config.rhs_info;
544  // Force H0 to 4 in order to use the MMUL extension
545  rhs_info.h0 = 4;
546 
547  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
549 
550  // Validate matrix multiply
551  kernel_info.has_pad_y = false;
552  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
553 
554  return Status{};
555 }
556 
557 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
558 {
559  ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
560 
561  // Perform validation step
562  ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
563  ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
564 
565  // Check if we need to reshape the matrix B only on the first run
566  _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
567  _is_prepared = gemm_info.retain_internal_weights();
568 
569  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
570  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
571  const unsigned int n = b->dimension(0);
572  const unsigned int k = a->dimension(0);
573  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
574 
575  // Select GEMMType
576  _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
577  b->are_values_constant());
578 
579  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
580 
581  ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
582 
583  switch(_gemm_kernel_type)
584  {
586  {
587  configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
588  break;
589  }
591  {
592  configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
593  break;
594  }
596  {
597  configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
598  break;
599  }
601  {
602  configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
603  break;
604  }
605  default:
606  {
607  ARM_COMPUTE_ERROR("GEMMType not supported");
608  }
609  }
610 }
611 
612 Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
613 {
614  // Get the GPU target
615  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
616  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
617  const unsigned int n = b->dimension(0);
618  const unsigned int k = a->dimension(0);
619  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
620 
621  // Select GEMMType
622  CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
623  {
624  CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
625  },
627 
628  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
629 
630  const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
631 
632  switch(gemm_kernel_type)
633  {
635  {
636  ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
637  break;
638  }
640  {
641  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
642  break;
643  }
645  {
646  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
647  break;
648  }
650  {
651  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
652  break;
653  }
654  default:
655  {
656  ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
657  }
658  }
659 
660  return Status{};
661 }
662 
663 void ClGemm::run(ITensorPack &tensors)
664 {
665  const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
666  const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
667  ITensor *dst = tensors.get_tensor(ACL_DST);
668 
670 
671  CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
672  CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
673 
674  // Prepare the consts if needed
675  prepare(tensors);
676 
677  // Run matrix multiply kernel
678  switch(_gemm_kernel_type)
679  {
681  {
682  CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
683  break;
684  }
686  {
687  // Run interleave kernel
688  ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
689  CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
690 
691  if(!_reshape_b_only_on_first_run)
692  {
693  // Run transpose kernel
694  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
695  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
696  }
697  // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
698  ITensorPack gemm_reshaped_pack(tensors);
699  gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
700  gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
701 
702  if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
703  {
704  CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
705  }
706  break;
707  }
709  {
710  if(!_reshape_b_only_on_first_run)
711  {
712  // Run transpose kernel
713  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
714  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
715  }
716  // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
717  // Check if the lhs or dst tensors have padding
718  const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
719  const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
720  bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
721 
722  // Copy original tensor pack and overwrite rhs with reshaped counterpart
723  ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
724  gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
725 
726  if(has_pad_y)
727  {
728  ARM_COMPUTE_ERROR_ON(has_pad_y);
729  }
730  else
731  {
732  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
733  }
734  break;
735  }
737  {
738  if(!_reshape_b_only_on_first_run)
739  {
740  // Run transpose kernel
741  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
742  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
743  }
744  // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
745  // Check if the lhs or dst tensors have padding
746  const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
747  const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
748  bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
749 
750  // Copy original tensor pack and overwrite rhs with reshaped counterpart
751  ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
752  gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
753 
754  if(has_pad_y)
755  {
756  ARM_COMPUTE_ERROR_ON(has_pad_y);
757  }
758  else
759  {
760  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
761  }
762  break;
763  }
764  default:
765  {
766  ARM_COMPUTE_ERROR("GEMMType not supported");
767  }
768  }
769 }
770 
771 void ClGemm::prepare(ITensorPack &constants)
772 {
773  if(!_is_prepared)
774  {
775  const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
776  ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
777 
778  // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
779  if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
780  {
781  ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
782 
783  CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
784  ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
785 
786  ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
787  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
788  }
789  _is_prepared = true;
790  }
791 }
792 
794 {
795  return _aux_mem;
796 }
797 } // namespace opencl
798 } // namespace arm_compute
unsigned int top
top of the border
Definition: Types.h:390
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
bool broadcast_bias
Flag used to broadcast the bias addition.
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on default heuristics.
Descriptor used by the GEMM kernels.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
void add_const_tensor(int id, const ITensor *tensor)
Add const tensor to the pack.
Definition: ITensorPack.cpp:49
SimpleTensor< float > b
Definition: DFT.cpp:157
static CLScheduler & get()
Access the scheduler singleton.
#define ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL(msg)
Log an information message to the logger with function name before the message.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
GPUTarget target() const
Get the target GPU.
Definition: CLScheduler.cpp:49
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Initialise the kernel&#39;s inputs and output.
Definition: ClGemm.cpp:557
OpenCL kernel to multiply matrices when only the input matrix RHS (src1) has been reshaped...
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Definition: Types.h:2327
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:2303
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
Reshaped GEMM kernel where only the rhs matrix is reshaped.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
unsigned int bottom
bottom of the border
Definition: Types.h:392
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
Definition: Types.h:2433
Status class.
Definition: Error.h:52
ActivationLayerInfo activation_info
Activation function to perform after the matrix multiplication.
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
Definition: Types.h:2449
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2022 Arm Limited.
std::vector< MemoryInfo > MemoryRequirements
Definition: Types.h:134
OpenCL kernel to multiply matrices using MMUL when only the input matrix RHS (src1) has been reshaped...
GEMMLHSMatrixInfo lhs_info
If the result is valid.
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:54
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
unsigned int m
Number of LHS rows.
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
unsigned int n
Number of RHS columns.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:2318
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration.
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:250
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual bool are_values_constant() const =0
Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel...
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_src_as_3d)
Static function to check if given info will lead to a valid configuration.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
virtual PaddingSize padding() const =0
Padding of tensor.
std::string to_string(const T &val)
Fallback method: try to use std::to_string:
Definition: TypePrinter.h:80
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:2441
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:2497
CLCompileContext class.
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
const experimental::PostOpList< ITensorInfo * > & post_ops() const
Post operations to apply after the matrix multiplication.
Definition: Types.h:2553
OpenCL kernel to multiply matrices when neither of the input matrices have been reshaped.
bool has_pad_y
Flag used to indicate if the input/output tensors have internal pad on the y direction.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:64
#define ARM_COMPUTE_RETURN_ERROR_MSG(...)
An error is returned with the given description.
Definition: Error.h:194
Interface for OpenCL tensor.
Definition: ICLTensor.h:42
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular...
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Native GEMM kernel with configurable block size.
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
bool is_zero(float a, float epsilon=0.00001f)
Checks if the input floating point number is 0.0f checking if the difference is within a range define...
Definition: float_ops.h:109
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Definition: ClGemm.cpp:663
OpenCL kernel to reshape the LHS matrix when performing the matrix multiplication.
experimental::PostOpList< ITensorInfo * > post_ops
(EXPERIMENTAL_POST_OPS) Specifies a list of post ops to be fused after the main op.
Tensor packing service.
Definition: ITensorPack.h:39
virtual const cl::Buffer & cl_buffer() const =0
Interface to be implemented by the child class to return a reference to the OpenCL buffer containing ...
#define ARM_COMPUTE_LOG_PARAMS(...)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
Definition: Types.h:2425
void prepare(ITensorPack &constants) override
Prepare the function for executing.
Definition: ClGemm.cpp:771
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
GEMM information class.
Definition: Types.h:2339
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
Definition: ClGemm.cpp:612
DataType
Available data types.
Definition: Types.h:79
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Definition: Types.h:2537
ClGemm()
Constructor.
Definition: ClGemm.cpp:188
OpenCL kernel to multiply matrices when both the input matrices LHS (src0) and RHS (src1) have been r...
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
Definition: ClGemm.cpp:793