Compute Library
 22.11
gemm_implementation.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "arm_gemm.hpp"
26 
27 #include "kernel_weight_format.hpp"
28 
29 #include <cstdint>
30 #include <functional>
31 
32 namespace arm_gemm {
33 
34 /* Structure describing an implementation. For each supported combination
35  * of types, a static list of these structures is built up to describe the
36  * implementations available.
37  */
38 template<typename Top, typename Tret, class OutputStage = Nothing>
41  const char * name;
43  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {};
44  std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {};
45  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
46 
47  bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
48  // Check supplied is_supported() function first.
49  if (is_supported != nullptr && !is_supported(args, os)) {
50  return false;
51  }
52 
53  // Check weight format is appropriate.
54  if (args._fixed_format == false) {
55  // Can't return a fixed format kernel if we weren't asked for one.
56  return (kernel_weight_format == KernelWeightFormat::NON_FIXED);
57  } else {
58  // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it.
59  if (kernel_weight_format == KernelWeightFormat::NON_FIXED) {
60  return false;
61  }
62 
63  // If there's no config, or the config says ANY then this one is OK.
64  if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) {
65  return true;
66  }
67 
68  // If we get here it means there is a config and it specifies a format. Check it matches this kernel.
69  // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
70  // was called above first.
71  return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
72  }
73  }
74 
75  uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const {
76  if (cycle_estimate != nullptr) {
77  return cycle_estimate(args, os);
78  } else {
79  return 0;
80  }
81  }
82 
83  GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
84  return instantiate(args, os);
85  }
86 
88  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate,
89  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) {
90  GemmImplementation impl(m,n);
91 
95 
96  return impl;
97  }
98 
99  GemmImplementation(const GemmImplementation &) = default;
100  GemmImplementation & operator= (const GemmImplementation &) = default;
101 
102  GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {}
103 
104  GemmImplementation(GemmMethod m, const char *n,
105  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
106  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
107  method(m), name(n), is_supported(is_supported),
108  cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
110 
112  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
113  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
114  method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
115  cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
117 };
118 
119 /* Slightly different version of above for straightforward GEMMs with no
120  * output stage, so the std::functions there don't have to deal with the
121  * unnecessary second argument. */
122 template<typename Top, typename Tret>
123 struct GemmImplementation<Top, Tret, Nothing> {
125  const char * name;
126  const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED;
127  std::function<bool(const GemmArgs &)> is_supported = {};
128  std::function<uint64_t(const GemmArgs &)> cycle_estimate = {};
129  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {};
130 
131  bool do_is_supported(const GemmArgs &args, const Nothing &) const {
132  // Check supplied is_supported() function first.
133  if (is_supported != nullptr && !is_supported(args)) {
134  return false;
135  }
136 
137  // Check weight format is appropriate.
138  if (args._fixed_format == false) {
139  // Can't return a fixed format kernel if we weren't asked for one.
140  return (kernel_weight_format == KernelWeightFormat::NON_FIXED);
141  } else {
142  // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it.
143  if (kernel_weight_format == KernelWeightFormat::NON_FIXED) {
144  return false;
145  }
146 
147  // If there's no config, or the config says ANY then this one is OK.
148  if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) {
149  return true;
150  }
151 
152  // If we get here it means there is a config and it specifies a format. Check it matches this kernel.
153  // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
154  // was called above first.
155  return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
156  }
157  }
158 
159  uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const {
160  if (cycle_estimate != nullptr) {
161  return cycle_estimate(args);
162  } else {
163  return 0;
164  }
165  }
166 
167  GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
168  return instantiate(args);
169  }
170 
171  static GemmImplementation with_estimate(GemmMethod m, const char *n,
172  std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
173  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
174  GemmImplementation impl(m,n);
175 
179 
180  return impl;
181  }
182 
184  std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
185  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
186  GemmImplementation impl(m,n,f);
187 
191 
192  return impl;
193  }
194 
195  GemmImplementation(const GemmImplementation &) = default;
196  GemmImplementation & operator= (const GemmImplementation &) = default;
197 
198  GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {}
199 
200  GemmImplementation(GemmMethod m, const char *n,
201  std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
202  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
203  method(m), name(n), is_supported(is_supported),
204  cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
206 
208  std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
209  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
210  method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
211  cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
213 };
214 
215 /* "Main" function implemented for each valid combination of types.
216  * Returns a list of GEMM implementation descriptors for processing by the
217  * other functions, ended by an implementation with
218  * method==GemmMethod::DEFAULT. */
219 template<typename Top, typename Tret, class OutputStage = Nothing>
221 
222 /*
223  * Select a GEMM implementation for the given arguments.
224  *
225  * The logic here returns the method on the list which supports the
226  * requested problem parameters, matches the provided filters (method and/or
227  * name string match) and offers the lowest cycle estimate. A cycle
228  * estimate of '0' is treated as a special value, causing the corresponding
229  * method to be selected immediately.
230  *
231  * If no method supports the requested parameters and passes the filters,
232  * this function returns false and doesn't touch the provided pointer
233  * reference.
234  */
235 template<typename Top, typename Tret, class OutputStage>
236 bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
237  auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
238  const GemmConfig *cfg = args._cfg;
239 
240  const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
241  uint64_t best_estimate = 0;
242 
243  for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
244  /* Skip if this implementation doesn't support these args. */
245  if (!i->do_is_supported(args, os)) {
246  continue;
247  }
248 
249  /* Skip if a specific method is requested and this is a different one. */
250  if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
251  continue;
252  }
253 
254  /* Skip if a filter is to be applied and it doesn't match. */
255  if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
256  continue;
257  }
258 
259  /* Test the cycle estimate */
260  uint64_t estimate = i->do_cycle_estimate(args, os);
261 
262  /* Short circuit - if the estimate is zero, return this one immediately. */
263  if (estimate==0) {
264  impl=i;
265  return true;
266  }
267 
268  /* Otherwise, remember this is our best so far if we don't yet have
269  * a valid candidate, or we beat the estimate. */
270  if ((saved_impl == nullptr) || (estimate < best_estimate)) {
271  saved_impl = i;
272  best_estimate = estimate;
273  }
274  }
275 
276  /* Return whichever method gave the best estimate. */
277  if (saved_impl != nullptr) {
278  impl = saved_impl;
279  return true;
280  }
281 
282  return false;
283 }
284 
285 template<typename Top, typename Tret, class OutputStage>
286 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) {
287  std::vector<KernelDescription> res;
288 
289  /* Find out what the default implementation in so we can set the flag accordingly later. */
290  const GemmImplementation<Top, Tret, OutputStage> *default_impl;
291  find_implementation(args, os, default_impl);
292 
293  auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
294 
295  for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
296  /* Check that this implementation supports the presented problem. */
297 
298  if (!i->do_is_supported(args, os)) {
299  continue;
300  }
301 
302  res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os)));
303  }
304 
305  return res;
306 }
307 
308 template<typename Top, typename Tret, class OutputStage>
309 bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
311  const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl);
312  if (success)
313  wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
314  return success;
315 }
316 
317 template<typename Top, typename Tret, class OutputStage>
318 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
320 
321  if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
322  return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os));
323  }
324 
325  return UniqueGemmCommon<Top, Tret>(nullptr);
326 }
327 
328 template<typename Top, typename Tret, class OutputStage>
329 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) {
331 
332  if (find_implementation<Top, Tret>(args, os, impl)) {
333  return KernelDescription(impl->method, impl->name);
334  }
335 
336  /* This shouldn't happen - there should always be at least one valid implementation. */
337  return KernelDescription();
338 }
339 
340 
341 } // namespace arm_gemm
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function< bool(const GemmArgs &)> is_supported, std::function< bool(const GemmArgs &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os)
std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate
std::vector< KernelDescription > get_compatible_kernels(const GemmArgs &args, const OutputStage &os)
std::function< bool(const GemmArgs &, const OutputStage &)> is_supported
WeightFormat get_weight_format(const KernelWeightFormat, size_t)
Definition: misc.cpp:40
const KernelWeightFormat kernel_weight_format
bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os)
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const OutputStage &os) const
const GemmConfig * _cfg
Definition: arm_gemm.hpp:157
bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation< Top, Tret, OutputStage > *&impl)
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< uint64_t(const GemmArgs &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
GemmImplementation(GemmMethod m, const char *n)
bool do_is_supported(const GemmArgs &args, const Nothing &) const
uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const
WeightFormat weight_format
Definition: arm_gemm.hpp:112
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED)
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const Nothing &) const
std::unique_ptr< GemmCommon< Top, Tret > > UniqueGemmCommon
Definition: arm_gemm.hpp:216
GemmImplementation & operator=(const GemmImplementation &)=default
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, std::function< bool(const GemmArgs &)> is_supported, std::function< uint64_t(const GemmArgs &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const
const GemmImplementation< Top, Tret, OutputStage > * gemm_implementation_list()
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< bool(const GemmArgs &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
GemmImplementation(const GemmImplementation &)=default
std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)