42 #ifdef ARM_COMPUTE_ENABLE_SVE
43 #ifdef ARM_COMPUTE_ENABLE_SME2
47 #endif // ARM_COMPUTE_ENABLE_SME2
53 #endif // ARM_COMPUTE_ENABLE_SVE
57 static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = {
58 #ifdef ARM_COMPUTE_ENABLE_SVE
59 #ifdef ARM_COMPUTE_ENABLE_SME2
63 "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL",
64 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
65 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<int32_t>();
66 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
67 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(
args); }
71 "sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL",
72 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
73 [](
const GemmArgs &
args) {
const auto VL = sme::get_vector_length<int32_t>();
74 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
75 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL, int8_t, int32_t>(
args); }
79 "sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL",
80 [](
const GemmArgs &
args) {
return args._ci->has_sme2(); },
82 [](
const GemmArgs &
args) {
return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL, int8_t, int32_t>(
args); }
84 #endif // ARM_COMPUTE_ENABLE_SME2
87 "sve_hybrid_s8s32_mmla_6x4VL",
88 [](
const GemmArgs &
args) {
return args._ci->has_svei8mm(); },
89 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_s8s32_mmla_6x4VL, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
90 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_s8s32_mmla_6x4VL, int8_t, int32_t>(
args); }
94 "sve_interleaved_s8s32_mmla_8x3VL",
95 [](
const GemmArgs &
args) {
return args._ci->has_svei8mm() && (
args._Ksize>8); },
96 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
97 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int32_t>(
args); }
101 "sve_hybrid_s8s32_dot_6x4VL",
102 [](
const GemmArgs &
args) {
return args._ci->has_sve() &&
args._Ksize>=16; },
103 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
104 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int32_t>(
args); }
108 "sve_interleaved_s8s32_dot_8x3VL",
109 [](
const GemmArgs &
args) {
return args._ci->has_sve() && (
args._Ksize>4); },
110 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
111 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int32_t>(
args); }
116 "a64_interleaved_s8s32_mmla_8x12",
117 [](
const GemmArgs &
args) {
return args._ci->has_i8mm() && (
args._Ksize>8); },
118 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
119 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int32_t>(
args); }
123 "a64_hybrid_s8s32_mmla_6x16",
124 [](
const GemmArgs &
args) {
return args._ci->has_i8mm(); },
125 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_s8s32_mmla_6x16, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
126 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_s8s32_mmla_6x16, int8_t, int32_t>(
args); }
130 "a64_smallK_hybrid_s8s32_dot_8x4",
131 [](
const GemmArgs &
args) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize<=32) && !
args._indirect_input; },
132 [](
const GemmArgs &
args) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
133 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(
args); }
137 "a64_smallK_hybrid_s8s32_dot_6x4",
138 [](
const GemmArgs &
args) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize>32) && (
args._Ksize<=64) && !
args._indirect_input; },
139 [](
const GemmArgs &
args) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
140 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(
args); }
146 [](
const GemmArgs &
args) {
return args._ci->get_cpu_model() == CPUModel::A53 && ((
args._Msize > 28) || ((
args._Msize % 8) > 4)); },
147 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_s16_8x12, int8_t, int32_t>(
args); },
152 "a64_hybrid_s8s32_dot_6x16",
153 [](
const GemmArgs &
args) {
return args._ci->has_dotprod(); },
154 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
155 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int32_t>(
args); }
160 [](
const GemmArgs &
args) {
return args._ci->has_dotprod(); },
161 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_gemm_s8_8x12, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
162 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_s8_8x12, int8_t, int32_t>(
args); }
168 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_gemm_s8_4x4, int8_t, int32_t>::estimate_cycles<int32_t>(
args); },
169 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_s8_4x4, int8_t, int32_t>(
args); }
182 const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int32_t>() {
183 return gemm_s8_methods;
187 template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(
const GemmArgs &
args,
const Nothing &);
188 template bool has_opt_gemm<int8_t, int32_t, Nothing>(
WeightFormat &weight_format,
const GemmArgs &
args,
const Nothing &);
189 template KernelDescription get_gemm_method<int8_t, int32_t, Nothing>(
const GemmArgs &
args,
const Nothing &);
190 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (
const GemmArgs &
args,
const Nothing &);
194 #endif // __aarch64__