24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
58 #include "utils/Utils.h"
67 using namespace utils;
79 class GEMMCommandLineOptions final
82 explicit GEMMCommandLineOptions(CommandLineParser &
parser) noexcept
83 :
help(
parser.add_option<ToggleOption>(
"help")),
84 add_bias(
parser.add_option<ToggleOption>(
"add_bias")),
85 M(
parser.add_option<SimpleOption<int>>(
"m", 7)),
86 N(
parser.add_option<SimpleOption<int>>(
"n", 3)),
87 K(
parser.add_option<SimpleOption<int>>(
"k", 5)),
88 B(
parser.add_option<SimpleOption<int>>(
"b", 1)),
89 alpha(
parser.add_option<SimpleOption<float>>(
"alpha", 1.f)),
90 beta(
parser.add_option<SimpleOption<float>>(
"beta", 0.f)),
91 offset_src0(
parser.add_option<SimpleOption<int>>(
"offset_i0", 10)),
92 offset_src1(
parser.add_option<SimpleOption<int>>(
"offset_i1", 10)),
93 offset_dst(
parser.add_option<SimpleOption<int>>(
"offset_o", 10)),
94 scale_src0(
parser.add_option<SimpleOption<float>>(
"scale_i0", 1.f / 255)),
95 scale_src1(
parser.add_option<SimpleOption<float>>(
"scale_i1", 1.f / 255)),
96 scale_dst(
parser.add_option<SimpleOption<float>>(
"scale_o", 1.f / 255)),
100 const std::set<arm_compute::DataType> supported_data_types
109 help->set_help(
"Show this help message");
110 add_bias->set_help(
"Add bias to the GEMM. Used when running in QASYMM8");
111 M->set_help(
"M value");
112 N->set_help(
"N value");
113 K->set_help(
"K value");
114 B->set_help(
"B value - number of batches");
115 alpha->set_help(
"Alpha value");
116 beta->set_help(
"Beta value");
117 offset_src0->set_help(
"Offset of first input. Used when running in QASYMM8");
118 offset_src1->set_help(
"Offset of second input. Used when running in QASYMM8");
119 offset_dst->set_help(
"Offset of output. Used when running in QASYMM8");
120 scale_src0->set_help(
"Scale of first input. Used when running in QASYMM8");
121 scale_src1->set_help(
"Scale of second input. Used when running in QASYMM8");
122 scale_dst->set_help(
"Scale of output. Used when running in QASYMM8");
126 GEMMCommandLineOptions(
const GEMMCommandLineOptions &) =
delete;
128 GEMMCommandLineOptions &operator=(
const GEMMCommandLineOptions &) =
delete;
130 GEMMCommandLineOptions(GEMMCommandLineOptions &&) noexcept(
true) =
default;
132 GEMMCommandLineOptions &operator=(GEMMCommandLineOptions &&) noexcept(
true) =
default;
134 ~GEMMCommandLineOptions() =
default;
138 ToggleOption *add_bias;
139 SimpleOption<int> *
M;
140 SimpleOption<int> *
N;
141 SimpleOption<int> *
K;
142 SimpleOption<int> *
B;
143 SimpleOption<float> *alpha;
144 SimpleOption<float> *beta;
145 SimpleOption<int> *offset_src0;
146 SimpleOption<int> *offset_src1;
147 SimpleOption<int> *offset_dst;
148 SimpleOption<float> *scale_src0;
149 SimpleOption<float> *scale_src1;
150 SimpleOption<float> *scale_dst;
151 EnumOption<arm_compute::DataType> *
data_type;
155 class CLGEMMValidateExample :
public ValidateExample
158 bool do_setup(
int argc,
char **argv)
override
164 GEMMCommandLineOptions gemm_options(
parser);
168 const bool print_help = gemm_options.help->is_set() ? gemm_options.help->value() :
false;
171 parser.print_help(argv[0]);
176 consume_params(gemm_options);
177 print_parameters_internal();
184 float multiplier = scale_src0 * scale_src1 / scale_dst;
197 src0.info()->set_quantization_info(
QuantizationInfo(scale_src0, offset_src0));
198 src1.info()->set_quantization_info(
QuantizationInfo(scale_src1, offset_src1));
204 mm_gemmlowp.configure(&src0, &src1,
nullptr, &tmp_dst);
209 gemm_info.gemmlowp_shift = dst_shift;
210 gemm_info.gemmlowp_offset = offset_dst;
211 mm_gemmlowp_output_stage.configure(&tmp_dst, add_bias ? &biases :
nullptr, &
dst, gemm_info);
212 tmp_dst.allocator()->allocate();
213 biases.allocator()->allocate();
219 mm_gemm.configure(&src0, &src1, &src2, &
dst, alpha, beta);
223 src0.allocator()->allocate();
224 src1.allocator()->allocate();
225 dst.allocator()->allocate();
226 src2.allocator()->allocate();
235 void print_parameters_internal()
259 void do_validate()
override
273 SimpleTensor<half> ref_dst = reference::gemm<half>(ref_src0, ref_src1, ref_src2, alpha, beta);
287 SimpleTensor<float> ref_dst = reference::gemm<float>(ref_src0, ref_src1, ref_src2, alpha, beta);
303 const std::vector<int32_t> dst_multiplier_vec = { dst_multiplier };
304 const std::vector<int32_t> dst_shift_vec = { dst_shift };
311 ref_dst = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(ref_tmp_dst, biases, dst_multiplier_vec, dst_shift_vec, offset_dst);
315 ref_dst = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(ref_tmp_dst, dst_multiplier_vec, dst_shift_vec, offset_dst);
324 void do_run()
override
332 mm_gemmlowp_output_stage.run();
345 template <
typename U>
348 switch(
tensor.data_type())
358 std::uniform_real_distribution<float>
distribution(-1.0f, 1.0f);
365 std::uniform_int_distribution<>
distribution(-6000, 6000);
374 void consume_params(
const GEMMCommandLineOptions &opts)
384 alpha = opts.alpha->value();
385 beta = opts.beta->value();
386 offset_src0 = opts.offset_src0->value();
387 offset_src1 = opts.offset_src1->value();
388 offset_dst = opts.offset_dst->value();
389 scale_src0 = opts.scale_src0->value();
390 scale_src1 = opts.scale_src1->value();
391 scale_dst = opts.scale_dst->value();
392 add_bias = opts.add_bias->is_set() ? opts.add_bias->value() :
true;
403 size_t M{ 7 },
N{ 3 },
K{ 5 },
B{ 1 };
405 float alpha{ 1.0 }, beta{ 0.0 };
406 int offset_src0{ 10 }, offset_src1{ 10 }, offset_dst{ 10 };
407 float scale_src0{ 1.0f / 255 }, scale_src1{ 1.0f / 255 }, scale_dst{ 1.0f / 255 };
408 int32_t dst_multiplier{ 0 }, dst_shift{ 0 };
409 bool add_bias{
true };
418 int main(
int argc,
char **argv)
420 return utils::run_example<CLGEMMValidateExample>(argc, argv);