34 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
39 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
51 #ifdef ARM_COMPUTE_ENABLE_SVE
52 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
57 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
58 #ifdef ARM_COMPUTE_ENABLE_SME2
67 #endif // ARM_COMPUTE_ENABLE_SME2
80 #endif // ARM_COMPUTE_ENABLE_SVE
84 static const GemmImplementation<float, float> gemm_fp32_methods[] =
90 [](
const GemmArgs &
args) {
return args._Msize==1 &&
args._nbatches>1 && !
args._indirect_input; },
92 [](
const GemmArgs &
args) {
return new GemvBatched<float, float>(
args); }
95 #ifdef ARM_COMPUTE_ENABLE_BF16
99 "a64_interleaved_bf16fp32_mmla_8x12",
100 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
101 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(
args); },
102 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>(
args); }
107 "a64_hybrid_fp32bf16fp32_mmla_6x16",
108 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
109 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(
args); },
110 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>(
args); }
114 "a64_hybrid_fp32bf16fp32_mmla_4x24",
115 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
116 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(
args); },
117 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>(
args); }
120 #ifdef ARM_COMPUTE_ENABLE_SVE
121 #ifdef ARM_COMPUTE_ENABLE_SME2
125 "sme2_gemv_fp32bf16fp32_dot_16VL",
126 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_sme2() &&
args._Msize==1 &&
args._nbatches==1 && !
args._indirect_input; },
128 [](
const GemmArgs &
args) {
return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(
args); }
132 "sme2_gemv_fp32_mla_16VL",
133 [](
const GemmArgs &
args) {
return args._ci->has_sme2() &&
args._Msize==1 &&
args._nbatches==1 && !
args._indirect_input; },
135 [](
const GemmArgs &
args) {
return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(
args); }
137 #ifdef ARM_COMPUTE_ENABLE_BF16
140 "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
141 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_sme2(); },
142 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
143 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
144 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(
args); }
146 #endif // ARM_COMPUTE_ENABLE_BF16
149 "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
150 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
151 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
152 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
153 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(
args); }
155 #ifdef ARM_COMPUTE_ENABLE_BF16
158 "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
159 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_sme2(); },
160 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
161 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
162 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(
args); }
164 #endif // ARM_COMPUTE_ENABLE_BF16
167 "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL",
168 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
169 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
170 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
171 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(
args); }
173 #ifdef ARM_COMPUTE_ENABLE_BF16
176 "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
177 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_sme2(); },
179 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(
args); }
181 #endif // ARM_COMPUTE_ENABLE_BF16
184 "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL",
185 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
187 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(
args); }
189 #endif // ARM_COMPUTE_ENABLE_SME2
190 #ifdef ARM_COMPUTE_ENABLE_BF16
193 "sve_interleaved_bf16fp32_mmla_8x3VL",
194 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_svebf16(); },
195 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(
args); },
196 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>(
args); }
200 "sve_hybrid_fp32bf16fp32_mmla_6x4VL",
201 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
202 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(
args); },
203 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(
args); }
207 "sve_hybrid_fp32bf16fp32_mmla_4x6VL",
208 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
209 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(
args); },
210 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(
args); }
213 #ifdef ARM_COMPUTE_ENABLE_SVEF32MM
218 "sve_interleaved_fp32_mmla_8x3VL",
219 [](
const GemmArgs &
args) {
return args._ci->has_svef32mm() && (
args._Ksize>4); },
220 [](
const GemmArgs &
args) {
return !(
args._fast_mode &&
args._ci->has_bf16()); },
221 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_fp32_mmla_8x3VL, float, float>(
args); }
223 #endif // ARM_COMPUTE_ENABLE_SVEF32MM
227 "sve_hybrid_fp32_mla_8x1VL",
228 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
229 [](
const GemmArgs &
args) {
return (
args._Nsize < 12); },
230 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_8x1VL, float, float>(
args); }
234 "sve_hybrid_fp32_mla_6x4VL",
235 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
236 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(
args); },
237 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>(
args); }
241 "sve_interleaved_fp32_mla_8x3VL",
242 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
243 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(
args); },
244 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(
args); }
246 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
247 #ifdef ARM_COMPUTE_ENABLE_BF16
250 "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
252 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_svebf16(); },
253 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(
args); },
254 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>(
args); }
258 "sve_ffhybrid_fp32bf16fp32_mmla_4x6VL",
260 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_svebf16(); },
261 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(
args); },
262 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>(
args); }
267 "sve_ffinterleaved_fp32_mla_8x3VL",
269 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
270 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(
args); },
271 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>(
args); }
275 "sve_ffhybrid_fp32_mla_6x4VL",
277 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
278 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(
args); },
279 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(
args); }
288 [](
const GemmArgs &
args) {
return args._ci->get_cpu_model() == CPUModel::A35; },
289 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_sgemm_8x6, float, float>(
args); }
294 "a64_smallK_hybrid_fp32_mla_8x4",
295 [](
const GemmArgs &
args) {
return args._Ksize <= 8 && (
args._Nsize % 4)==0 && !
args._indirect_input; },
297 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(
args); }
301 "a64_smallK_hybrid_fp32_mla_6x4",
302 [](
const GemmArgs &
args) {
return (
args._Ksize > 8 &&
args._Ksize <= 16) && (
args._Nsize % 4)==0 && !
args._indirect_input; },
304 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(
args); }
308 "a64_hybrid_fp32_mla_8x4",
310 [](
const GemmArgs &
args) {
return (
args._Nsize < 12); },
311 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_8x4, float, float>(
args); }
315 "a64_hybrid_fp32_mla_4x24",
317 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>::estimate_cycles<float>(
args); },
318 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>(
args); }
322 "a64_hybrid_fp32_mla_6x16",
324 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(
args); },
325 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>(
args); }
331 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(
args); },
332 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(
args); }
334 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
335 #ifdef ARM_COMPUTE_ENABLE_BF16
339 "a64_ffinterleaved_bf16fp32_mmla_8x12",
341 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
342 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(
args); },
343 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>(
args); }
347 "a64_ffhybrid_fp32bf16fp32_mmla_4x24",
349 [](
const GemmArgs &
args) {
return args._fast_mode &&
args._ci->has_bf16(); },
350 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(
args); },
351 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(
args); }
356 "a64_ffinterleaved_fp32_mla_8x12",
359 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>::estimate_cycles<float>(
args); },
360 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>(
args); }
364 "a64_ffhybrid_fp32_mla_6x16",
367 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(
args); },
368 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(
args); }
379 [](
const GemmArgs &
args) {
return new GemmInterleaved<sgemm_8x6, float, float>(
args); }
394 return gemm_fp32_methods;