39 #ifdef ARM_COMPUTE_ENABLE_SVE
40 #ifdef ARM_COMPUTE_ENABLE_SME2
45 #endif // ARM_COMPUTE_ENABLE_SME2
53 #endif // ARM_COMPUTE_ENABLE_SVE
64 static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
66 #ifdef ARM_COMPUTE_ENABLE_SVE
67 #ifdef ARM_COMPUTE_ENABLE_SME2
71 "sme2_gemv_u8qa_dot_16VL",
74 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemvPretransposed<cls_sme2_gemv_u8qa_dot_16VL, uint8_t, uint8_t, Requantize32>(
args, qp); }
78 "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL",
79 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts ==
nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
80 [](
const GemmArgs &
args,
const Requantize32 &) {
const auto VL = sme::get_vector_length<uint32_t>();
81 return args._Msize <= VL || (2*VL <
args._Msize &&
args._Msize <= 3*VL); },
82 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(
args, qp); }
86 "sme2_interleaved_nomerge_u8q_mopa_4VLx1VL",
87 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts ==
nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
88 [](
const GemmArgs &
args,
const Requantize32 &) {
const auto VL = sme::get_vector_length<int32_t>();
89 return args._Nsize <= VL || (2*VL <
args._Nsize &&
args._Nsize <= 3*VL); },
90 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_4VLx1VL, uint8_t, uint8_t>(
args, qp); }
94 "sme2_interleaved_nomerge_u8q_mopa_2VLx2VL",
95 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts ==
nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
97 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_2VLx2VL, uint8_t, uint8_t>(
args, qp); }
99 #endif // ARM_COMPUTE_ENABLE_SME2
102 "sve_hybrid_u8qa_mmla_4x4VL",
104 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(
args); },
105 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>(
args, qp); }
109 "sve_interleaved_u8u32_mmla_8x3VL",
110 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_svei8mm() && (
args._Ksize>8); },
111 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(
args); },
112 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>(
args, qp); }
116 "sve_hybrid_u8u32_mmla_6x4VL",
117 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_svei8mm(); },
118 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(
args); },
119 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>(
args, qp); }
123 "sve_hybrid_u8qa_dot_4x4VL",
125 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(
args); },
126 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>(
args, qp); }
130 "sve_hybrid_u8u32_dot_6x4VL",
131 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_sve(); },
132 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(
args); },
133 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>(
args, qp); }
137 "sve_interleaved_u8u32_dot_8x3VL",
138 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_sve() && (
args._Ksize>4); },
139 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(
args); },
140 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>(
args, qp); }
145 "a64_hybrid_u8qa_mmla_4x16",
147 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(
args); },
148 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>(
args, qp); }
152 "a64_interleaved_u8u32_mmla_8x12",
153 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_i8mm() && (
args._Ksize>8); },
154 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(
args); },
155 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>(
args, qp); }
159 "a64_hybrid_u8u32_mmla_6x16",
160 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_i8mm(); },
161 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(
args); },
162 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>(
args, qp); }
166 "a64_smallK_hybrid_u8u32_dot_8x4",
167 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize<=32) && !
args._indirect_input; },
168 [](
const GemmArgs &
args,
const Requantize32 &) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
169 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint8_t>(
args, qp); }
173 "a64_smallK_hybrid_u8u32_dot_6x4",
174 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_dotprod() && (
args._Nsize % 4 == 0) && (
args._Ksize>32) && (
args._Ksize<=64) && !
args._indirect_input; },
175 [](
const GemmArgs &
args,
const Requantize32 &) {
return !(
args._ci->has_svei8mm() ||
args._ci->has_i8mm()); },
176 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint8_t>(
args, qp); }
182 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->get_cpu_model() == CPUModel::A53 &&
args._Msize > 4; },
183 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_a64_gemm_u16_8x12, uint8_t, uint8_t>(
args, qp); },
187 "a64_hybrid_u8qa_dot_4x16",
189 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(
args); },
190 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>(
args, qp); }
194 "a64_hybrid_u8u32_dot_6x16",
195 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_dotprod(); },
196 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(
args); },
197 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>(
args, qp); }
202 [](
const GemmArgs &
args,
const Requantize32 &) {
return args._ci->has_dotprod(); },
203 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(
args); },
204 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>(
args, qp); }
210 [](
const GemmArgs &
args,
const Requantize32 &) {
return GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>::estimate_cycles<uint8_t>(
args); },
211 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>(
args, qp); }
216 [](
const GemmArgs &
args,
const Requantize32 &) {
return !
args._indirect_input; },
217 [](
const GemmArgs &,
const Requantize32 &) {
return false; },
218 [](
const GemmArgs &
args,
const Requantize32 &qp) {
return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(
args, qp); }
230 const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_list<uint8_t, uint8_t, Requantize32>() {
231 return gemm_quint8_methods;
234 template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(
const GemmArgs &
args,
const Requantize32 &os);
235 template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(
WeightFormat &weight_format,
const GemmArgs &
args,
const Requantize32 &os);
236 template KernelDescription get_gemm_method<uint8_t, uint8_t, Requantize32>(
const GemmArgs &
args,
const Requantize32 &os);
237 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(
const GemmArgs &
args,
const Requantize32 &os);
241 #endif // __aarch64__