Compute Library
 21.08
ClGemm.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Error.h"
32 #include "arm_compute/core/Log.h"
34 #include "arm_compute/core/Types.h"
35 #include "arm_compute/core/Utils.h"
40 
41 #include "src/common/utils/Log.h"
49 
50 #include "support/Cast.h"
51 #include "utils/TypePrinter.h"
52 
53 namespace arm_compute
54 {
55 namespace opencl
56 {
58 using namespace arm_compute::cl_gemm;
59 using namespace arm_compute::experimental;
60 using namespace arm_compute::utils::cast;
61 using namespace arm_compute::opencl::kernels;
62 
63 namespace
64 {
65 inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
66 {
67  switch(kernel_type)
68  {
73  {
74  return true;
75  }
76  default:
77  {
78  return false;
79  }
80  }
81 }
82 //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
83 inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
84 {
85  if(!constant_weights)
86  {
88  }
89 
90  auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
91  if(bool(gemm_kernel))
92  {
93  if(validate_gemm_kernel(gemm_kernel.gemm_type))
94  {
95  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
96  return gemm_kernel.gemm_type;
97  }
98  }
99  gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
100  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
101  return gemm_kernel.gemm_type;
102 }
103 // Validate lhs_info and rhs_info for reshaped only rhs kernel
104 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
105  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
106 {
107  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
108  TensorInfo tmp_b_info{};
109  // Validate reshape RHS kernel
110  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
111  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
112  {
113  return false;
114  }
115  // Validate mm kernel
116  gemm_kernel_info.lhs_info = lhs_info;
117  gemm_kernel_info.rhs_info = rhs_info;
118  gemm_kernel_info.has_pad_y = false;
119  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
120  {
121  return false;
122  }
123  gemm_kernel_info.has_pad_y = true;
124  if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
125  {
126  return false;
127  }
128  return true;
129 }
130 
131 //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
132 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
133  const ITensorInfo *b,
134  const ITensorInfo *c, const ITensorInfo *output)
135 {
137  if(config)
138  {
139  if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
140  {
141  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
142  return { config.lhs_info, config.rhs_info };
143  }
144  }
146  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
147  return { config.lhs_info, config.rhs_info };
148 }
149 
150 // Validate lhs_info and rhs_info for reshaped kernel
151 inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
152  const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
153 {
154  // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
155  TensorInfo tmp_a_info{};
156  TensorInfo tmp_b_info{};
157 
158  // Validate reshape LHS kernel
159  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
160  if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
161  {
162  return false;
163  }
164 
165  // Validate reshape RHS kernel
166  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
167  if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
168  {
169  return false;
170  }
171  // Validate mm kernel
172  gemm_kernel_info.lhs_info = lhs_info;
173  gemm_kernel_info.rhs_info = rhs_info;
174  if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
175  {
176  return false;
177  }
178  return true;
179 }
180 
181 //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
182 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
183  const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
184 {
186  if(config)
187  {
188  if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
189  {
190  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
191  return { config.lhs_info, config.rhs_info };
192  }
193  }
195  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
196  return { config.lhs_info, config.rhs_info };
197 }
198 } // namespace
199 
201  : _mm_kernel(std::make_unique<ClGemmMatrixMultiplyKernel>()),
202  _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
203  _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
204  _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
205  _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
206  _mm_reshaped_only_rhs_fallback_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
207  _tmp_a(),
208  _tmp_b(),
209  _reshape_b_only_on_first_run(false),
210  _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1),
211  _is_prepared(false),
212  _aux_mem(AuxTensorIdx::Count)
213 {
214 }
215 
216 void ClGemm::configure_native_v1(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
217  const GEMMInfo &gemm_info)
218 {
219  const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
220  const unsigned int n = b->dimension(0);
221  const unsigned int k = a->dimension(0);
222  const GPUTarget gpu_target = CLScheduler::get().target();
223 
224  // Set the target for the kernels
225  _mm_kernel->set_target(gpu_target);
226 
227  GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
228 
229  // Configure and tune matrix multiply kernel
230  _mm_kernel->configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
231 
232  // Tune kernel statically
233  CLScheduler::get().tune_kernel_static(*_mm_kernel);
234 }
235 
236 void ClGemm::configure_reshaped_v1(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237  const GEMMInfo &gemm_info)
238 {
239  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
240  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
241  const unsigned int n = b->dimension(0);
242  const unsigned int k = a->dimension(0);
243  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
244  const GPUTarget gpu_target = CLScheduler::get().target();
245  int mult_transpose1xW_width = 1;
246  int mult_interleave4x4_height = 1;
247 
248  // Set the target for the kernels
249  _reshape_lhs_kernel->set_target(gpu_target);
250  _mm_kernel->set_target(gpu_target);
251 
252  if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
253  {
254  mult_transpose1xW_width = 4;
255  mult_interleave4x4_height = 2;
256  }
257 
258  GEMMRHSMatrixInfo rhs_info;
259  rhs_info.n0 = 16 / b->element_size();
260  rhs_info.k0 = 1;
261  rhs_info.h0 = mult_transpose1xW_width;
262  rhs_info.interleave = false;
263  rhs_info.transpose = false;
264 
265  GEMMLHSMatrixInfo lhs_info;
266  lhs_info.m0 = 4;
267  lhs_info.k0 = 4;
268  lhs_info.v0 = mult_interleave4x4_height;
269  lhs_info.interleave = true;
270  lhs_info.transpose = true;
271 
272  GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
273 
274  // Configure interleave kernel
275  _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
276 
277  // Configure transpose kernel
278  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
279 
280  // Configure and tune matrix multiply kernel
281  _mm_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
282 
283  CLScheduler::get().tune_kernel_static(*_mm_kernel);
284 
285  // Request memory for LHS and RHS reshape matrix
286  _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
287  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
288 }
289 
290 void ClGemm::configure_reshaped_v2(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
291  const GEMMInfo &gemm_info)
292 {
294  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
295  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
296  const unsigned int n = b->dimension(0);
297  const unsigned int k = a->dimension(0);
298  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
299  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
300  const GPUTarget gpu_target = CLScheduler::get().target();
301  bool broadcast_bias = gemm_info.broadcast_bias();
302 
303  GEMMKernelInfo kernel_info;
304  kernel_info.m = m;
305  kernel_info.n = n;
306  kernel_info.k = k;
307  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
308  kernel_info.reinterpret_input_as_3d = false;
309  kernel_info.broadcast_bias = broadcast_bias;
310  kernel_info.activation_info = gemm_info.activation_info();
311 
312  // Set the target for the kernels
313  _reshape_lhs_kernel->set_target(gpu_target);
314  _mm_kernel->set_target(gpu_target);
315 
316  GEMMLHSMatrixInfo lhs_info{};
317  GEMMRHSMatrixInfo rhs_info{};
318 
319  // Pick up the GEMM configuration
320  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
321  c, output, gemm_info.reinterpret_input_as_3d());
322 
323  _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
324  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
325 
326  // Configure and tune matrix multiply kernel
327  _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
328 
329  // Request memory for LHS and RHS reshape matrix
330  _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
331  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
332 }
333 
334 void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
335  const GEMMInfo &gemm_info)
336 {
338  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
339  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
340  const unsigned int n = b->dimension(0);
341  const unsigned int k = a->dimension(0);
342  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
343  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
344  const GPUTarget gpu_target = CLScheduler::get().target();
345  bool broadcast_bias = gemm_info.broadcast_bias();
346 
347  GEMMKernelInfo kernel_info;
348  kernel_info.m = m;
349  kernel_info.n = n;
350  kernel_info.k = k;
351  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
352  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
353  kernel_info.broadcast_bias = broadcast_bias;
354  kernel_info.activation_info = gemm_info.activation_info();
355 
356  // Set the target for the kernels
357  _mm_kernel->set_target(gpu_target);
358 
359  GEMMLHSMatrixInfo lhs_info{};
360  GEMMRHSMatrixInfo rhs_info{};
361 
362  // Pick up the GEMM configuration
363  std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
364 
365  // Transpose matrix
366  _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
367 
368  // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
369  // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
370  // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
371 
372  // Configure matrix multiply kernel with no y padding support
373  kernel_info.has_pad_y = false;
374  _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
375 
376  // Configure matrix multiply kernel with y padding support
377  kernel_info.has_pad_y = true;
378  _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
379 
380  // Request memory for RHS reshape matrix
381  _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
382 }
383 
384 Status ClGemm::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
385 {
386  ARM_COMPUTE_UNUSED(alpha);
387  ARM_COMPUTE_UNUSED(output);
388 
389  // Get the GPU target
390  const GPUTarget gpu_target = CLScheduler::get().target();
391  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
392  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
393  const unsigned int n = b->dimension(0);
394  const unsigned int k = a->dimension(0);
395  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
396 
397  const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
398 
399  // Validate matrix multiply
401  false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
402 
403  return Status{};
404 }
405 
406 Status ClGemm::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
407 {
408  ARM_COMPUTE_UNUSED(alpha);
409  ARM_COMPUTE_UNUSED(output);
410 
411  TensorInfo tmp_a_info{};
412  TensorInfo tmp_b_info{};
413 
414  // Get the GPU target
415  const GPUTarget gpu_target = CLScheduler::get().target();
416  const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
417  const unsigned int n = b->dimension(0);
418  const unsigned int k = a->dimension(0);
419  int mult_transpose1xW_width = 1;
420  int mult_interleave4x4_height = 1;
421  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
422 
423  if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
424  {
425  mult_transpose1xW_width = 4;
426  mult_interleave4x4_height = 2;
427  }
428 
429  GEMMRHSMatrixInfo rhs_info;
430  rhs_info.n0 = 16 / b->element_size();
431  rhs_info.k0 = 1;
432  rhs_info.h0 = mult_transpose1xW_width;
433  rhs_info.interleave = false;
434  rhs_info.transpose = false;
435 
436  GEMMLHSMatrixInfo lhs_info;
437  lhs_info.m0 = 4;
438  lhs_info.k0 = 4;
439  lhs_info.v0 = mult_interleave4x4_height;
440  lhs_info.interleave = true;
441  lhs_info.transpose = true;
442 
443  const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
444 
445  // Validate interleave kernel
446  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
448 
449  // Validate transpose kernel
450  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
452 
453  // Validate matrix multiply
454  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
455  true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
456 
457  return Status{};
458 }
459 
460 Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
461 {
462  ARM_COMPUTE_UNUSED(alpha);
463  ARM_COMPUTE_UNUSED(output);
464 
465  TensorInfo tmp_a_info{};
466  TensorInfo tmp_b_info{};
467 
468  // Get the GPU target
469  const GPUTarget gpu_target = CLScheduler::get().target();
471  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
472  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
473  const unsigned int n = b->dimension(0);
474  const unsigned int k = a->dimension(0);
475  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
476  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
477  const bool broadcast_bias = gemm_info.broadcast_bias();
478 
479  GEMMKernelInfo kernel_info;
480  kernel_info.m = m;
481  kernel_info.n = n;
482  kernel_info.k = k;
483  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
484  kernel_info.reinterpret_input_as_3d = false;
485  kernel_info.broadcast_bias = broadcast_bias;
486  kernel_info.activation_info = gemm_info.activation_info();
487 
488  GEMMLHSMatrixInfo lhs_info;
489  GEMMRHSMatrixInfo rhs_info;
490 
491  // Pick up the GEMM configuration
492  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
493  const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
494  lhs_info = gemm_config.lhs_info;
495  rhs_info = gemm_config.rhs_info;
496 
497  auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
499 
500  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
502 
503  // Validate matrix multiply
504  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
505 
506  return Status{};
507 }
508 
509 Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
510 {
511  ARM_COMPUTE_UNUSED(alpha);
512  ARM_COMPUTE_UNUSED(output);
513 
514  TensorInfo tmp_b_info{};
515 
516  // Get the GPU target
517  const GPUTarget gpu_target = CLScheduler::get().target();
518  const DataType data_type = a->data_type();
519  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
520  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
521  const unsigned int n = b->dimension(0);
522  const unsigned int k = a->dimension(0);
523  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
524  const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
525  const bool broadcast_bias = gemm_info.broadcast_bias();
526 
527  GEMMKernelInfo kernel_info;
528  kernel_info.m = m;
529  kernel_info.n = n;
530  kernel_info.k = k;
531  kernel_info.depth_output_gemm3d = depth_output_gemm3d;
532  kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
533  kernel_info.broadcast_bias = broadcast_bias;
534  kernel_info.activation_info = gemm_info.activation_info();
535 
536  GEMMLHSMatrixInfo lhs_info;
537  GEMMRHSMatrixInfo rhs_info;
538 
539  // Pick up the GEMM configuration
540  // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
541  const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
542  lhs_info = gemm_config.lhs_info;
543  rhs_info = gemm_config.rhs_info;
544 
545  auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
547 
548  // Validate matrix multiply
549  kernel_info.has_pad_y = false;
550  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
551 
552  kernel_info.has_pad_y = true;
553  ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
554 
555  return Status{};
556 }
557 
558 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
559 {
560  ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
561 
562  // Perform validation step
563  ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
564 
565  // Check if we need to reshape the matrix B only on the first run
566  _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
567  _is_prepared = gemm_info.retain_internal_weights();
568 
569  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
570  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
571  const unsigned int n = b->dimension(0);
572  const unsigned int k = a->dimension(0);
573  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
574 
575  // Select GEMMType
576  _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
577  gemm_info.constant_weights());
578 
579  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
580 
581  ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
582 
583  switch(_gemm_kernel_type)
584  {
586  {
587  configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
588  break;
589  }
591  {
592  configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
593  break;
594  }
596  {
597  configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
598  break;
599  }
601  {
602  configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
603  break;
604  }
605  default:
606  {
607  ARM_COMPUTE_ERROR("GEMMType not supported");
608  }
609  }
610 }
611 
612 Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
613 {
614  // Get the GPU target
615  bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
616  const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
617  const unsigned int n = b->dimension(0);
618  const unsigned int k = a->dimension(0);
619  const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
620 
621  // Select GEMMType
622  CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
623  {
624  CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
625  },
626  gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights());
627 
628  const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
629 
630  const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
631 
632  switch(gemm_kernel_type)
633  {
635  {
636  ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
637  break;
638  }
640  {
641  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
642  break;
643  }
645  {
646  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
647  break;
648  }
650  {
651  ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
652  break;
653  }
654  default:
655  {
656  ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
657  }
658  }
659 
660  return Status{};
661 }
662 
663 void ClGemm::run(ITensorPack &tensors)
664 {
665  const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
666  const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
667  const ITensor *src2 = tensors.get_const_tensor(ACL_SRC_2);
668  ITensor *dst = tensors.get_tensor(ACL_DST);
669 
671 
672  CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
673  CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
674 
675  // Prepare the consts if needed
676  prepare(tensors);
677 
678  // Run matrix multiply kernel
679  switch(_gemm_kernel_type)
680  {
682  {
683  CLScheduler::get().enqueue_op(*_mm_kernel, tensors, true);
684  break;
685  }
688  {
689  // Run interleave kernel
690  ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
691  CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
692 
693  if(!_reshape_b_only_on_first_run)
694  {
695  // Run transpose kernel
696  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
697  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
698  }
699 
700  ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } };
701 
702  if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
703  {
704  CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
705  }
706  else
707  {
708  CLScheduler::get().enqueue_op(*_mm_kernel, gemm_reshaped_pack, true);
709  }
710  break;
711  }
713  {
714  if(!_reshape_b_only_on_first_run)
715  {
716  // Run transpose kernel
717  ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
718  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
719  }
720  // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
721  // Check if the lhs or dst tensors have padding
722  const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
723  const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
724  bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
725 
726  ITensorPack gemm_reshaped_onlyrhs_pack{ { ACL_SRC_0, lhs }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } };
727  if(has_pad_y)
728  {
729  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_fallback_kernel, gemm_reshaped_onlyrhs_pack, true);
730  }
731  else
732  {
733  CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
734  }
735  break;
736  }
737  default:
738  {
739  ARM_COMPUTE_ERROR("GEMMType not supported");
740  }
741  }
742 }
743 
744 void ClGemm::prepare(ITensorPack &constants)
745 {
746  if(!_is_prepared)
747  {
748  const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
749  ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
750 
751  // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
752  if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
753  {
754  ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
755 
756  CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
757  ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
758 
759  ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
760  CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
761  }
762  _is_prepared = true;
763  }
764 }
765 
767 {
768  return _aux_mem;
769 }
770 } // namespace opencl
771 } // namespace arm_compute
unsigned int top
top of the border
Definition: Types.h:372
bool broadcast_bias
Flag used to broadcast the bias addition.
bool constant_weights() const
Flag which specifies if the values of the weights tensor are constant throughout multiple executions ...
Definition: Types.h:2120
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on default heuristics.
Descriptor used by the GEMM kernels.
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
SimpleTensor< float > b
Definition: DFT.cpp:157
static CLScheduler & get()
Access the scheduler singleton.
#define ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL(msg)
Log an information message to the logger with function name before the message.
Definition: Log.h:98
unsigned int v0
Number of vertical blocks of size (m0xk0) stored on the same output row.
Definition: Types.h:1913
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
GPUTarget target() const
Get the target GPU.
Definition: CLScheduler.cpp:45
unsigned int depth_output_gemm3d
Depth of the output tensor in case is reinterpreted as 3D.
GEMM reshape information class.
Definition: Types.h:1760
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Initialise the kernel&#39;s inputs and output.
Definition: ClGemm.cpp:558
OpenCL kernel to multiply matrices when only the input matrix RHS (src1) has been reshaped...
A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and de...
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Definition: Types.h:1928
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
bool fp_mixed_precision() const
Flag which specifies if a wider accumulator should be used.
Definition: Types.h:2064
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:1904
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
unsigned int bottom
bottom of the border
Definition: Types.h:374
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
Reshaped GEMM kernel where only the rhs matrix is reshaped.
int depth_output_gemm3d() const
Depth of the output when GEMM output is reinterpreted as 3D tensor.
Definition: Types.h:2024
Status class.
Definition: Error.h:52
GPUTarget get_arch_from_target(GPUTarget target)
Helper function to get the GPU arch.
Definition: GPUTarget.cpp:193
ActivationLayerInfo activation_info
Activation function to perform after the matrix multiplication.
bool retain_internal_weights() const
Flag which specifies if the weights tensor has to be retained from previous run.
Definition: Types.h:2040
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
Interface for CPU tensor.
Definition: ITensor.h:36
bool transpose
True if the (k0xn0) block has to be transposed before been stored.
Definition: Types.h:1929
bool interleave
True if the v0 (m0xk0) blocks have to be interleaved in the output row.
Definition: Types.h:1915
Copyright (c) 2017-2021 Arm Limited.
std::vector< MemoryInfo > MemoryRequirements
Definition: Types.h:113
bool transpose
True if the (m0xk0) block has to be transposed before been stored.
Definition: Types.h:1914
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
Native GEMM kernel with fixed block size.
const DataType data_type
Definition: Im2Col.cpp:150
Interface to enqueue OpenCL kernels and get/set the OpenCL CommandQueue and ICLTuner.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
Definition: ITensorPack.cpp:54
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Types.h:1927
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
unsigned int m
Number of LHS rows.
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
std::string to_string(const ROIPoolingLayerInfo &pool_info)
Formatted output of the ROIPoolingInfo type.
Definition: TypePrinter.h:148
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
unsigned int n
Number of RHS columns.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
TensorShape compute_lhs_reshaped_shape(const ITensorInfo &a, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d=false)
Calculate the Left Hand Side matrix reshaped shape.
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:1919
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration.
unsigned int n0
Number of columns processed by the matrix multiplication.
Definition: Types.h:1926
static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision=false, const ActivationLayerInfo &activation_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration.
size_t total_size() const override
Returns the total size of the tensor in bytes.
Definition: TensorInfo.h:250
void enqueue_op(ICLKernel &kernel, ITensorPack &tensors, bool flush=true)
Schedule the execution of the passed kernel if possible.
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
bool reinterpret_input_as_3d
Flag used to reinterpret the input as 3D.
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_src_as_3d)
Static function to check if given info will lead to a valid configuration.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on default heuristics.
virtual PaddingSize padding() const =0
Padding of tensor.
bool reinterpret_input_as_3d() const
Flag which specifies if the input tensor has to be reinterpreted as 3D.
Definition: Types.h:2032
bool broadcast_bias() const
Flag which specifies whether to broadcast the shape of the bias tensor.
Definition: Types.h:2080
CLCompileContext class.
bool has_pad_y
Flag used to indicate if the input/output tensors have internal pad on the y direction.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Definition: ITensorPack.cpp:64
#define ARM_COMPUTE_RETURN_ERROR_MSG(...)
An error is returned with the given description.
Definition: Error.h:194
Interface for OpenCL tensor.
Definition: ICLTensor.h:42
OpenCL kernel to reshape the RHS matrix when performing the matrix multiplication In particular...
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
Select gemm type based on mlgo heuristics.
unsigned int k
Number of LHS columns or RHS rows.
bool interleave
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
Definition: Types.h:1930
bool is_zero(float a, float epsilon=0.00001f)
Checks if the input floating point number is 0.0f checking if the difference is within a range define...
Definition: float_ops.h:109
void run(ITensorPack &tensors) override
Run the kernels contained in the function.
Definition: ClGemm.cpp:663
OpenCL kernel to reshape the LHS matrix when performing the matrix multiplication.
OpenCL kernel to multiply two input matrices "A" and "B" and add a martix "C" if provided.
Tensor packing service.
Definition: ITensorPack.h:39
virtual const cl::Buffer & cl_buffer() const =0
Interface to be implemented by the child class to return a reference to the OpenCL buffer containing ...
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
bool reshape_b_only_on_first_run() const
Flag which specifies if the reshape of matrix B should executed only for the first.
Definition: Types.h:2016
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Types.h:1912
void prepare(ITensorPack &constants) override
Prepare the function for executing.
Definition: ClGemm.cpp:744
int offset_int_vec(int offset)
Definition: MemoryHelpers.h:38
GEMM information class.
Definition: Types.h:1939
unsigned int m0
Number of rows processed by the matrix multiplication.
Definition: Types.h:1911
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on mlgo heuristics.
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Static function to check if given info will lead to a valid configuration.
Definition: ClGemm.cpp:612
void tune_kernel_static(ICLKernel &kernel)
Tunes OpenCL kernel.
Definition: CLScheduler.cpp:82
DataType
Available data types.
Definition: Types.h:77
ActivationLayerInfo activation_info() const
Activation layer to apply after the matrix multiplication.
Definition: Types.h:2104
ClGemm()
Constructor.
Definition: ClGemm.cpp:200
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
OpenCL kernel to multiply matrices when both the input matrices LHS (src0) and RHS (src1) have been r...
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
Select gemm config based on default heuristics.
experimental::MemoryRequirements workspace() const override
Return the memory requirements required by the workspace.
Definition: ClGemm.cpp:766