24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
41 #include "utils/Utils.h"
49 using namespace utils;
63 bool interleave_rhs{
true};
64 bool transpose_rhs{
true};
74 ::std::ostream &
operator<<(::std::ostream &os,
const GemmConfigs &configs)
76 std::string false_str = std::string(
"false");
77 std::string true_str = std::string(
"true");
79 os <<
"m0 : " << configs.m0 << std::endl;
80 os <<
"n0 : " << configs.n0 << std::endl;
81 os <<
"k0 : " << configs.k0 << std::endl;
82 os <<
"h0 : " << configs.h0 << std::endl;
83 os <<
"interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
84 os <<
"transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
89 class GemmConfigOptions
96 GemmConfigOptions(CommandLineParser &
parser)
97 : m0(
parser.add_positional_option<SimpleOption<size_t>>(
"m0", 4)),
98 n0(
parser.add_positional_option<SimpleOption<size_t>>(
"n0", 4)),
99 k0(
parser.add_positional_option<SimpleOption<size_t>>(
"k0", 4)),
100 h0(
parser.add_positional_option<SimpleOption<size_t>>(
"h0", 1)),
101 interleave_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_rhs", 1)),
102 transpose_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"transpose_rhs", 1))
104 m0->set_help(
"Number of rows processed by the matrix multiplication");
105 n0->set_help(
"Number of columns processed by the matrix multiplication");
106 k0->set_help(
"Number of partial accumulations performed by the matrix multiplication");
107 h0->set_help(
"Number of horizontal blocks of size (k0xn0) stored on the same output row");
108 interleave_rhs->set_help(
"Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
109 transpose_rhs->set_help(
"Transpose rhs matrix (1) / Do not transpose rhs matrix (0)");
112 GemmConfigOptions(
const GemmConfigOptions &) =
delete;
114 GemmConfigOptions &operator=(
const GemmConfigOptions &) =
delete;
116 GemmConfigOptions(GemmConfigOptions &&) =
default;
118 GemmConfigOptions &operator=(GemmConfigOptions &&) =
default;
120 ~GemmConfigOptions() =
default;
122 SimpleOption<size_t> *m0;
123 SimpleOption<size_t> *n0;
124 SimpleOption<size_t> *k0;
125 SimpleOption<size_t> *h0;
126 SimpleOption<size_t> *interleave_rhs;
127 SimpleOption<size_t> *transpose_rhs;
136 GemmConfigs consume_gemm_configs(
const GemmConfigOptions &options)
139 configs.m0 = options.m0->value();
140 configs.n0 = options.n0->value();
141 configs.k0 = options.k0->value();
142 configs.h0 = options.h0->value();
143 configs.interleave_rhs = options.interleave_rhs->value() != 0;
144 configs.transpose_rhs = options.transpose_rhs->value() != 0;
154 class CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFusedOutputStageFixedpointExample :
public Example
157 bool do_setup(
int argc,
char **argv)
override
166 GemmConfigOptions config_options(
parser);
171 parser.print_help(argv[0]);
177 std::cerr <<
"Invalid arguments." << std::endl;
178 parser.print_help(argv[0]);
179 std::cerr <<
"Falling back to default parameters and configs" << std::endl;
184 configs = consume_gemm_configs(config_options);
187 std::cout <<
"Gemm parameters:" << std::endl;
188 std::cout << params << std::endl;
189 std::cout <<
"Gemm configurations:" << std::endl;
190 std::cout << configs << std::endl;
204 lhs.info()->set_quantization_info(q_info);
205 rhs.info()->set_quantization_info(q_info);
206 bias.info()->set_quantization_info(q_info);
207 dst.info()->set_quantization_info(q_info);
210 lhs_info.
m0 = configs.m0;
211 lhs_info.
k0 = configs.k0;
214 rhs_info.
n0 = configs.n0;
215 rhs_info.
k0 = configs.k0;
216 rhs_info.
h0 = configs.h0;
218 rhs_info.
transpose = configs.transpose_rhs;
221 if (rhs_info.
h0 == 0)
223 rhs_info.
h0 = std::max(
static_cast<unsigned int>(params.
N) / rhs_info.
n0, 1
U);
226 rhs_reshaped.allocator()->init(
228 rhs_reshaped.info()->set_quantization_info(q_info);
233 std::cerr <<
"cl_image is not supported on the device, disable export_to_cl_image" << std::endl;
246 const unsigned int num_filters = 1;
262 std::tie(min_val, max_val) =
get_min_max(
dst.info()->data_type());
264 auto min_activation = min_val.get<int32_t>();
265 auto max_activation = max_val.get<int32_t>();
268 gemmlowp_output_stage.
gemmlowp_offset =
dst.info()->quantization_info().uniform().offset;
274 gemm_info.
m = params.
M;
275 gemm_info.
n = params.
N;
276 gemm_info.
k = params.
K;
285 gemm_info.
a_offset = lhs.info()->quantization_info().uniform().offset;
286 gemm_info.
b_offset = rhs.info()->quantization_info().uniform().offset;
293 vector_sum_row.allocator()->init(info_vector_sum_row);
295 mtx_a_reduction = std::make_unique<ClGemmLowpMatrixAReduction>();
299 std::cerr <<
"Invalid arguments for CLGEMMLowpMatrixAReductionKernel." << std::endl;
309 vector_sum_col.allocator()->init(info_vector_sum_col);
314 if (!
gemm.validate(lhs.info(), rhs_reshaped.info(),
dst.info(), gemm_info,
315 gemm_info.
a_offset == 0 ?
nullptr : vector_sum_col.info(),
316 gemm_info.
b_offset == 0 ?
nullptr : vector_sum_row.info(),
bias.info(),
317 dst_multipliers.info(), dst_shifts.info()))
319 std::cerr <<
"Invalid arguments for ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel." << std::endl;
324 gemm.configure(lhs.info(), rhs_reshaped.info(),
dst.info(), gemm_info,
325 gemm_info.
a_offset == 0 ?
nullptr : vector_sum_col.info(),
326 gemm_info.
b_offset == 0 ?
nullptr : vector_sum_row.info(),
bias.info(), dst_multipliers.info(),
330 lhs.allocator()->allocate();
331 rhs.allocator()->allocate();
332 rhs_reshaped.allocator()->allocate();
333 bias.allocator()->allocate();
334 dst.allocator()->allocate();
335 vector_sum_col.allocator()->allocate();
336 vector_sum_row.allocator()->allocate();
337 dst_multipliers.allocator()->allocate();
338 dst_shifts.allocator()->allocate();
342 void do_run()
override
344 if (mtx_a_reduction !=
nullptr)
347 mtx_a_reduction->run(red_pack);
364 void do_teardown()
override
380 std::unique_ptr<ClGemmLowpMatrixAReduction> mtx_a_reduction{
nullptr};
388 int main(
int argc,
char **argv)
390 return run_example<CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFusedOutputStageFixedpointExample>(argc, argv);