24 #ifndef ARM_COMPUTE_CL
25 #error "This example needs to be built with -DARM_COMPUTE_CL"
39 #include "utils/Utils.h"
46 using namespace utils;
67 ::std::ostream &
operator<<(::std::ostream &os,
const GemmConfigs &configs)
69 std::string false_str = std::string(
"false");
70 std::string true_str = std::string(
"true");
72 os <<
"m0 : " << configs.m0 << std::endl;
73 os <<
"n0 : " << configs.n0 << std::endl;
74 os <<
"k0 : " << configs.k0 << std::endl;
79 class GemmConfigOptions
86 GemmConfigOptions(CommandLineParser &
parser)
87 : m0(
parser.add_positional_option<SimpleOption<size_t>>(
"m0", 4)),
88 n0(
parser.add_positional_option<SimpleOption<size_t>>(
"n0", 4)),
89 k0(
parser.add_positional_option<SimpleOption<size_t>>(
"k0", 4))
91 m0->set_help(
"Number of rows processed by the matrix multiplication");
92 n0->set_help(
"Number of columns processed by the matrix multiplication");
93 k0->set_help(
"Number of partial accumulations performed by the matrix multiplication");
96 GemmConfigOptions(
const GemmConfigOptions &) =
delete;
98 GemmConfigOptions &operator=(
const GemmConfigOptions &) =
delete;
100 GemmConfigOptions(GemmConfigOptions &&) =
default;
102 GemmConfigOptions &operator=(GemmConfigOptions &&) =
default;
104 ~GemmConfigOptions() =
default;
106 SimpleOption<size_t> *m0;
107 SimpleOption<size_t> *n0;
108 SimpleOption<size_t> *k0;
117 GemmConfigs consume_gemm_configs(
const GemmConfigOptions &options)
120 configs.m0 = options.m0->value();
121 configs.n0 = options.n0->value();
122 configs.k0 = options.k0->value();
130 class CLGEMMMatrixMultiplyNativeExample :
public Example
133 bool do_setup(
int argc,
char **argv)
override
136 const float alpha = 1.0f;
137 const float beta = 0.0f;
145 GemmConfigOptions config_options(
parser);
152 parser.print_help(argv[0]);
158 std::cerr <<
"Invalid arguments." << std::endl;
159 parser.print_help(argv[0]);
160 std::cerr <<
"Falling back to default parameters and configs" << std::endl;
166 configs = consume_gemm_configs(config_options);
170 std::cout <<
"Gemm parameters:" << std::endl;
171 std::cout << params << std::endl;
172 std::cout <<
"Gemm configurations:" << std::endl;
173 std::cout << configs << std::endl;
184 lhs_info.
m0 = configs.m0;
185 lhs_info.
k0 = configs.k0;
188 rhs_info.
n0 = configs.n0;
189 rhs_info.
k0 = configs.k0;
192 kernel_info.
m = params.
M;
193 kernel_info.
n = params.
N;
194 kernel_info.
k = params.
K;
202 status =
gemm.validate(lhs.info(), rhs.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info, rhs_info,
207 std::cerr <<
"Unsupported arguments." << std::endl;
208 std::cerr <<
"Check documentation for supported/unsupported combinations" << std::endl;
213 gemm.configure(lhs.info(), rhs.info(),
bias.info(),
dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
216 lhs.allocator()->allocate();
217 rhs.allocator()->allocate();
218 bias.allocator()->allocate();
219 dst.allocator()->allocate();
223 void do_run()
override
233 void do_teardown()
override
251 int main(
int argc,
char **argv)
253 return run_example<CLGEMMMatrixMultiplyNativeExample>(argc, argv);