24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
42 #include "utils/Utils.h"
48 using namespace utils;
62 bool interleave_lhs{
true};
63 bool transpose_lhs{
true};
64 bool interleave_rhs{
true};
65 bool transpose_rhs{
true};
66 bool export_to_cl_image_rhs{
true};
76 ::std::ostream &
operator<<(::std::ostream &os,
const GemmConfigs &configs)
78 std::string false_str = std::string(
"false");
79 std::string true_str = std::string(
"true");
81 os <<
"m0 : " << configs.m0 << std::endl;
82 os <<
"n0 : " << configs.n0 << std::endl;
83 os <<
"k0 : " << configs.k0 << std::endl;
84 os <<
"v0 : " << configs.v0 << std::endl;
85 os <<
"h0 : " << configs.h0 << std::endl;
86 os <<
"interleave_lhs : " << (configs.interleave_lhs ? true_str : false_str) << std::endl;
87 os <<
"transpose_lhs : " << (configs.transpose_lhs ? true_str : false_str) << std::endl;
88 os <<
"interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
89 os <<
"transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
90 os <<
"export_to_cl_image_rhs : " << (configs.export_to_cl_image_rhs ? true_str : false_str) << std::endl;
95 class GemmConfigOptions
102 GemmConfigOptions(CommandLineParser &
parser)
103 : m0(
parser.add_positional_option<SimpleOption<size_t>>(
"m0", 4)),
104 n0(
parser.add_positional_option<SimpleOption<size_t>>(
"n0", 4)),
105 k0(
parser.add_positional_option<SimpleOption<size_t>>(
"k0", 4)),
106 v0(
parser.add_positional_option<SimpleOption<size_t>>(
"v0", 1)),
107 h0(
parser.add_positional_option<SimpleOption<size_t>>(
"h0", 1)),
108 interleave_lhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_lhs", 1)),
109 interleave_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_rhs", 1)),
110 transpose_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"transpose_rhs", 1)),
111 export_to_cl_image_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"export_to_cl_image_rhs", 1))
113 m0->set_help(
"Number of rows processed by the matrix multiplication");
114 n0->set_help(
"Number of columns processed by the matrix multiplication");
115 k0->set_help(
"Number of partial accumulations performed by the matrix multiplication");
116 v0->set_help(
"Number of vertical blocks of size (m0xk0) stored on the same output row");
117 h0->set_help(
"Number of horizontal blocks of size (k0xn0) stored on the same output row");
118 interleave_lhs->set_help(
"Interleave lhs matrix (1) / Do not interleave lhs matrix (0)");
119 interleave_rhs->set_help(
"Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
123 transpose_rhs->set_help(
"Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do "
124 "transpose lhs matrix (0)");
125 export_to_cl_image_rhs->set_help(
126 "Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)");
129 GemmConfigOptions(
const GemmConfigOptions &) =
delete;
131 GemmConfigOptions &operator=(
const GemmConfigOptions &) =
delete;
133 GemmConfigOptions(GemmConfigOptions &&) =
default;
135 GemmConfigOptions &operator=(GemmConfigOptions &&) =
default;
137 ~GemmConfigOptions() =
default;
139 SimpleOption<size_t> *m0;
140 SimpleOption<size_t> *n0;
141 SimpleOption<size_t> *k0;
142 SimpleOption<size_t> *v0;
143 SimpleOption<size_t> *h0;
144 SimpleOption<size_t> *interleave_lhs;
145 SimpleOption<size_t> *interleave_rhs;
149 SimpleOption<size_t> *
151 SimpleOption<size_t> *export_to_cl_image_rhs;
160 GemmConfigs consume_gemm_configs(
const GemmConfigOptions &options)
163 configs.m0 = options.m0->value();
164 configs.n0 = options.n0->value();
165 configs.k0 = options.k0->value();
166 configs.v0 = options.v0->value();
167 configs.h0 = options.h0->value();
168 configs.interleave_lhs = options.interleave_lhs->value() != 0;
172 configs.transpose_lhs = options.transpose_rhs->value() == 0;
173 configs.interleave_rhs = options.interleave_rhs->value() != 0;
174 configs.transpose_rhs = options.transpose_rhs->value() != 0;
175 configs.export_to_cl_image_rhs = options.export_to_cl_image_rhs->value() != 0;
186 class CLGEMMMatrixMultiplyReshapedExample :
public Example
189 bool do_setup(
int argc,
char **argv)
override
192 const float alpha = 1.0f;
193 const float beta = 0.0f;
201 GemmConfigOptions config_options(
parser);
208 parser.print_help(argv[0]);
214 std::cerr <<
"Invalid arguments." << std::endl;
215 parser.print_help(argv[0]);
216 std::cerr <<
"Falling back to default parameters and configs" << std::endl;
222 configs = consume_gemm_configs(config_options);
226 std::cout <<
"Gemm parameters:" << std::endl;
227 std::cout << params << std::endl;
228 std::cout <<
"Gemm configurations:" << std::endl;
229 std::cout << configs << std::endl;
240 lhs_info.
m0 = configs.m0;
241 lhs_info.
k0 = configs.k0;
242 lhs_info.
v0 = configs.v0;
244 lhs_info.
transpose = configs.transpose_lhs;
247 rhs_info.
n0 = configs.n0;
248 rhs_info.
k0 = configs.k0;
249 rhs_info.
h0 = configs.h0;
251 rhs_info.
transpose = configs.transpose_rhs;
255 kernel_info.
m = params.
M;
256 kernel_info.
n = params.
N;
257 kernel_info.
k = params.
K;
263 if (rhs_info.
h0 == 0)
265 rhs_info.
h0 = std::max(kernel_info.
n / rhs_info.
n0, 1
U);
269 lhs_reshaped.allocator()->init(
273 rhs_reshaped.allocator()->init(
280 std::cerr <<
"cl_image is not supported on the device, disable export_to_cl_image" << std::endl;
291 std::cerr <<
"Unsupported arguments." << std::endl;
292 std::cerr <<
"Check documentation for supported/unsupported combinations" << std::endl;
296 status =
gemm.validate(lhs_reshaped.info(), rhs_reshaped.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info,
297 rhs_info, kernel_info);
301 std::cerr <<
"Unsupported arguments." << std::endl;
302 std::cerr <<
"Check documentation for supported/unsupported combinations" << std::endl;
307 reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
310 gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info,
311 rhs_info, kernel_info);
314 lhs.allocator()->allocate();
315 rhs.allocator()->allocate();
316 lhs_reshaped.allocator()->allocate();
317 rhs_reshaped.allocator()->allocate();
318 bias.allocator()->allocate();
319 dst.allocator()->allocate();
323 void do_run()
override
327 reshape_lhs.run(reshape_lsh_pack);
337 void do_teardown()
override
358 int main(
int argc,
char **argv)
360 return run_example<CLGEMMMatrixMultiplyReshapedExample>(argc, argv);