26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
38 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
41 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
45 #ifdef ARM_COMPUTE_ENABLE_SME2
50 #endif // ARM_COMPUTE_ENABLE_SME2
51 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
54 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
60 static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
61 #ifdef ARM_COMPUTE_ENABLE_SVE
62 #ifdef ARM_COMPUTE_ENABLE_SME2
65 "sme2_gemv_fp16fp32fp16_dot_16VL",
66 [](
const GemmArgs &
args) {
return args._ci->has_sme2() &&
args._Msize==1 &&
args._nbatches==1 && !
args._indirect_input; },
68 [](
const GemmArgs &
args) {
return new GemvPretransposed<cls_sme2_gemv_fp16fp32fp16_dot_16VL, __fp16, __fp16>(
args); }
72 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL",
73 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
74 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
75 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
76 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(
args); }
80 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL",
81 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
82 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<float>();
83 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
84 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(
args); }
88 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL",
89 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
91 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL, __fp16, __fp16, Nothing, false, false, false, true>(
args); }
93 #endif // ARM_COMPUTE_ENABLE_SME2
96 "sve_hybrid_fp16_mla_6x4VL",
97 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
98 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
99 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>(
args); }
103 "sve_interleaved_fp16_mla_8x3VL",
104 [](
const GemmArgs &
args) {
return args._ci->has_sve() && (
args._Ksize > 4); },
105 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
106 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(
args); }
108 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
111 "sve_ffinterleaved_fp16_mla_8x3VL",
113 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
114 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
115 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(
args); }
119 "sve_ffhybrid_fp16_mla_6x4VL",
121 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
122 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
123 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(
args); }
127 #
if defined(__aarch64__)
130 "a64_hybrid_fp16_mla_6x32",
131 [](
const GemmArgs &
args) {
return args._ci->has_fp16(); },
132 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
133 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>(
args); }
138 [](
const GemmArgs &
args) {
return args._ci->has_fp16(); },
139 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
140 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(
args); }
142 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
145 "a64_ffinterleaved_fp16_mla_8x24",
147 [](
const GemmArgs &
args) {
return args._ci->has_fp16(); },
148 [](
const GemmArgs &
args) {
return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
149 [](
const GemmArgs &
args) {
return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(
args); }
153 "a64_ffhybrid_fp16_mla_6x32",
155 [](
const GemmArgs &
args) {
return args._ci->has_fp16(); },
156 [](
const GemmArgs &
args) {
return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(
args); },
157 [](
const GemmArgs &
args) {
return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(
args); }
164 [](
const GemmArgs &
args) {
return !
args._ci->has_fp16(); },
165 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16>(
args); }
167 #elif defined(__arm__)
173 [](
const GemmArgs &
args) {
return new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(
args); }
175 #else // not AArch64 or AArch32
176 # error Unknown Architecture
188 const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() {
189 return gemm_fp16_methods;
193 template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(
const GemmArgs &
args,
const Nothing &);
194 template bool has_opt_gemm<__fp16, __fp16, Nothing>(
WeightFormat &weight_format,
const GemmArgs &
args,
const Nothing &);
195 template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(
const GemmArgs &
args,
const Nothing &);
196 template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(
const GemmArgs &
args,
const Nothing &);
200 #endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))