24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
39 #include "utils/Utils.h"
47 using namespace utils;
60 bool interleave_rhs{
true};
61 bool transpose_rhs{
true};
62 bool export_to_cl_image_rhs{
true};
72 ::std::ostream &
operator<<(::std::ostream &os,
const GemmConfigs &configs)
74 std::string false_str = std::string(
"false");
75 std::string true_str = std::string(
"true");
77 os <<
"m0 : " << configs.m0 << std::endl;
78 os <<
"n0 : " << configs.n0 << std::endl;
79 os <<
"k0 : " << configs.k0 << std::endl;
80 os <<
"h0 : " << configs.h0 << std::endl;
81 os <<
"interleave_rhs : " << (configs.interleave_rhs ? true_str : false_str) << std::endl;
82 os <<
"transpose_rhs : " << (configs.transpose_rhs ? true_str : false_str) << std::endl;
83 os <<
"export_to_cl_image_rhs : " << (configs.export_to_cl_image_rhs ? true_str : false_str) << std::endl;
88 class GemmConfigOptions
95 GemmConfigOptions(CommandLineParser &
parser)
96 : m0(
parser.add_positional_option<SimpleOption<size_t>>(
"m0", 4)),
97 n0(
parser.add_positional_option<SimpleOption<size_t>>(
"n0", 4)),
98 k0(
parser.add_positional_option<SimpleOption<size_t>>(
"k0", 4)),
99 h0(
parser.add_positional_option<SimpleOption<size_t>>(
"h0", 1)),
100 interleave_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"interleave_rhs", 1)),
101 transpose_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"transpose_rhs", 1)),
102 export_to_cl_image_rhs(
parser.add_positional_option<SimpleOption<size_t>>(
"export_to_cl_image_rhs", 1))
104 m0->set_help(
"Number of rows processed by the matrix multiplication");
105 n0->set_help(
"Number of columns processed by the matrix multiplication");
106 k0->set_help(
"Number of partial accumulations performed by the matrix multiplication");
107 h0->set_help(
"Number of horizontal blocks of size (k0xn0) stored on the same output row");
108 interleave_rhs->set_help(
"Interleave rhs matrix (1) / Do not interleave rhs matrix (0)");
109 transpose_rhs->set_help(
"Transpose rhs matrix (1) / Do not transpose rhs matrix (0)");
110 export_to_cl_image_rhs->set_help(
111 "Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)");
114 GemmConfigOptions(
const GemmConfigOptions &) =
delete;
116 GemmConfigOptions &operator=(
const GemmConfigOptions &) =
delete;
118 GemmConfigOptions(GemmConfigOptions &&) =
default;
120 GemmConfigOptions &operator=(GemmConfigOptions &&) =
default;
122 ~GemmConfigOptions() =
default;
124 SimpleOption<size_t> *m0;
125 SimpleOption<size_t> *n0;
126 SimpleOption<size_t> *k0;
127 SimpleOption<size_t> *h0;
128 SimpleOption<size_t> *interleave_rhs;
129 SimpleOption<size_t> *transpose_rhs;
130 SimpleOption<size_t> *export_to_cl_image_rhs;
139 GemmConfigs consume_gemm_configs(
const GemmConfigOptions &options)
142 configs.m0 = options.m0->value();
143 configs.n0 = options.n0->value();
144 configs.k0 = options.k0->value();
145 configs.h0 = options.h0->value();
146 configs.interleave_rhs = options.interleave_rhs->value() != 0;
147 configs.transpose_rhs = options.transpose_rhs->value() != 0;
148 configs.export_to_cl_image_rhs = options.export_to_cl_image_rhs->value() != 0;
156 class CLGEMMMatrixMultiplyReshapedOnlyRHSExample :
public Example
159 bool do_setup(
int argc,
char **argv)
override
162 const float alpha = 1.0f;
163 const float beta = 0.0f;
171 GemmConfigOptions config_options(
parser);
178 parser.print_help(argv[0]);
184 std::cerr <<
"Invalid arguments." << std::endl;
185 parser.print_help(argv[0]);
186 std::cerr <<
"Falling back to default parameters and configs" << std::endl;
192 configs = consume_gemm_configs(config_options);
196 std::cout <<
"Gemm parameters:" << std::endl;
197 std::cout << params << std::endl;
198 std::cout <<
"Gemm configurations:" << std::endl;
199 std::cout << configs << std::endl;
210 lhs_info.
m0 = configs.m0;
211 lhs_info.
k0 = configs.k0;
214 rhs_info.
n0 = configs.n0;
215 rhs_info.
k0 = configs.k0;
216 rhs_info.
h0 = configs.h0;
218 rhs_info.
transpose = configs.transpose_rhs;
222 kernel_info.
m = params.
M;
223 kernel_info.
n = params.
N;
224 kernel_info.
k = params.
K;
230 if (rhs_info.
h0 == 0)
232 rhs_info.
h0 = std::max(kernel_info.
n / rhs_info.
n0, 1
U);
236 rhs_reshaped.allocator()->init(
243 std::cerr <<
"cl_image is not supported on the device, disable export_to_cl_image" << std::endl;
250 status =
gemm.validate(lhs.info(), rhs_reshaped.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info,
251 rhs_info, kernel_info);
255 std::cerr <<
"Unsupported arguments." << std::endl;
256 std::cerr <<
"Check documentation for supported/unsupported combinations" << std::endl;
261 gemm.configure(lhs.info(), rhs_reshaped.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info, rhs_info,
265 lhs.allocator()->allocate();
266 rhs.allocator()->allocate();
267 rhs_reshaped.allocator()->allocate();
268 bias.allocator()->allocate();
269 dst.allocator()->allocate();
273 void do_run()
override
283 void do_teardown()
override
302 int main(
int argc,
char **argv)
304 return run_example<CLGEMMMatrixMultiplyReshapedOnlyRHSExample>(argc, argv);