36 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
40 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
47 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
50 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
52 #ifdef ARM_COMPUTE_ENABLE_SVE
53 #ifdef ARM_COMPUTE_ENABLE_SME2
58 #endif // ARM_COMPUTE_ENABLE_SME2
66 #endif // ARM_COMPUTE_ENABLE_SVE
70 static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
73 #ifdef ARM_COMPUTE_ENABLE_BF16
74 #ifdef ARM_COMPUTE_ENABLE_SVE
75 #ifdef ARM_COMPUTE_ENABLE_SME2
79 "sme2_gemv_bf16fp32_dot_16VL",
80 [](
const GemmArgs &
args) {
return args._ci->has_sme2() &&
args._Msize==1 &&
args._nbatches==1 && !
args._indirect_input; },
82 [](
const GemmArgs &
args) {
return new GemvPretransposed<cls_sme2_gemv_bf16fp32_dot_16VL, bfloat16, float>(
args); }
86 "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
87 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
88 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
89 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
90 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(
args); }
94 "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
95 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
96 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
97 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
98 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, bfloat16, float>(
args); }
102 "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
103 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
105 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, bfloat16, float>(
args); }
107 #endif // ARM_COMPUTE_ENABLE_SME2
111 "sve_interleaved_bf16fp32_mmla_8x3VL",
112 [](
const GemmArgs &
args) {
return args._ci->has_svebf16() && (
args._Ksize>4); },
113 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
114 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(
args); }
118 "sve_hybrid_bf16fp32_mmla_6x4VL",
119 [](
const GemmArgs &
args) {
return args._ci->has_svebf16(); },
120 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
121 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(
args); }
125 "sve_hybrid_bf16fp32_dot_6x4VL",
126 [](
const GemmArgs &
args) {
return args._ci->has_svebf16(); },
127 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
128 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>(
args); }
132 "sve_interleaved_bf16fp32_dot_8x3VL",
133 [](
const GemmArgs &
args) {
return args._ci->has_svebf16() && (
args._Ksize>2); },
134 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
135 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(
args); }
137 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
140 "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
142 [](
const GemmArgs &
args) {
return args._ci->has_svebf16(); },
143 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
144 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(
args); }
148 "sve_ffhybrid_bf16fp32_mmla_6x4VL",
150 [](
const GemmArgs &
args) {
return args._ci->has_svebf16(); },
151 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
152 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(
args); }
158 "a64_hybrid_bf16fp32_mmla_6x16",
159 [](
const GemmArgs &
args) {
return args._ci->has_bf16(); },
160 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
161 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>(
args); }
165 "a64_interleaved_bf16fp32_mmla_8x12",
166 [](
const GemmArgs &
args) {
return args._ci->has_bf16() && (
args._Ksize>4); },
167 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
168 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>(
args); }
172 "a64_hybrid_bf16fp32_dot_6x16",
173 [](
const GemmArgs &
args) {
return args._ci->has_bf16(); },
174 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
175 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>(
args); }
179 "a64_interleaved_bf16fp32_dot_8x12",
180 [](
const GemmArgs &
args) {
return args._ci->has_bf16() && (
args._Ksize>2); },
181 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
182 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(
args); }
184 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
187 "a64_ffinterleaved_bf16fp32_mmla_8x12",
189 [](
const GemmArgs &
args) {
return args._ci->has_bf16(); },
190 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
191 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>(
args); }
195 "a64_ffhybrid_bf16fp32_mmla_6x16",
197 [](
const GemmArgs &
args) {
return args._ci->has_bf16(); },
198 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
199 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>(
args); }
203 "a64_ffinterleaved_bf16fp32_dot_8x12",
205 [](
const GemmArgs &
args) {
return args._ci->has_bf16(); },
206 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
207 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(
args); }
214 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>::estimate_cycles<bfloat16>(
args); },
215 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>(
args); }
230 return gemm_bf16_methods;