Compute Library
 21.11
MLGOHeuristics Class Reference

MLGOHeuristics for configuring GEMM kernels. More...

#include <MLGOHeuristics.h>

Public Member Functions

 MLGOHeuristics ()
 Constructor. More...
 
 ~MLGOHeuristics ()=default
 Default Destructor. More...
 
 MLGOHeuristics (const MLGOHeuristics &)=delete
 Prevent Copy Construct. More...
 
MLGOHeuristicsoperator= (const MLGOHeuristics &)=delete
 Prevent Copy Assignment. More...
 
 MLGOHeuristics (MLGOHeuristics &&)=default
 Default Move Constructor. More...
 
MLGOHeuristicsoperator= (MLGOHeuristics &&)=default
 Default Move Assignment. More...
 
std::pair< bool, GEMMTypequery_gemm_type (const Query &query) const
 Query the gemm type. More...
 
std::pair< bool, GEMMConfigNativequery_gemm_config_native (const Query &query) const
 Query the gemm configuration for native kernel. More...
 
std::pair< bool, GEMMConfigReshapedOnlyRHSquery_gemm_config_reshaped_only_rhs (const Query &query) const
 Query the gemm configuration for reshaped only rhs kernel. More...
 
std::pair< bool, GEMMConfigReshapedquery_gemm_config_reshaped (const Query &query) const
 Query the gemm configuration for reshaped kernel. More...
 
bool reload_from_file (const std::string &filename)
 (Re)Load the heuristics from reading a dotmlgo file More...
 
bool reload_from_stream (std::istream &istream)
 (Re)Load the heuristics from reading an input stream More...
 
std::pair< bool, HeuristicTree * > get_heuristic_tree (HeuristicTree::TreeID id)
 Get the heuristic tree from tree id. More...
 
bool add_heuristic_tree (HeuristicTree &&t)
 Add a heuristic tree. More...
 
bool check_heuristic_tree (HeuristicTree::TreeID id)
 Check the validity of the heuristic tree. More...
 
bool check_all () const
 Check the overall validity of the heuristics. More...
 

Detailed Description

MLGOHeuristics for configuring GEMM kernels.

Definition at line 54 of file MLGOHeuristics.h.

Constructor & Destructor Documentation

◆ MLGOHeuristics() [1/3]

Constructor.

Definition at line 53 of file MLGOHeuristics.cpp.

54  : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
55 {
56 }

◆ ~MLGOHeuristics()

~MLGOHeuristics ( )
default

Default Destructor.

◆ MLGOHeuristics() [2/3]

MLGOHeuristics ( const MLGOHeuristics )
delete

Prevent Copy Construct.

◆ MLGOHeuristics() [3/3]

MLGOHeuristics ( MLGOHeuristics &&  )
default

Default Move Constructor.

Member Function Documentation

◆ add_heuristic_tree()

bool add_heuristic_tree ( HeuristicTree &&  t)

Add a heuristic tree.

Parameters
tHeuristic tree to be added

Definition at line 187 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, and tf_frozen_model_extractor::t.

Referenced by arm_compute::mlgo::parser::heuristics_table_entry().

188 {
189  if(_indices.size() >= _max_num_trees)
190  {
191  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
192  return false;
193  }
194  // PRE: correctness of t is guaranteed by the tree construction process
195  // Ensure unique id
196  const auto id = t.id();
197  if(_indices.find(id) != _indices.end())
198  {
199  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
200  return false;
201  }
202 
203  // Ensure unique index
204  const auto index = t.index();
205  if(_trees.find(index) != _trees.end())
206  {
207  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
208  return false;
209  }
210 
211  _indices[id] = index;
212  _trees[index] = std::move(t);
213  _tree_valid[id] = false;
214  return true;
215 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ check_all()

bool check_all ( ) const

Check the overall validity of the heuristics.

Returns
bool

Definition at line 149 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE.

Referenced by arm_compute::mlgo::parser::mlgo().

150 {
151  // Tree validities are already checked and cached.
152  bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
153  {
154  return !v.second;
155  })
156  == _tree_valid.end();
157  if(!all_trees_are_checked)
158  {
159  ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
160  return false;
161  }
162 
163  // Other top level checks...
164 
165  return true;
166 }
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ check_heuristic_tree()

bool check_heuristic_tree ( HeuristicTree::TreeID  id)

Check the validity of the heuristic tree.

Parameters
idID of the tree to be checked
Returns
bool

Definition at line 131 of file MLGOHeuristics.cpp.

References MLGOHeuristics::get_heuristic_tree().

Referenced by arm_compute::mlgo::parser::heuristic_tree().

132 {
133  bool status;
134  HeuristicTree *tree{ nullptr };
135  std::tie(status, tree) = get_heuristic_tree(id);
136  if(!status)
137  {
138  return status;
139  }
140  status = tree->check();
141  if(!status)
142  {
143  return status;
144  }
145  _tree_valid[id] = true;
146  return true;
147 }
std::pair< bool, HeuristicTree * > get_heuristic_tree(HeuristicTree::TreeID id)
Get the heuristic tree from tree id.

◆ get_heuristic_tree()

std::pair< bool, HeuristicTree * > get_heuristic_tree ( HeuristicTree::TreeID  id)

Get the heuristic tree from tree id.

Parameters
[in]idTree id.
Returns
HeuristicTree&

Definition at line 168 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, and tf_frozen_model_extractor::t.

Referenced by MLGOHeuristics::check_heuristic_tree(), and arm_compute::mlgo::parser::heuristic_tree().

169 {
170  if(_indices.find(id) == _indices.end())
171  {
172  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
173  return std::make_pair(false, nullptr);
174  }
175  const auto index = _indices[id];
176 
177  if(_trees.find(index) == _trees.end())
178  {
179  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
180  return std::make_pair(false, nullptr);
181  }
182  auto &t = _trees[index];
183 
184  return std::make_pair(true, &t);
185 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ operator=() [1/2]

MLGOHeuristics& operator= ( const MLGOHeuristics )
delete

Prevent Copy Assignment.

◆ operator=() [2/2]

MLGOHeuristics& operator= ( MLGOHeuristics &&  )
default

Default Move Assignment.

◆ query_gemm_config_native()

std::pair< bool, GEMMConfigNative > query_gemm_config_native ( const Query query) const

Query the gemm configuration for native kernel.

Parameters
[in]queryQuery
Returns
std::pair<bool, GEMMConfigNative> bool signals if the query succeeded or failed

Definition at line 76 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, Query::b, Query::data_type, arm_compute::mlgo::GEMM_Config_Native, Query::ip_target, Query::k, Query::m, Query::n, and arm_compute::mlgo::to_string().

Referenced by arm_compute::cl_gemm::auto_heuristics::select_mlgo_gemm_config_native().

77 {
78  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
79  const auto invalid = GEMMConfigNative{};
80  if(!_valid)
81  {
82  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
83  return { false, invalid };
84  }
85  auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
86  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
87  if(_trees.find(index) == _trees.end())
88  {
89  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
90  return { false, invalid };
91  }
92  return _trees.at(index).query<GEMMConfigNative>(shape_query);
93 }
std::string to_string(const GEMMConfigNative &config)
Definition: Utils.cpp:156
About the gemm config for native kernel.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ query_gemm_config_reshaped()

std::pair< bool, GEMMConfigReshaped > query_gemm_config_reshaped ( const Query query) const

Query the gemm configuration for reshaped kernel.

Parameters
[in]queryQuery
Returns
std::pair<bool, GEMMConfigReshaped> bool signals if the query succeeded or failed

Definition at line 112 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, Query::b, Query::data_type, arm_compute::mlgo::GEMM_Config_Reshaped, Query::ip_target, Query::k, Query::m, Query::n, and arm_compute::mlgo::to_string().

Referenced by arm_compute::cl_gemm::auto_heuristics::select_mlgo_gemm_config_reshaped(), and arm_compute::test::validation::TEST_CASE().

113 {
114  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
115  const auto invalid = GEMMConfigReshaped{};
116  if(!_valid)
117  {
118  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
119  return { false, invalid };
120  }
121  auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
122  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
123  if(_trees.find(index) == _trees.end())
124  {
125  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
126  return { false, invalid };
127  }
128  return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
129 }
std::string to_string(const GEMMConfigNative &config)
Definition: Utils.cpp:156
About the gemm config for reshaped kernel.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ query_gemm_config_reshaped_only_rhs()

std::pair< bool, GEMMConfigReshapedOnlyRHS > query_gemm_config_reshaped_only_rhs ( const Query query) const

Query the gemm configuration for reshaped only rhs kernel.

Parameters
[in]queryQuery
Returns
std::pair<bool, GEMMConfigReshapedOnlyRHS> bool signals if the query succeeded or failed

Definition at line 94 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, Query::b, Query::data_type, arm_compute::mlgo::GEMM_Config_Reshaped_Only_RHS, Query::ip_target, Query::k, Query::m, Query::n, and arm_compute::mlgo::to_string().

Referenced by arm_compute::cl_gemm::auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(), and arm_compute::test::validation::TEST_CASE().

95 {
96  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
97  const auto invalid = GEMMConfigReshapedOnlyRHS{};
98  if(!_valid)
99  {
100  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
101  return { false, invalid };
102  }
103  auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
104  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
105  if(_trees.find(index) == _trees.end())
106  {
107  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
108  return { false, invalid };
109  }
110  return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
111 }
std::string to_string(const GEMMConfigNative &config)
Definition: Utils.cpp:156
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87
About the gemm config for reshaped only rhs kernel.

◆ query_gemm_type()

std::pair< bool, GEMMType > query_gemm_type ( const Query query) const

Query the gemm type.

Parameters
[in]queryQuery
Returns
std::pair<bool, GEMMType> signals if the query succeeded or failed

Definition at line 58 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, Query::b, Query::data_type, arm_compute::mlgo::GEMM_Type, Query::ip_target, Query::k, Query::m, Query::n, arm_compute::RESHAPED, and arm_compute::mlgo::to_string().

Referenced by arm_compute::cl_gemm::auto_heuristics::select_mlgo_gemm_kernel(), and arm_compute::test::validation::TEST_CASE().

59 {
60  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
61  const auto invalid = GEMMType::RESHAPED;
62  if(!_valid)
63  {
64  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
65  return { false, invalid };
66  }
67  auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
68  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
69  if(_trees.find(index) == _trees.end())
70  {
71  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
72  return { false, invalid };
73  }
74  return _trees.at(index).query<GEMMType>(shape_query);
75 }
std::string to_string(const GEMMConfigNative &config)
Definition: Utils.cpp:156
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
CLGEMMKernelType GEMMType
Definition: Common.h:43
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ reload_from_file()

bool reload_from_file ( const std::string &  filename)

(Re)Load the heuristics from reading a dotmlgo file

Parameters
[in]filenamePath to the dotmlgo file
Returns
bool Signals if the reload succeeded or failed

Definition at line 217 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, and MLGOHeuristics::reload_from_stream().

218 {
219  std::ifstream fs;
220  fs.exceptions(std::ifstream::badbit);
221  fs.open(filename, std::ios::in);
222  if(!fs.is_open())
223  {
224  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
225  return _valid = false;
226  }
227  return reload_from_stream(fs);
228 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
bool reload_from_stream(std::istream &istream)
(Re)Load the heuristics from reading an input stream

◆ reload_from_stream()

bool reload_from_stream ( std::istream &  istream)

(Re)Load the heuristics from reading an input stream

Parameters
[in]istreamIstream containing mlgo heuristics
Returns
bool Signals if the reload succeeded or failed

Definition at line 230 of file MLGOHeuristics.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, and arm_compute::mlgo::parser::parse_mlgo().

Referenced by MLGOHeuristics::reload_from_file(), and arm_compute::test::validation::TEST_CASE().

231 {
232  auto parsed = parser::parse_mlgo(in);
233  if(!parsed.first)
234  {
235  ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
236  return _valid = false;
237  }
238  *this = std::move(parsed.second);
239  ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
240  return _valid = true;
241 }
std::pair< bool, MLGOHeuristics > parse_mlgo(std::istream &in)
Parse and construct a MLGOHeuristics from input stream.
Definition: MLGOParser.cpp:798
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

The documentation for this class was generated from the following files: