24.02.1
|
Go to the documentation of this file.
38 template<
typename Top,
typename Tret,
class OutputStage = Nothing>
54 if (
args._fixed_format ==
false) {
105 std::function<
bool(
const GemmArgs &,
const OutputStage &)>
is_supported, std::function<
bool(
const GemmArgs &,
const OutputStage &)> is_recommended,
108 cycle_estimate( [is_recommended](const
GemmArgs &
args, const OutputStage &os) {
return (is_recommended ==
nullptr) ? 0 : (is_recommended(
args, os) ? 0 : UINT64_MAX); } ),
112 std::function<
bool(
const GemmArgs &,
const OutputStage &)>
is_supported, std::function<
bool(
const GemmArgs &,
const OutputStage &)> is_recommended,
115 cycle_estimate( [is_recommended](const
GemmArgs &
args, const OutputStage &os) {
return (is_recommended ==
nullptr) ? 0 : (is_recommended(
args, os) ? 0 : UINT64_MAX); } ),
122 template<
typename Top,
typename Tret>
138 if (
args._fixed_format ==
false) {
204 cycle_estimate( [is_recommended](const
GemmArgs &
args) -> uint64_t {
return (is_recommended ==
nullptr) ? 0 : (is_recommended(
args) ? 0 : UINT64_MAX); } ),
211 cycle_estimate( [is_recommended](const
GemmArgs &
args) -> uint64_t {
return (is_recommended ==
nullptr) ? 0 : (is_recommended(
args) ? 0 : UINT64_MAX); } ),
221 template<
typename Top,
typename Tret,
class OutputStage = Nothing>
237 template<
typename Top,
typename Tret,
class OutputStage>
239 auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
243 uint64_t best_estimate = 0;
247 if (!i->do_is_supported(
args, os)) {
257 if (cfg && cfg->
filter !=
"" && !strstr(i->name, cfg->
filter.c_str())) {
272 if ((saved_impl ==
nullptr) || (estimate < best_estimate)) {
274 best_estimate = estimate;
279 if (saved_impl !=
nullptr) {
287 template<
typename Top,
typename Tret,
class OutputStage>
289 std::vector<KernelDescription> res;
295 auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
300 if (!i->do_is_supported(
args, os)) {
310 template<
typename Top,
typename Tret,
class OutputStage>
313 const bool success = find_implementation<Top, Tret, OutputStage>(
args, os, impl);
319 template<
typename Top,
typename Tret,
class OutputStage>
323 if (find_implementation<Top, Tret, OutputStage>(
args, os, impl)) {
330 template<
typename Top,
typename Tret,
class OutputStage>
334 if (find_implementation<Top, Tret>(
args, os, impl)) {
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os)
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< uint64_t(const GemmArgs &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const
GemmImplementation(GemmMethod m, const char *n)
bool do_is_supported(const GemmArgs &args, const Nothing &) const
std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate
bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os)
static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
std::function< bool(const GemmArgs &, const OutputStage &)> is_supported
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &, const OutputStage &)> is_supported, std::function< bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &, const OutputStage &)> instantiate)
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED)
static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, std::function< bool(const GemmArgs &)> is_supported, std::function< uint64_t(const GemmArgs &)> cycle_estimate, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
WeightFormat get_weight_format(const KernelWeightFormat, size_t)
std::function< uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate
GemmImplementation & operator=(const GemmImplementation &)=default
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const Nothing &) const
const GemmImplementation< Top, Tret, OutputStage > * gemm_implementation_list()
const KernelWeightFormat kernel_weight_format
bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation< Top, Tret, OutputStage > *&impl)
GemmCommon< Top, Tret > * do_instantiate(const GemmArgs &args, const OutputStage &os) const
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function< bool(const GemmArgs &)> is_supported, std::function< bool(const GemmArgs &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)
std::unique_ptr< GemmCommon< Top, Tret > > UniqueGemmCommon
GemmImplementation(const GemmImplementation &)=default
std::vector< KernelDescription > get_compatible_kernels(const GemmArgs &args, const OutputStage &os)
GemmImplementation(GemmMethod m, const char *n, std::function< bool(const GemmArgs &)> is_supported, std::function< bool(const GemmArgs &)> is_recommended, std::function< GemmCommon< Top, Tret > *(const GemmArgs &)> instantiate)