24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
42 #include "utils/Utils.h"
48 using namespace utils;
62 bool interleave_lhs{
true};
63 bool transpose_lhs{
true};
64 bool interleave_rhs{
true};
65 bool transpose_rhs{
true};
75 ::std::ostream &
operator<<(::std::ostream &os,
const GemmConfigs &configs)
77 std::string false_str = std::string(
"false");
78 std::string true_str = std::string(
"true");
80 os <<
"m0 : " << configs.m0 << std::endl;
81 os <<
"n0 : " << configs.n0 << std::endl;
82 os <<
"k0 : " << configs.k0 << std::endl;
83 os <<
"v0 : " << configs.v0 << std::endl;
84 os <<
"h0 : " << configs.h0 << std::endl;
85 os <<
"interleave_lhs : " << (configs.interleave_lhs ? true_str : false_str) << std::endl;
86 os <<
"transpose_lhs : " << (configs.transpose_lhs ? true_str : false_str) << std::endl;
87 os <<
"interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
88 os <<
"transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
93 class GemmConfigOptions
100 GemmConfigOptions(CommandLineParser &
parser)
101 : m0(
parser.add_positional_option<SimpleOption<size_t>>(
"m0", 4)),
102 n0(
parser.add_positional_option<SimpleOption<size_t>>(
"n0", 4)),
103 k0(
parser.add_positional_option<SimpleOption<size_t>>(
"k0", 4)),
104 v0(
parser.add_positional_option<SimpleOption<size_t>>(
"v0", 1)),
105 h0(
parser.add_positional_option<SimpleOption<size_t>>(
"h0", 1)),
106 interleave_lhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_lhs", 1)),
107 interleave_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_rhs", 1)),
108 transpose_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"transpose_rhs", 1))
110 m0->set_help(
"Number of rows processed by the matrix multiplication");
111 n0->set_help(
"Number of columns processed by the matrix multiplication");
112 k0->set_help(
"Number of partial accumulations performed by the matrix multiplication");
113 v0->set_help(
"Number of vertical blocks of size (m0xk0) stored on the same output row");
114 h0->set_help(
"Number of horizontal blocks of size (k0xn0) stored on the same output row");
115 interleave_lhs->set_help(
"Interleave lhs matrix (1) / Do not interleave lhs matrix (0)");
116 interleave_rhs->set_help(
"Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
120 transpose_rhs->set_help(
"Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do "
121 "transpose lhs matrix (0)");
124 GemmConfigOptions(
const GemmConfigOptions &) =
delete;
126 GemmConfigOptions &operator=(
const GemmConfigOptions &) =
delete;
128 GemmConfigOptions(GemmConfigOptions &&) =
default;
130 GemmConfigOptions &operator=(GemmConfigOptions &&) =
default;
132 ~GemmConfigOptions() =
default;
134 SimpleOption<size_t> *m0;
135 SimpleOption<size_t> *n0;
136 SimpleOption<size_t> *k0;
137 SimpleOption<size_t> *v0;
138 SimpleOption<size_t> *h0;
139 SimpleOption<size_t> *interleave_lhs;
140 SimpleOption<size_t> *interleave_rhs;
144 SimpleOption<size_t> *
154 GemmConfigs consume_gemm_configs(
const GemmConfigOptions &options)
157 configs.m0 = options.m0->value();
158 configs.n0 = options.n0->value();
159 configs.k0 = options.k0->value();
160 configs.v0 = options.v0->value();
161 configs.h0 = options.h0->value();
162 configs.interleave_lhs = options.interleave_lhs->value() != 0;
166 configs.transpose_lhs = options.transpose_rhs->value() == 0;
167 configs.interleave_rhs = options.interleave_rhs->value() != 0;
168 configs.transpose_rhs = options.transpose_rhs->value() != 0;
177 class CLGEMMLowpMatrixMultiplyReshapedExample :
public Example
180 bool do_setup(
int argc,
char **argv)
override
189 GemmConfigOptions config_options(
parser);
194 parser.print_help(argv[0]);
200 std::cerr <<
"Invalid arguments." << std::endl;
201 parser.print_help(argv[0]);
202 std::cerr <<
"Falling back to default parameters and configs" << std::endl;
207 configs = consume_gemm_configs(config_options);
210 std::cout <<
"Gemm parameters:" << std::endl;
211 std::cout << params << std::endl;
212 std::cout <<
"Gemm configurations:" << std::endl;
213 std::cout << configs << std::endl;
224 lhs.info()->set_quantization_info(q_info);
225 rhs.info()->set_quantization_info(q_info);
226 dst.info()->set_quantization_info(q_info);
229 lhs_info.
m0 = configs.m0;
230 lhs_info.
k0 = configs.k0;
231 lhs_info.
v0 = configs.v0;
233 lhs_info.
transpose = configs.transpose_lhs;
236 rhs_info.
n0 = configs.n0;
237 rhs_info.
k0 = configs.k0;
238 rhs_info.
h0 = configs.h0;
240 rhs_info.
transpose = configs.transpose_rhs;
243 if (rhs_info.
h0 == 0)
245 rhs_info.
h0 = std::max(
static_cast<unsigned int>(params.
N) / rhs_info.
n0, 1
U);
248 lhs_reshaped.allocator()->init(
250 rhs_reshaped.allocator()->init(
252 lhs_reshaped.info()->set_quantization_info(q_info);
253 rhs_reshaped.info()->set_quantization_info(q_info);
259 std::cerr <<
"cl_image is not supported on the device, disable export_to_cl_image" << std::endl;
265 static_cast<int>(params.
N),
266 static_cast<int>(params.
K),
267 static_cast<int>(configs.h0),
268 static_cast<int>(configs.v0),
274 if (!reshape_lhs.validate(lhs.info(), lhs_reshaped.info(), lhs_info, gemm_info.reinterpret_input_as_3d()))
276 std::cerr <<
"Invalid arguments for ClGemmReshapeLHSMatrixKernel." << std::endl;
280 if (!
gemm.validate(lhs_reshaped.info(), rhs_reshaped.info(),
dst.info(), lhs_info, rhs_info, gemm_info))
282 std::cerr <<
"Invalid arguments for ClGemmLowpMatrixMultiplyReshapedKernel." << std::endl;
287 reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
289 gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(),
dst.info(), lhs_info, rhs_info, gemm_info);
292 lhs.allocator()->allocate();
293 rhs.allocator()->allocate();
294 lhs_reshaped.allocator()->allocate();
295 rhs_reshaped.allocator()->allocate();
296 dst.allocator()->allocate();
300 void do_run()
override
303 reshape_lhs.run(reshape_lsh_pack);
312 void do_teardown()
override
332 int main(
int argc,
char **argv)
334 return run_example<CLGEMMLowpMatrixMultiplyReshapedExample>(argc, argv);