27 #include "depthwise.hpp"
37 template <
typename TInput,
typename TWeight = TInput,
typename TOutput = TInput,
class OutputStage = Nothing>
42 std::function<bool(
const DepthwiseArgs &,
const OutputStage &)>
is_supported;
43 std::function<uint64_t(
const DepthwiseArgs &,
const OutputStage &)>
cycle_estimate;
44 std::function<DepthwiseCommon<TInput, TWeight, TOutput> *(
const DepthwiseArgs &,
const OutputStage &)>
initialise;
56 DepthwiseCommon<TInput, TWeight, TOutput> *
get_instance(
const DepthwiseArgs &
args,
const OutputStage &os)
const
59 impl->set_name(std::string(
name));
67 template <
typename TInput,
typename TWeight = TInput,
typename TOutput = TInput,
class OutputStage = Nothing>
70 template <
typename TInput,
typename TWeight = TInput,
typename TOutput = TInput,
class OutputStage = Nothing>
72 const DepthwiseArgs &
args,
73 const OutputStage &os,
78 uint64_t best_cycle_estimate = UINT64_MAX;
80 const auto *impl = depthwise_implementation_list<TInput, TWeight, TOutput, OutputStage>();
81 for (; impl->method != DepthwiseMethod::DEFAULT; impl++)
83 const bool has_cfg = (
args.config !=
nullptr);
84 const auto &cfg =
args.config;
87 !impl->get_is_supported(
args, os) ||
88 (has_cfg && cfg->method != DepthwiseMethod::DEFAULT && cfg->method != impl->method) ||
89 (has_cfg && cfg->filter !=
"" && !std::strstr(impl->name, cfg->filter.c_str()))
95 const auto cycle_estimate = impl->get_cycle_estimate(
args, os);
97 if (cycle_estimate == 0)
103 if (selected ==
nullptr || cycle_estimate < best_cycle_estimate)
106 best_cycle_estimate = cycle_estimate;
110 return (selected !=
nullptr);
113 template <
typename TInput,
typename TWeight,
typename TOutput,
class OutputStage>
116 std::vector<KernelDescription> kerns;
120 find_implementation<TInput, TWeight, TOutput, OutputStage>(
args, os, default_impl);
122 for (
auto impl = depthwise_implementation_list<TInput, TWeight, TOutput, OutputStage>();
123 impl->method != DepthwiseMethod::DEFAULT; impl++)
125 if (!impl->get_is_supported(
args, os))
131 impl->method, impl->name, impl == default_impl,
139 template <
typename TInput,
typename TWeight,
typename TOutput,
class OutputStage>
140 UniqueDepthwiseCommon<TInput, TWeight, TOutput>
depthwise(
const DepthwiseArgs &
args,
const OutputStage &os)
143 const bool success = find_implementation<TInput, TWeight, TOutput, OutputStage>(
args, os, impl);
144 return UniqueDepthwiseCommon<TInput, TWeight, TOutput>(success ? impl->
get_instance(
args, os) :
nullptr);