39 bool evaluate(GEMMShape
shape, Condition cond)
42 constexpr
float eps = 0.0001f;
44 std::vector<std::pair<std::string, float>> cond_values{
45 {
"m",
static_cast<float>(
shape.m)},
46 {
"n",
static_cast<float>(
shape.n)},
47 {
"k",
static_cast<float>(
shape.k)},
48 {
"b",
static_cast<float>(
shape.b)},
49 {
"r_mn",
static_cast<float>(
shape.m) /
shape.n},
50 {
"r_mk",
static_cast<float>(
shape.m) /
shape.k},
51 {
"r_nk",
static_cast<float>(
shape.n) /
shape.k},
52 {
"r_mnk",
static_cast<float>(
shape.m) / (
static_cast<float>(
shape.n) /
shape.k)},
53 {
"workload", (
static_cast<float>(
shape.m) *
shape.n *
shape.b) / 20.0}};
54 auto cond_value_pair_it =
55 std::find_if(cond_values.begin(), cond_values.end(),
56 [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; });
59 const float cond_value = cond_value_pair_it->second;
64 return cond_value < cond.threshold;
68 return cond_value <= cond.threshold;
72 return cond_value > cond.threshold;
76 return cond_value >= cond.threshold;
81 return std::abs(cond_value - cond.threshold) < eps;
88 constexpr
size_t HeuristicTree::_max_num_nodes;
89 constexpr
size_t HeuristicTree::_max_query_depth;
97 : _id{
id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{
data_type}, _tree{}
101 template <
typename T>
105 auto cur_node = _tree.at(_root).get();
109 if (depth > _max_query_depth)
113 return std::make_pair(
false, T{});
116 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
117 if (evaluate(
shape, br_node->condition))
119 cur_node = _tree.at(br_node->true_node).get();
123 cur_node = _tree.at(br_node->false_node).get();
128 auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
129 return std::make_pair(
true, l_node->value);
132 template <
typename T>
135 if (_tree.size() >= _max_num_nodes)
140 if (_tree.find(
id) != _tree.end())
145 _tree[
id] = std::make_unique<LeafNode<T>>(
id, val);
151 if (_tree.size() >= _max_num_nodes)
157 const std::set<std::string> supported_features = {
"m",
"n",
"k",
"b",
"r_mn",
"r_mk",
"r_nk",
"r_mnk",
"workload"};
158 const auto orig_feature = cond.
feature;
160 [](
char c) { return std::tolower(c); });
161 if (supported_features.find(cond.
feature) == supported_features.end())
167 if (_tree.find(
id) != _tree.end())
172 _tree[
id] = std::make_unique<BranchNode>(
id, cond, t_node, f_node);
176 bool HeuristicTree::check_if_structurally_correct()
const
178 std::set<NodeID> visited;
179 std::deque<NodeID> to_visit{_root};
181 while (!to_visit.empty())
183 auto id = to_visit.front();
184 to_visit.pop_front();
185 if (_tree.find(
id) == _tree.end())
190 auto not_seen_before = visited.insert(
id);
191 if (!not_seen_before.second)
196 auto cur_node = _tree.at(
id).get();
199 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
200 to_visit.push_back(br_node->true_node);
201 to_visit.push_back(br_node->false_node);
204 if (visited.size() != _tree.size())
219 if (_tree.find(_root) == _tree.end())
224 return check_if_structurally_correct();
228 template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(
GEMMShape shape)
const;
230 template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(
GEMMShape shape)
const;
232 template std::pair<bool, GEMMConfigReshapedOnlyRHS>
233 HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(
GEMMShape shape)
const;
235 template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(
GEMMShape shape)
const;