38 return std::tie(lhs.
m0, lhs.
n0, lhs.
k0) == std::tie(rhs.
m0, rhs.
n0, rhs.
k0);
42 return std::tie(lhs.
m0, lhs.
n0, lhs.
k0, lhs.
h0, lhs.
interleave_rhs, lhs.
transpose_rhs, lhs.
export_cl_image) == std::tie(rhs.
m0, rhs.
n0, rhs.
k0, rhs.
h0, rhs.
interleave_rhs, rhs.
transpose_rhs,
47 return std::tie(lhs.
m0, lhs.
n0, lhs.
k0, lhs.
v0, lhs.
h0, lhs.
interleave_lhs, lhs.
interleave_rhs, lhs.
transpose_rhs, lhs.
export_cl_image) == std::tie(rhs.
m0, rhs.
n0, rhs.
k0, rhs.
v0, rhs.
h0,
51 constexpr
size_t MLGOHeuristics::_max_num_trees;
54 : _indices{}, _trees{}, _tree_valid{}, _valid{
false }
65 return {
false, invalid };
69 if(_trees.find(index) == _trees.end())
72 return {
false, invalid };
74 return _trees.at(index).query<
GEMMType>(shape_query);
83 return {
false, invalid };
87 if(_trees.find(index) == _trees.end())
90 return {
false, invalid };
101 return {
false, invalid };
105 if(_trees.find(index) == _trees.end())
108 return {
false, invalid };
119 return {
false, invalid };
123 if(_trees.find(index) == _trees.end())
126 return {
false, invalid };
140 status = tree->check();
145 _tree_valid[id] =
true;
152 bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](
auto v)
156 == _tree_valid.end();
157 if(!all_trees_are_checked)
159 ARM_COMPUTE_LOG_INFO_MSG_CORE(
"Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
170 if(_indices.find(
id) == _indices.end())
173 return std::make_pair(
false,
nullptr);
175 const auto index = _indices[id];
177 if(_trees.find(index) == _trees.end())
180 return std::make_pair(
false,
nullptr);
182 auto &
t = _trees[index];
184 return std::make_pair(
true, &
t);
189 if(_indices.size() >= _max_num_trees)
196 const auto id =
t.id();
197 if(_indices.find(
id) != _indices.end())
204 const auto index =
t.index();
205 if(_trees.find(index) != _trees.end())
211 _indices[id] = index;
212 _trees[index] = std::move(
t);
213 _tree_valid[id] =
false;
220 fs.exceptions(std::ifstream::badbit);
221 fs.open(filename, std::ios::in);
225 return _valid =
false;
236 return _valid =
false;
238 *
this = std::move(parsed.second);
240 return _valid =
true;
unsigned int m0
Number of rows processed by the matrix multiplication.
bool transpose_rhs
True if the (k0xn0) block has to be transposed before been stored.
unsigned int b
Batch size.
std::pair< bool, HeuristicTree * > get_heuristic_tree(HeuristicTree::TreeID id)
Get the heuristic tree from tree id.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
MLGOHeuristics()
Constructor.
bool reload_from_file(const std::string &filename)
(Re)Load the heuristics from reading a dotmlgo file
bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
bool interleave_rhs
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
GEMM Configuration for Reshaped kernel.
bool interleave_rhs
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
std::string to_string(const GEMMConfigNative &config)
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
bool export_cl_image
True if the reshaped rhs has to be exported to cl_image.
CLGEMMKernelType
OpenCL GEMM kernel types.
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
unsigned int n0
Number of columns processed by the matrix multiplication.
std::pair< bool, GEMMType > query_gemm_type(const Query &query) const
Query the gemm type.
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Copyright (c) 2017-2021 Arm Limited.
bool check_all() const
Check the overall validity of the heuristics.
About the gemm config for reshaped kernel.
About the gemm config for native kernel.
std::string ip_target
The name of the IP target.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
unsigned int n0
Number of columns processed by the matrix multiplication.
unsigned int m0
Number of rows processed by the matrix multiplication.
bool add_heuristic_tree(HeuristicTree &&t)
Add a heuristic tree.
GEMM Configuration for Reshaped Only RHS kernel.
std::pair< bool, MLGOHeuristics > parse_mlgo(std::istream &in)
Parse and construct a MLGOHeuristics from input stream.
DataType data_type
Data type.
unsigned int n
Number of columns for the rhs matrix.
bool export_cl_image
True if the reshaped rhs has to be exported to cl_image.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
unsigned int m0
Number of rows processed by the matrix multiplication.
A binary decision tree based heuristic.
bool transpose_rhs
True if the (k0xn0) block has to be transposed before been stored.
unsigned int v0
Number of vertical blocks of size (m0xk0) stored on the same output row.
bool interleave_lhs
True if the v0 (m0xk0) blocks have to be interleaved in the output row.
std::pair< bool, GEMMConfigReshapedOnlyRHS > query_gemm_config_reshaped_only_rhs(const Query &query) const
Query the gemm configuration for reshaped only rhs kernel.
std::pair< bool, GEMMConfigNative > query_gemm_config_native(const Query &query) const
Query the gemm configuration for native kernel.
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
bool check_heuristic_tree(HeuristicTree::TreeID id)
Check the validity of the heuristic tree.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
GEMM Configuration for Native kernel.
About the gemm config for reshaped only rhs kernel.
GEMM Shape used for query.
bool reload_from_stream(std::istream &istream)
(Re)Load the heuristics from reading an input stream
unsigned int n0
Number of columns processed by the matrix multiplication.
unsigned int m
Number of rows for the lhs matrix.
std::pair< bool, GEMMConfigReshaped > query_gemm_config_reshaped(const Query &query) const
Query the gemm configuration for reshaped kernel.
unsigned int k
Number of rows for the rhs matrix.