Compute Library
 22.05
gemm_implementation.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "arm_gemm.hpp"
26 
27 #include <cstdint>
28 #include <functional>
29 
30 namespace arm_gemm {
31 
32 /* Structure describing an implementation. For each supported combination
33  * of types, a static list of these structures is built up to describe the
34  * implementations available.
35  */
36 template<typename Top, typename Tret, class OutputStage = Nothing>
39  const char * name;
40  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {};
41  std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {};
42  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
43 
44  bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
45  if (is_supported != nullptr) {
46  return is_supported(args, os);
47  } else {
48  return true;
49  }
50  }
51 
52  uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const {
53  if (cycle_estimate != nullptr) {
54  return cycle_estimate(args, os);
55  } else {
56  return 0;
57  }
58  }
59 
60  GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
61  return instantiate(args, os);
62  }
63 
65  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate,
66  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) {
67  GemmImplementation impl(m,n);
68 
72 
73  return impl;
74  }
75 
76  GemmImplementation(const GemmImplementation &) = default;
78 
79  GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {}
80 
81  GemmImplementation(GemmMethod m, const char *n,
82  std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
83  std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
84  method(m), name(n), is_supported(is_supported),
85  cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
87 };
88 
89 /* Slightly different version of above for straightforward GEMMs with no
90  * output stage, so the std::functions there don't have to deal with the
91  * unnecessary second argument. */
92 template<typename Top, typename Tret>
93 struct GemmImplementation<Top, Tret, Nothing> {
95  const char * name;
96  std::function<bool(const GemmArgs &)> is_supported = {};
97  std::function<uint64_t(const GemmArgs &)> cycle_estimate = {};
98  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {};
99 
100  bool do_is_supported(const GemmArgs &args, const Nothing &) const {
101  if (is_supported != nullptr) {
102  return is_supported(args);
103  } else {
104  return true;
105  }
106  }
107 
108  uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const {
109  if (cycle_estimate != nullptr) {
110  return cycle_estimate(args);
111  } else {
112  return 0;
113  }
114  }
115 
116  GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
117  return instantiate(args);
118  }
119 
120  static GemmImplementation with_estimate(GemmMethod m, const char *n,
121  std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
122  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
123  GemmImplementation impl(m,n);
124 
128 
129  return impl;
130  }
131 
132  GemmImplementation(const GemmImplementation &) = default;
133  GemmImplementation & operator= (const GemmImplementation &) = default;
134 
135  GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {}
136 
137  GemmImplementation(GemmMethod m, const char *n,
138  std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
139  std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
140  method(m), name(n), is_supported(is_supported),
141  cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
143 };
144 
145 /* "Main" function implemented for each valid combination of types.
146  * Returns a list of GEMM implementation descriptors for processing by the
147  * other functions, ended by an implementation with
148  * method==GemmMethod::DEFAULT. */
149 template<typename Top, typename Tret, class OutputStage = Nothing>
151 
152 /*
153  * Select a GEMM implementation for the given arguments.
154  *
155  * The logic here returns the method on the list which supports the
156  * requested problem parameters, matches the provided filters (method and/or
157  * name string match) and offers the lowest cycle estimate. A cycle
158  * estimate of '0' is treated as a special value, causing the corresponding
159  * method to be selected immediately.
160  *
161  * If no method supports the requested parameters and passes the filters,
162  * this function returns false and doesn't touch the provided pointer
163  * reference.
164  */
165 template<typename Top, typename Tret, class OutputStage>
166 bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
167  auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
168  const GemmConfig *cfg = args._cfg;
169 
170  const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
171  uint64_t best_estimate = 0;
172 
173  for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
174  /* Skip if this implementation doesn't support these args. */
175  if (!i->do_is_supported(args, os)) {
176  continue;
177  }
178 
179  /* Skip if a specific method is requested and this is a different one. */
180  if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
181  continue;
182  }
183 
184  /* Skip if a filter is to be applied and it doesn't match. */
185  if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
186  continue;
187  }
188 
189  /* Test the cycle estimate */
190  uint64_t estimate = i->do_cycle_estimate(args, os);
191 
192  /* Short circuit - if the estimate is zero, return this one immediately. */
193  if (estimate==0) {
194  impl=i;
195  return true;
196  }
197 
198  /* Otherwise, remember this is our best so far if we don't yet have
199  * a valid candidate, or we beat the estimate. */
200  if ((saved_impl == nullptr) || (estimate < best_estimate)) {
201  saved_impl = i;
202  best_estimate = estimate;
203  }
204  }
205 
206  /* Return whichever method gave the best estimate. */
207  if (saved_impl != nullptr) {
208  impl = saved_impl;
209  return true;
210  }
211 
212  return false;
213 }
214 
215 template<typename Top, typename Tret, class OutputStage>
216 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) {
217  std::vector<KernelDescription> res;
218 
219  /* Find out what the default implementation in so we can set the flag accordingly later. */
220  const GemmImplementation<Top, Tret, OutputStage> *default_impl;
221  find_implementation(args, os, default_impl);
222 
223  auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
224 
225  for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
226  /* Check that this implementation supports the presented problem. */
227 
228  if (!i->do_is_supported(args, os)) {
229  continue;
230  }
231 
232  res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os)));
233  }
234 
235  return res;
236 }
237 
238 template<typename Top, typename Tret, class OutputStage>
239 bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
241  return find_implementation<Top, Tret, OutputStage>(args, os, impl);
242 }
243 
244 template<typename Top, typename Tret, class OutputStage>
245 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
247 
248  if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
249  return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os));
250  }
251 
252  return UniqueGemmCommon<Top, Tret>(nullptr);
253 }
254 
255 } // namespace arm_gemm
std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate
std::vector< KernelDescription > get_compatible_kernels(const GemmArgs &args, const OutputStage &os)
std::function< bool(const GemmArgs &, const OutputStage &)> is_supported
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const OutputStage &os) const
const GemmConfig * _cfg
Definition: arm_gemm.hpp:115
bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation< Top, Tret, OutputStage > *&impl)
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< uint64_t(const GemmArgs &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
GemmImplementation(GemmMethod m, const char *n)
bool do_is_supported(const GemmArgs &args, const Nothing &) const
uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const Nothing &) const
std::unique_ptr< GemmCommon< Top, Tret > > UniqueGemmCommon
Definition: arm_gemm.hpp:174
GemmImplementation & operator=(const GemmImplementation &)=default
bool has_opt_gemm(const GemmArgs &args, const OutputStage &os)
uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const
const GemmImplementation< Top, Tret, OutputStage > * gemm_implementation_list()
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< bool(const GemmArgs &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
GemmImplementation(const GemmImplementation &)=default
std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)