Compute Library
 21.11
HeuristicTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 #include "arm_compute/core/Log.h"
26 
27 #include "support/Cast.h"
28 
29 #include <algorithm>
30 #include <deque>
31 #include <set>
32 namespace arm_compute
33 {
34 namespace mlgo
35 {
36 namespace
37 {
38 bool evaluate(GEMMShape shape, Condition cond)
39 {
40  // PRE: all features and ConditionalOps are valid
41  constexpr float eps = 0.0001f;
42  // Calculate all secondary features
43  std::vector<std::pair<std::string, float>> cond_values
44  {
45  { "m", static_cast<float>(shape.m) },
46  { "n", static_cast<float>(shape.n) },
47  { "k", static_cast<float>(shape.k) },
48  { "b", static_cast<float>(shape.b) },
49  { "r_mn", static_cast<float>(shape.m) / shape.n },
50  { "r_mk", static_cast<float>(shape.m) / shape.k },
51  { "r_nk", static_cast<float>(shape.n) / shape.k },
52  { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
53  { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
54  };
55  auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
56  [&cond](decltype(*cond_values.begin()) it)
57  {
58  return it.first == cond.feature;
59  });
60 
61  ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
62  const float cond_value = cond_value_pair_it->second;
63  switch(cond.op)
64  {
65  case ConditionalOp::LT:
66  {
67  return cond_value < cond.threshold;
68  }
69  case ConditionalOp::LE:
70  {
71  return cond_value <= cond.threshold;
72  }
73  case ConditionalOp::GT:
74  {
75  return cond_value > cond.threshold;
76  }
77  case ConditionalOp::GE:
78  {
79  return cond_value >= cond.threshold;
80  }
81  case ConditionalOp::EQ:
82  default:
83  {
84  return std::abs(cond_value - cond.threshold) < eps;
85  }
86  }
87 }
88 
89 } // namespace
90 
91 constexpr size_t HeuristicTree::_max_num_nodes;
92 constexpr size_t HeuristicTree::_max_query_depth;
93 constexpr HeuristicTree::NodeID HeuristicTree::_root;
94 
97 {
98 }
99 
100 HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
101  : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
102 {
103 }
104 
105 template <typename T>
106 std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
107 {
108  // Root ID = 0;
109  auto cur_node = _tree.at(_root).get();
110  size_t depth = 0;
111  while(cur_node->type() != NodeType::Leaf)
112  {
113  if(depth > _max_query_depth)
114  {
115  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
116  return std::make_pair(false, T{});
117  }
118  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
119  auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
120  if(evaluate(shape, br_node->condition))
121  {
122  cur_node = _tree.at(br_node->true_node).get();
123  }
124  else
125  {
126  cur_node = _tree.at(br_node->false_node).get();
127  }
128  ++depth;
129  }
130  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
131  auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
132  return std::make_pair(true, l_node->value);
133 }
134 
135 template <typename T>
137 {
138  if(_tree.size() >= _max_num_nodes)
139  {
140  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
141  return false;
142  }
143  if(_tree.find(id) != _tree.end())
144  {
145  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
146  return false;
147  }
148  _tree[id] = std::make_unique<LeafNode<T>>(id, val);
149  return true;
150 }
151 
153 {
154  if(_tree.size() >= _max_num_nodes)
155  {
156  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
157  return false;
158  }
159 
160  const std::set<std::string> supported_features =
161  {
162  "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
163  };
164  const auto orig_feature = cond.feature;
165  std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
166  {
167  return std::tolower(c);
168  });
169  if(supported_features.find(cond.feature) == supported_features.end())
170  {
171  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
172  return false;
173  }
174 
175  if(_tree.find(id) != _tree.end())
176  {
177  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
178  return false;
179  }
180  _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
181  return true;
182 }
183 
184 bool HeuristicTree::check_if_structurally_correct() const
185 {
186  std::set<NodeID> visited;
187  std::deque<NodeID> to_visit{ _root };
188 
189  while(!to_visit.empty())
190  {
191  auto id = to_visit.front();
192  to_visit.pop_front();
193  if(_tree.find(id) == _tree.end())
194  {
195  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
196  return false;
197  }
198  auto not_seen_before = visited.insert(id);
199  if(!not_seen_before.second)
200  {
201  ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
202  return false;
203  }
204  auto cur_node = _tree.at(id).get();
205  if(cur_node->type() == NodeType::Branch)
206  {
207  auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
208  to_visit.push_back(br_node->true_node);
209  to_visit.push_back(br_node->false_node);
210  }
211  }
212  if(visited.size() != _tree.size())
213  {
214  ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
215  return false;
216  }
217  return true;
218 }
219 
221 {
222  if(_tree.empty())
223  {
224  ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
225  return false;
226  }
227  if(_tree.find(_root) == _tree.end())
228  {
229  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
230  return false;
231  }
232  return check_if_structurally_correct();
233 }
234 
235 /** Explicit template instantiation @relates HeuristicTree */
236 template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
237 /** Explicit template instantiation @relates HeuristicTree */
238 template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
239 /** Explicit template instantiation @relates HeuristicTree */
240 template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
241 /** Explicit template instantiation @relates HeuristicTree */
242 template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
243 
244 /** Explicit template instantiation @relates HeuristicTree */
245 template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
246 /** Explicit template instantiation @relates HeuristicTree */
247 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
248 /** Explicit template instantiation @relates HeuristicTree */
250 /** Explicit template instantiation @relates HeuristicTree */
251 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
252 
253 } // namespace mlgo
254 
255 } // namespace arm_compute
GEMM Configuration for Reshaped kernel.
Definition: Common.h:66
bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node)
Add a branch node.
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
HeuristicType
Types of Heuristic (tree)
Definition: Common.h:35
Copyright (c) 2017-2021 Arm Limited.
std::pair< bool, T > query(GEMMShape shape) const
Query a leaf value given a gemm shape.
const DataType data_type
Definition: Im2Col.cpp:150
std::string tolower(std::string string)
Convert string to lower case.
Definition: Utility.h:205
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
GEMM Configuration for Reshaped Only RHS kernel.
Definition: Common.h:54
Greater than or equal to.
bool check()
Check if tree is valid.
A binary decision tree based heuristic.
Definition: HeuristicTree.h:67
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:784
A branch condition expression evaluating: feature op threshold.
Definition: HeuristicTree.h:50
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87
GEMM Configuration for Native kernel.
Definition: Common.h:46
GEMM Shape used for query.
Definition: HeuristicTree.h:58
bool add_leaf(NodeID id, T leaf_value)
Add a leaf node.
TreeID id() const
Get tree ID.
DataType
Available data types.
Definition: Types.h:79
std::string feature
Feature name.
Definition: HeuristicTree.h:52