Compute Library
 21.05
HeuristicTree.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
25 #define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
26 
27 #include "arm_compute/core/Types.h"
29 
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <utility>
34 
35 namespace arm_compute
36 {
37 namespace mlgo
38 {
39 /** Conditional ops */
40 enum class ConditionalOp
41 {
42  EQ, /**< Equal */
43  LT, /**< Less than */
44  LE, /**< Less than or equal to */
45  GT, /**< Greater than */
46  GE, /**< Greater than or equal to */
47 };
48 
49 /** A branch condition expression evaluating: feature op threshold */
50 struct Condition
51 {
52  std::string feature; /**< Feature name */
53  ConditionalOp op; /**< Condtional op */
54  float threshold; /**< Threshold value */
55 };
56 
57 /** GEMM Shape used for query */
58 struct GEMMShape
59 {
60  unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
61  unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
62  unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
63  unsigned int b; /**< Batch size */
64 };
65 
66 /** A binary decision tree based heuristic */
68 {
69 public:
70  using NodeID = size_t;
71  using TreeID = size_t;
72  using Index = std::tuple<HeuristicType, std::string, DataType>;
73  enum class NodeType
74  {
75  Branch,
76  Leaf
77  };
78  struct Node
79  {
80  virtual NodeType type() const = 0;
81  virtual ~Node() = default;
82  };
83 
84  struct BranchNode : public Node
85  {
86  BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
87  : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node }
88  {
89  }
90  NodeType type() const override
91  {
92  return NodeType::Branch;
93  }
98  };
99 
100  template <typename T>
101  struct LeafNode : public Node
102  {
103  LeafNode(NodeID id, T val)
104  : id{ id }, value{ val }
105  {
106  }
107  NodeType type() const override
108  {
109  return NodeType::Leaf;
110  }
112  T value;
113  };
114 
115 public:
116  /** Constructor */
117  HeuristicTree();
118  /** Constructor */
119  HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type);
120  // Since the HeuristicTree is a handle that owns the the nodes, it is move-only
121  /** Prevent copy construction */
122  HeuristicTree(const HeuristicTree &) = delete;
123  /** Prevent copy assignment */
124  HeuristicTree &operator=(const HeuristicTree &) = delete;
125  /** Move constructor */
126  HeuristicTree(HeuristicTree &&other) noexcept = default;
127  /** Move assignment */
128  HeuristicTree &operator=(HeuristicTree &&other) = default;
129 
130  /** Query a leaf value given a gemm shape
131  *
132  * @tparam T Leaf value type
133  * @param shape A @ref GEMMShape for the query
134  * @return std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not
135  */
136  template <typename T>
137  std::pair<bool, T> query(GEMMShape shape) const;
138 
139  /** Add a leaf node
140  *
141  * @tparam T Leaf value type
142  * @param id Leaf node ID
143  * @param leaf_value Leaf node value
144  * @return bool If the addition succeeded or not
145  */
146  template <typename T>
147  bool add_leaf(NodeID id, T leaf_value);
148  /** Add a branch node
149  *
150  * @param id Branch node ID
151  * @param cond Branch node @ref Condition
152  * @param true_node True node's ID
153  * @param false_node False node's ID
154  * @return bool If the addition succeeded or not
155  */
156  bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node);
157 
158  /** Get tree ID
159  * @return TreeID
160  */
161  TreeID id() const
162  {
163  return _id;
164  }
165 
166  /** Get tree index
167  * @return Index
168  */
169  Index index() const
170  {
171  return std::make_tuple(_heuristic_type, _ip_target, _data_type);
172  }
173 
174  /** Check if tree is valid
175  * @return bool
176  */
177  bool check();
178 
179 private:
180  static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query
181  static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree
182  static constexpr NodeID _root{ 0 }; // Root tree ID
183 
184 private:
185  bool check_if_structurally_correct() const;
186 
187 private:
188  TreeID _id; /**< Heuristic tree ID */
189  HeuristicType _heuristic_type; /**< Heuristic type */
190  std::string _ip_target; /**< IP target associated with the tree */
191  DataType _data_type; /**< Data type associated with the tree */
192  std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */
193 };
194 } // namespace mlgo
195 
196 } // namespace arm_compute
197 
198 #endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
float threshold
Threshold value.
Definition: HeuristicTree.h:54
unsigned int n
Number of columns for the rhs matrix.
Definition: HeuristicTree.h:61
HeuristicTree & operator=(const HeuristicTree &)=delete
Prevent copy assignment.
unsigned int m
Number of rows for the lhs matrix.
Definition: HeuristicTree.h:60
bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node)
Add a branch node.
std::tuple< HeuristicType, std::string, DataType > Index
Definition: HeuristicTree.h:72
unsigned int k
Number of rows for the rhs matrix.
Definition: HeuristicTree.h:62
HeuristicType
Types of Heuristic (tree)
Definition: Common.h:35
Copyright (c) 2017-2021 Arm Limited.
ConditionalOp
Conditional ops.
Definition: HeuristicTree.h:40
std::pair< bool, T > query(GEMMShape shape) const
Query a leaf value given a gemm shape.
const DataType data_type
Definition: Im2Col.cpp:150
Greater than or equal to.
unsigned int b
Batch size.
Definition: HeuristicTree.h:63
bool check()
Check if tree is valid.
BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
Definition: HeuristicTree.h:86
A binary decision tree based heuristic.
Definition: HeuristicTree.h:67
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:784
A branch condition expression evaluating: feature op threshold.
Definition: HeuristicTree.h:50
virtual NodeType type() const =0
Index index() const
Get tree index.
GEMM Shape used for query.
Definition: HeuristicTree.h:58
bool add_leaf(NodeID id, T leaf_value)
Add a leaf node.
TreeID id() const
Get tree ID.
DataType
Available data types.
Definition: Types.h:77
ConditionalOp op
Condtional op.
Definition: HeuristicTree.h:53
std::string feature
Feature name.
Definition: HeuristicTree.h:52