38 bool evaluate(GEMMShape
shape, Condition cond)
41 constexpr
float eps = 0.0001f;
43 std::vector<std::pair<std::string, float>> cond_values
45 {
"m",
static_cast<float>(shape.m) },
46 {
"n",
static_cast<float>(shape.n) },
47 {
"k",
static_cast<float>(shape.k) },
48 {
"b",
static_cast<float>(shape.b) },
49 {
"r_mn",
static_cast<float>(shape.m) / shape.n },
50 {
"r_mk",
static_cast<float>(shape.m) / shape.k },
51 {
"r_nk",
static_cast<float>(shape.n) / shape.k },
52 {
"r_mnk",
static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
53 {
"workload", (
static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
55 auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
56 [&cond](decltype(*cond_values.begin()) it)
58 return it.first == cond.feature;
62 const float cond_value = cond_value_pair_it->second;
67 return cond_value < cond.threshold;
71 return cond_value <= cond.threshold;
75 return cond_value > cond.threshold;
79 return cond_value >= cond.threshold;
84 return std::abs(cond_value - cond.threshold) < eps;
91 constexpr
size_t HeuristicTree::_max_num_nodes;
92 constexpr
size_t HeuristicTree::_max_query_depth;
101 : _id{
id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{
data_type }, _tree{}
105 template <
typename T>
109 auto cur_node = _tree.at(_root).get();
113 if(depth > _max_query_depth)
116 return std::make_pair(
false, T{});
119 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
120 if(evaluate(shape, br_node->condition))
122 cur_node = _tree.at(br_node->true_node).get();
126 cur_node = _tree.at(br_node->false_node).get();
131 auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
132 return std::make_pair(
true, l_node->value);
135 template <
typename T>
138 if(_tree.size() >= _max_num_nodes)
143 if(_tree.find(
id) != _tree.end())
148 _tree[
id] = std::make_unique<LeafNode<T>>(
id, val);
154 if(_tree.size() >= _max_num_nodes)
160 const std::set<std::string> supported_features =
162 "m",
"n",
"k",
"b",
"r_mn",
"r_mk",
"r_nk",
"r_mnk",
"workload" 164 const auto orig_feature = cond.
feature;
169 if(supported_features.find(cond.
feature) == supported_features.end())
175 if(_tree.find(
id) != _tree.end())
180 _tree[
id] = std::make_unique<BranchNode>(
id, cond, t_node, f_node);
184 bool HeuristicTree::check_if_structurally_correct()
const 186 std::set<NodeID> visited;
187 std::deque<NodeID> to_visit{ _root };
189 while(!to_visit.empty())
191 auto id = to_visit.front();
192 to_visit.pop_front();
193 if(_tree.find(
id) == _tree.end())
198 auto not_seen_before = visited.insert(
id);
199 if(!not_seen_before.second)
204 auto cur_node = _tree.at(
id).get();
207 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
208 to_visit.push_back(br_node->true_node);
209 to_visit.push_back(br_node->false_node);
212 if(visited.size() != _tree.size())
227 if(_tree.find(_root) == _tree.end())
232 return check_if_structurally_correct();
236 template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(
GEMMShape shape)
const;
238 template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(
GEMMShape shape)
const;
240 template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(
GEMMShape shape)
const;
242 template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(
GEMMShape shape)
const;
GEMM Configuration for Reshaped kernel.
bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node)
Add a branch node.
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
CLGEMMKernelType
OpenCL GEMM kernel types.
HeuristicType
Types of Heuristic (tree)
Copyright (c) 2017-2022 Arm Limited.
std::pair< bool, T > query(GEMMShape shape) const
Query a leaf value given a gemm shape.
std::string tolower(std::string string)
Convert string to lower case.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
HeuristicTree()
Constructor.
GEMM Configuration for Reshaped Only RHS kernel.
Greater than or equal to.
bool check()
Check if tree is valid.
A binary decision tree based heuristic.
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
A branch condition expression evaluating: feature op threshold.
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
GEMM Configuration for Native kernel.
GEMM Shape used for query.
bool add_leaf(NodeID id, T leaf_value)
Add a leaf node.
TreeID id() const
Get tree ID.
DataType
Available data types.
std::string feature
Feature name.