32 #if defined(__aarch64__) 33 #if defined(ARM_COMPUTE_ENABLE_SVE) 34 #if defined(ARM_COMPUTE_ENABLE_SVE2) 36 #endif // defined(ARM_COMPUTE_ENABLE_SVE2) 39 #endif // defined(ARM_COMPUTE_ENABLE_SVE) 43 #endif // defined(__aarch64__) 52 template <
class Strategy>
53 bool is_supported(
const PoolingArgs &
args,
const Nothing &)
55 return ((args.pool_type == Strategy::pooling_type()) &&
56 (args.pool_window.rows == Strategy::pool_rows()) &&
57 (args.pool_window.cols == Strategy::pool_cols()) &&
58 (args.pool_stride.rows == Strategy::stride_rows()) &&
59 (args.pool_stride.cols == Strategy::stride_cols()));
63 static const PoolingImplementation<uint8_t, uint8_t> pooling_u8_methods[] = {
65 PoolingMethod::DEPTHFIRST,
66 "cpp_u8_nhwc_1x1_stride_any_depthfirst",
67 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
68 return args.pool_window.rows == 1 && args.pool_window.cols == 1;
71 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
72 return new PoolingDepthfirstGeneric<cpp_nhwc_1x1_stride_any_depthfirst<uint8_t>>(
args);
75 #if defined(__aarch64__) 76 #if defined(ARM_COMPUTE_ENABLE_SVE) 77 #if defined(ARM_COMPUTE_ENABLE_SVE2) 79 PoolingMethod::DEPTHFIRST,
80 "sve_u8_nhwc_avg_generic_depthfirst",
81 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
85 return args.cpu_info->has_sve2() && (args.exclude_padding ||
86 (args.padding.top == 0 && args.padding.bottom == 0 &&
87 args.padding.left == 0 && args.padding.right == 0)
88 ) && args.pool_type == PoolingType::AVERAGE;
91 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
92 return new PoolingDepthfirstGeneric<sve_u8_nhwc_avg_generic_depthfirst>(
args);
95 #endif // defined(ARM_COMPUTE_ENABLE_SVE2) 97 PoolingMethod::DEPTHFIRST,
98 "sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst",
99 [] (
const PoolingArgs &
args,
const Nothing &unused) ->
bool {
100 return args.cpu_info->has_sve() && is_supported<sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst>(
args, unused);
103 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
104 return new PoolingDepthfirst<sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst>(
args);
108 PoolingMethod::DEPTHFIRST,
109 "sve_u8_nhwc_max_generic_depthfirst",
110 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
return args.cpu_info->has_sve() && args.pool_type ==
PoolingType::MAX; },
112 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
113 return new PoolingDepthfirstGeneric<sve_u8_nhwc_max_generic_depthfirst>(
args);
116 #endif // defined(ARM_COMPUTE_ENABLE_SVE) 118 PoolingMethod::DEPTHFIRST,
119 "a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst",
120 is_supported<a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst>,
122 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
123 return new PoolingDepthfirst<a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst>(
args);
127 PoolingMethod::DEPTHFIRST,
128 "a64_u8_nhwc_avg_generic_depthfirst",
129 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
133 return (args.exclude_padding ||
134 (args.padding.top == 0 && args.padding.bottom == 0 &&
135 args.padding.left == 0 && args.padding.right == 0)
136 ) && args.pool_type == PoolingType::AVERAGE;
139 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
140 return new PoolingDepthfirstGeneric<a64_u8_nhwc_avg_generic_depthfirst>(
args);
144 PoolingMethod::DEPTHFIRST,
145 "a64_u8_nhwc_max_generic_depthfirst",
146 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
return args.pool_type ==
PoolingType::MAX; },
148 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
149 return new PoolingDepthfirstGeneric<a64_u8_nhwc_max_generic_depthfirst>(
args);
152 #endif // defined(__aarch64__) 153 { PoolingMethod::DEFAULT,
"",
nullptr,
nullptr,
nullptr },
159 return pooling_u8_methods;
162 template UniquePoolingCommon<uint8_t, uint8_t>
pooling(
const PoolingArgs &,
const Nothing &);
template UniquePoolingCommon< float, float > pooling(const PoolingArgs &, const Nothing &)
const PoolingImplementation< float, float > * pooling_implementation_list()