Compute Library
 21.02
HeuristicTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 #include "arm_compute/core/Log.h"
26 
27 #include <algorithm>
28 #include <deque>
29 #include <set>
30 namespace arm_compute
31 {
32 namespace mlgo
33 {
34 namespace
35 {
36 bool evaluate(GEMMShape shape, Condition cond)
37 {
38  // PRE: all features and ConditionalOps are valid
39  constexpr float eps = 0.0001f;
40  // Calculate all secondary features
41  std::vector<std::pair<std::string, float>> cond_values
42  {
43  { "m", static_cast<float>(shape.m) },
44  { "n", static_cast<float>(shape.n) },
45  { "k", static_cast<float>(shape.k) },
46  { "b", static_cast<float>(shape.b) },
47  { "r_mn", static_cast<float>(shape.m) / shape.n },
48  { "r_mk", static_cast<float>(shape.m) / shape.k },
49  { "r_nk", static_cast<float>(shape.n) / shape.k },
50  { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
51  { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
52  };
53  auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
54  [&cond](decltype(*cond_values.begin()) it)
55  {
56  return it.first == cond.feature;
57  });
58 
59  ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
60  const float cond_value = cond_value_pair_it->second;
61  switch(cond.op)
62  {
63  case ConditionalOp::LT:
64  {
65  return cond_value < cond.threshold;
66  }
67  case ConditionalOp::LE:
68  {
69  return cond_value <= cond.threshold;
70  }
71  case ConditionalOp::GT:
72  {
73  return cond_value > cond.threshold;
74  }
75  case ConditionalOp::GE:
76  {
77  return cond_value >= cond.threshold;
78  }
79  case ConditionalOp::EQ:
80  default:
81  {
82  return std::abs(cond_value - cond.threshold) < eps;
83  }
84  }
85 }
86 
87 } // namespace
88 
89 constexpr size_t HeuristicTree::_max_num_nodes;
90 constexpr size_t HeuristicTree::_max_query_depth;
91 constexpr HeuristicTree::NodeID HeuristicTree::_root;
92 
95 {
96 }
97 
98 HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
99  : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
100 {
101 }
102 
103 template <typename T>
104 std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
105 {
106  // Root ID = 0;
107  auto cur_node = _tree.at(_root).get();
108  size_t depth = 0;
109  while(cur_node->type() != NodeType::Leaf)
110  {
111  if(depth > _max_query_depth)
112  {
113  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
114  return std::make_pair(false, T{});
115  }
116  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
117  auto br_node = dynamic_cast<BranchNode *>(cur_node);
118  if(evaluate(shape, br_node->condition))
119  {
120  cur_node = _tree.at(br_node->true_node).get();
121  }
122  else
123  {
124  cur_node = _tree.at(br_node->false_node).get();
125  }
126  ++depth;
127  }
128  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
129  auto l_node = dynamic_cast<LeafNode<T> *>(cur_node);
130  return std::make_pair(true, l_node->value);
131 }
132 
133 template <typename T>
135 {
136  if(_tree.size() >= _max_num_nodes)
137  {
138  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
139  return false;
140  }
141  if(_tree.find(id) != _tree.end())
142  {
143  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
144  return false;
145  }
146  _tree[id] = std::make_unique<LeafNode<T>>(id, val);
147  return true;
148 }
149 
151 {
152  if(_tree.size() >= _max_num_nodes)
153  {
154  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
155  return false;
156  }
157 
158  const std::set<std::string> supported_features =
159  {
160  "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
161  };
162  const auto orig_feature = cond.feature;
163  std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
164  {
165  return std::tolower(c);
166  });
167  if(supported_features.find(cond.feature) == supported_features.end())
168  {
169  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
170  return false;
171  }
172 
173  if(_tree.find(id) != _tree.end())
174  {
175  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
176  return false;
177  }
178  _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
179  return true;
180 }
181 
182 bool HeuristicTree::check_if_structurally_correct() const
183 {
184  std::set<NodeID> visited;
185  std::deque<NodeID> to_visit{ _root };
186 
187  while(!to_visit.empty())
188  {
189  auto id = to_visit.front();
190  to_visit.pop_front();
191  if(_tree.find(id) == _tree.end())
192  {
193  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
194  return false;
195  }
196  auto not_seen_before = visited.insert(id);
197  if(!not_seen_before.second)
198  {
199  ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
200  return false;
201  }
202  auto cur_node = _tree.at(id).get();
203  if(cur_node->type() == NodeType::Branch)
204  {
205  auto br_node = dynamic_cast<BranchNode *>(cur_node);
206  to_visit.push_back(br_node->true_node);
207  to_visit.push_back(br_node->false_node);
208  }
209  }
210  if(visited.size() != _tree.size())
211  {
212  ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
213  return false;
214  }
215  return true;
216 }
217 
219 {
220  if(_tree.empty())
221  {
222  ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
223  return false;
224  }
225  if(_tree.find(_root) == _tree.end())
226  {
227  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
228  return false;
229  }
230  return check_if_structurally_correct();
231 }
232 
233 /** Explicit template instantiation @relates HeuristicTree */
234 template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
235 /** Explicit template instantiation @relates HeuristicTree */
236 template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
237 /** Explicit template instantiation @relates HeuristicTree */
238 template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
239 /** Explicit template instantiation @relates HeuristicTree */
240 template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
241 
242 /** Explicit template instantiation @relates HeuristicTree */
243 template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
244 /** Explicit template instantiation @relates HeuristicTree */
245 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
246 /** Explicit template instantiation @relates HeuristicTree */
248 /** Explicit template instantiation @relates HeuristicTree */
249 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
250 
251 } // namespace mlgo
252 
253 } // namespace arm_compute
GEMM Configuration for Reshaped kernel.
Definition: Common.h:66
bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node)
Add a branch node.
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
HeuristicType
Types of Heuristic (tree)
Definition: Common.h:35
Copyright (c) 2017-2021 Arm Limited.
std::pair< bool, T > query(GEMMShape shape) const
Query a leaf value given a gemm shape.
const DataType data_type
Definition: Im2Col.cpp:150
std::string tolower(std::string string)
Convert string to lower case.
Definition: Utility.h:203
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
GEMM Configuration for Reshaped Only RHS kernel.
Definition: Common.h:54
Greater than or equal to.
bool check()
Check if tree is valid.
A binary decision tree based heuristic.
Definition: HeuristicTree.h:67
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:784
A branch condition expression evaluating: feature op threshold.
Definition: HeuristicTree.h:50
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87
GEMM Configuration for Native kernel.
Definition: Common.h:46
GEMM Shape used for query.
Definition: HeuristicTree.h:58
bool add_leaf(NodeID id, T leaf_value)
Add a leaf node.
TreeID id() const
Get tree ID.
DataType
Available data types.
Definition: Types.h:77
std::string feature
Feature name.
Definition: HeuristicTree.h:52