32 #if defined(__aarch64__)
33 #if defined(ARM_COMPUTE_ENABLE_SME)
37 #endif // defined(ARM_COMPUTE_ENABLE_SME)
38 #if defined(ARM_COMPUTE_ENABLE_SVE)
42 #endif // defined(ARM_COMPUTE_ENABLE_SVE)
46 #endif // defined(__aarch64__)
53 static const PoolingImplementation<uint8_t, uint8_t> pooling_u8_methods[] = {
55 PoolingMethod::DEPTHFIRST,
56 "cpp_u8_nhwc_1x1_stride_any_depthfirst",
57 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
58 return args.pool_window.rows == 1 &&
args.pool_window.cols == 1;
61 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
62 auto strat =
new cpp_nhwc_1x1_stride_any_depthfirst<uint8_t>(
args.cpu_info);
63 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
66 #if defined(__aarch64__)
67 #if defined(ARM_COMPUTE_ENABLE_SME)
69 PoolingMethod::DEPTHFIRST,
70 "sme_u8_nhwc_max_2x2_s1_output2x2_depthfirst",
71 [] (
const PoolingArgs &
args,
const Nothing &os) ->
bool {
72 return args.cpu_info->has_sme() &&
73 is_supported<sme_u8_nhwc_max_2x2_s1_output2x2_depthfirst>(
args, os);
76 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
77 auto strat =
new sme_u8_nhwc_max_2x2_s1_output2x2_depthfirst(
args.cpu_info);
78 return new PoolingDepthfirst<uint8_t>(strat,
args);
82 PoolingMethod::DEPTHFIRST,
83 "sme_u8_nhwc_avg_generic_depthfirst",
84 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
88 return (
args.exclude_padding ||
89 (
args.padding.top == 0 &&
args.padding.bottom == 0 &&
90 args.padding.left == 0 &&
args.padding.right == 0)
91 ) &&
args.pool_type == PoolingType::AVERAGE &&
92 args.cpu_info->has_sme2();
95 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
96 auto strat =
new sme_u8_nhwc_avg_generic_depthfirst(
args.cpu_info);
97 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
101 PoolingMethod::DEPTHFIRST,
102 "sme_u8_nhwc_max_generic_depthfirst",
103 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
107 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
108 auto strat =
new sme_u8_nhwc_max_generic_depthfirst(
args.cpu_info);
109 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
112 #endif // defined(ARM_COMPUTE_ENABLE_SME)
113 #if defined(ARM_COMPUTE_ENABLE_SVE)
115 PoolingMethod::DEPTHFIRST,
116 "sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst",
117 [] (
const PoolingArgs &
args,
const Nothing &os) ->
bool {
118 return args.cpu_info->has_sve() &&
119 is_supported<sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst>(
args, os);
122 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
123 auto strat =
new sve_u8_nhwc_max_2x2_s1_output2x2_depthfirst(
args.cpu_info);
124 return new PoolingDepthfirst<uint8_t>(strat,
args);
128 PoolingMethod::DEPTHFIRST,
129 "sve_u8_nhwc_avg_generic_depthfirst",
130 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
134 return (
args.exclude_padding ||
135 (
args.padding.top == 0 &&
args.padding.bottom == 0 &&
136 args.padding.left == 0 &&
args.padding.right == 0)
137 ) &&
args.pool_type == PoolingType::AVERAGE &&
138 args.cpu_info->has_sve2();
141 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
142 auto strat =
new sve_u8_nhwc_avg_generic_depthfirst(
args.cpu_info);
143 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
147 PoolingMethod::DEPTHFIRST,
148 "sve_u8_nhwc_max_generic_depthfirst",
149 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
153 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
154 auto strat =
new sve_u8_nhwc_max_generic_depthfirst(
args.cpu_info);
155 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
158 #endif // defined(ARM_COMPUTE_ENABLE_SVE)
160 PoolingMethod::DEPTHFIRST,
161 "a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst",
162 is_supported<a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst>,
164 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
165 auto strat =
new a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst(
args.cpu_info);
166 return new PoolingDepthfirst<uint8_t>(strat,
args);
170 PoolingMethod::DEPTHFIRST,
171 "a64_u8_nhwc_avg_generic_depthfirst",
172 [] (
const PoolingArgs &
args,
const Nothing &) ->
bool {
176 return (
args.exclude_padding ||
177 (
args.padding.top == 0 &&
args.padding.bottom == 0 &&
178 args.padding.left == 0 &&
args.padding.right == 0)
179 ) &&
args.pool_type == PoolingType::AVERAGE;
182 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
183 auto strat =
new a64_u8_nhwc_avg_generic_depthfirst(
args.cpu_info);
184 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
188 PoolingMethod::DEPTHFIRST,
189 "a64_u8_nhwc_max_generic_depthfirst",
192 [] (
const PoolingArgs &
args,
const Nothing &) -> PoolingCommon<uint8_t, uint8_t> * {
193 auto strat =
new a64_u8_nhwc_max_generic_depthfirst(
args.cpu_info);
194 return new PoolingDepthfirstGeneric<uint8_t>(strat,
args);
197 #endif // defined(__aarch64__)
198 { PoolingMethod::DEFAULT,
"",
nullptr,
nullptr,
nullptr },
204 return pooling_u8_methods;
207 template UniquePoolingCommon<uint8_t, uint8_t>
pooling(
const PoolingArgs &,
const Nothing &);