26 #include "pooling.hpp"
35 template <
typename TInput,
typename TOutput,
class OutputStage = Nothing>
40 std::function<bool(
const PoolingArgs &,
const OutputStage &)>
is_supported;
41 std::function<uint64_t(
const PoolingArgs &,
const OutputStage &)>
cycle_estimate;
42 std::function<PoolingCommon<TInput, TOutput> *(
const PoolingArgs &,
const OutputStage &)>
initialise;
54 PoolingCommon<TInput, TOutput> *
get_instance(
const PoolingArgs &
args,
const OutputStage &os)
const
63 template <
typename TInput,
typename TOutput,
class OutputStage = Nothing>
66 template <
typename TInput,
typename TOutput,
class OutputStage = Nothing>
68 const PoolingArgs &
args,
69 const OutputStage &os,
74 const auto *impl = pooling_implementation_list<TInput, TOutput, OutputStage>();
75 for (; impl->method != PoolingMethod::DEFAULT; impl++)
77 if (
args.config !=
nullptr)
80 const auto cfg =
args.config;
82 if (cfg->filter !=
"" && !std::strstr(impl->name, cfg->filter.c_str()))
88 if (impl->get_is_supported(
args, os))
97 template <
typename TInput,
typename TOutput,
class OutputStage>
98 UniquePoolingCommon<TInput, TOutput>
pooling(
const PoolingArgs &
args,
const OutputStage &os)
101 const bool success = find_implementation<TInput, TOutput, OutputStage>(
args, os, impl);
102 return UniquePoolingCommon<TInput, TOutput>(success ? impl->
get_instance(
args, os) :
nullptr);
105 template <
class Strategy>
108 return ((
args.pool_type == Strategy::pooling_type) &&
109 (
args.pool_window.rows == Strategy::pool_rows) &&
110 (
args.pool_window.cols == Strategy::pool_cols) &&
111 (
args.pool_stride.rows == Strategy::stride_rows) &&
112 (
args.pool_stride.cols == Strategy::stride_cols));