49 static const GemmImplementation<uint8_t, uint32_t> gemm_u8_methods[] = {
50 #ifdef ARM_COMPUTE_ENABLE_SVE
53 "sve_hybrid_u8u32_mmla_6x4VL",
54 [](
const GemmArgs &
args) {
return args._ci->has_svei8mm(); },
55 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
56 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint32_t>(
args); }
60 "sve_interleaved_u8u32_mmla_8x3VL",
61 [](
const GemmArgs &
args) {
return args._ci->has_svei8mm() && (
args._Ksize>8); },
62 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
63 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>(
args); }
67 "sve_hybrid_u8u32_dot_6x4VL",
68 [](
const GemmArgs &
args) {
return args._ci->has_sve(); },
69 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
70 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>(
args); }
74 "sve_interleaved_u8u32_dot_8x3VL",
75 [](
const GemmArgs &
args) {
return args._ci->has_sve() && (
args._Ksize>4); },
76 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
77 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>(
args); }
82 "a64_interleaved_u8u32_mmla_8x12",
83 [](
const GemmArgs &
args) {
return args._ci->has_i8mm() && (
args._Ksize>8); },
84 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
85 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>(
args); }
89 "a64_hybrid_u8u32_mmla_6x16",
90 [](
const GemmArgs &
args) {
return args._ci->has_i8mm(); },
91 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
92 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint32_t>(
args); }
96 "a64_smallK_hybrid_u8u32_dot_8x4",
97 [](
const GemmArgs &
args) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize<=32) && !
args._indirect_input; },
98 [](
const GemmArgs &
args) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
99 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(
args); }
103 "a64_smallK_hybrid_u8u32_dot_6x4",
104 [](
const GemmArgs &
args) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize>32) && (
args._Ksize<=64) && !
args._indirect_input; },
105 [](
const GemmArgs &
args) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
106 [](
const GemmArgs &
args) {
return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(
args); }
112 [](
const GemmArgs &
args) {
return args._ci->get_cpu_model() == CPUModel::A53 &&
args._Msize > 4; },
113 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_u16_8x12, uint8_t, uint32_t>(
args); },
117 "a64_hybrid_u8u32_dot_6x16",
118 [](
const GemmArgs &
args) {
return args._ci->has_dotprod(); },
119 [](
const GemmArgs &
args) {
return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
120 [](
const GemmArgs &
args) {
return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>(
args); }
125 [](
const GemmArgs &
args) {
return args._ci->has_dotprod(); },
126 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
127 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>(
args); }
133 [](
const GemmArgs &
args) {
return GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>::estimate_cycles<uint32_t>(
args); },
134 [](
const GemmArgs &
args) {
return new GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>(
args); }
146 const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, uint32_t>() {
147 return gemm_u8_methods;
151 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(
const GemmArgs &
args,
const Nothing &);
152 template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(
WeightFormat &weight_format,
const GemmArgs &
args,
const Nothing &);
153 template KernelDescription get_gemm_method<uint8_t, uint32_t, Nothing>(
const GemmArgs &
args,
const Nothing &);
154 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (
const GemmArgs &
args,
const Nothing &);
158 #endif // __aarch64__