Compute Library
 21.11
HeuristicTree Class Reference

A binary decision tree based heuristic. More...

#include <HeuristicTree.h>

Data Structures

struct  BranchNode
 
struct  LeafNode
 
struct  Node
 

Public Types

enum  NodeType { Branch, Leaf }
 
using NodeID = size_t
 
using TreeID = size_t
 
using Index = std::tuple< HeuristicType, std::string, DataType >
 

Public Member Functions

 HeuristicTree ()
 Constructor. More...
 
 HeuristicTree (TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
 Constructor. More...
 
 HeuristicTree (const HeuristicTree &)=delete
 Prevent copy construction. More...
 
HeuristicTreeoperator= (const HeuristicTree &)=delete
 Prevent copy assignment. More...
 
 HeuristicTree (HeuristicTree &&other) noexcept=default
 Move constructor. More...
 
HeuristicTreeoperator= (HeuristicTree &&other)=default
 Move assignment. More...
 
template<typename T >
std::pair< bool, T > query (GEMMShape shape) const
 Query a leaf value given a gemm shape. More...
 
template<typename T >
bool add_leaf (NodeID id, T leaf_value)
 Add a leaf node. More...
 
bool add_branch (NodeID id, Condition cond, NodeID true_node, NodeID false_node)
 Add a branch node. More...
 
TreeID id () const
 Get tree ID. More...
 
Index index () const
 Get tree index. More...
 
bool check ()
 Check if tree is valid. More...
 

Related Functions

(Note that these are not member functions.)

template std::pair< bool, GEMMTypequery (GEMMShape shape) const
 Explicit template instantiation. More...
 
template std::pair< bool, GEMMConfigNativequery (GEMMShape shape) const
 Explicit template instantiation. More...
 
template std::pair< bool, GEMMConfigReshapedOnlyRHSquery (GEMMShape shape) const
 Explicit template instantiation. More...
 
template std::pair< bool, GEMMConfigReshapedquery (GEMMShape shape) const
 Explicit template instantiation. More...
 
template bool add_leaf (NodeID id, GEMMType val)
 Explicit template instantiation. More...
 
template bool add_leaf (NodeID id, GEMMConfigNative val)
 Explicit template instantiation. More...
 
template bool add_leaf (NodeID id, GEMMConfigReshapedOnlyRHS val)
 Explicit template instantiation. More...
 
template bool add_leaf (NodeID id, GEMMConfigReshaped val)
 Explicit template instantiation. More...
 

Detailed Description

A binary decision tree based heuristic.

Definition at line 67 of file HeuristicTree.h.

Member Typedef Documentation

◆ Index

using Index = std::tuple<HeuristicType, std::string, DataType>

Definition at line 72 of file HeuristicTree.h.

◆ NodeID

using NodeID = size_t

Definition at line 70 of file HeuristicTree.h.

◆ TreeID

using TreeID = size_t

Definition at line 71 of file HeuristicTree.h.

Member Enumeration Documentation

◆ NodeType

enum NodeType
strong
Enumerator
Branch 
Leaf 

Definition at line 73 of file HeuristicTree.h.

74  {
75  Branch,
76  Leaf
77  };

Constructor & Destructor Documentation

◆ HeuristicTree() [1/4]

Constructor.

Definition at line 95 of file HeuristicTree.cpp.

References arm_compute::F32, and arm_compute::mlgo::GEMM_Type.

97 {
98 }
1 channel, 1 F32 per channel

◆ HeuristicTree() [2/4]

HeuristicTree ( TreeID  id,
HeuristicType  h_type,
const std::string &  ip_target,
DataType  data_type 
)

Constructor.

Definition at line 100 of file HeuristicTree.cpp.

References arm_compute::test::validation::data_type.

101  : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
102 {
103 }
const DataType data_type
Definition: Im2Col.cpp:150

◆ HeuristicTree() [3/4]

HeuristicTree ( const HeuristicTree )
delete

Prevent copy construction.

◆ HeuristicTree() [4/4]

HeuristicTree ( HeuristicTree &&  other)
defaultnoexcept

Move constructor.

Member Function Documentation

◆ add_branch()

bool add_branch ( NodeID  id,
Condition  cond,
NodeID  true_node,
NodeID  false_node 
)

Add a branch node.

Parameters
idBranch node ID
condBranch node Condition
true_nodeTrue node's ID
false_nodeFalse node's ID
Returns
bool If the addition succeeded or not

Definition at line 152 of file HeuristicTree.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, HeuristicTree::Branch, Condition::feature, HeuristicTree::id(), and arm_compute::utility::tolower().

Referenced by arm_compute::mlgo::parser::heuristic_tree().

153 {
154  if(_tree.size() >= _max_num_nodes)
155  {
156  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
157  return false;
158  }
159 
160  const std::set<std::string> supported_features =
161  {
162  "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
163  };
164  const auto orig_feature = cond.feature;
165  std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
166  {
167  return std::tolower(c);
168  });
169  if(supported_features.find(cond.feature) == supported_features.end())
170  {
171  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
172  return false;
173  }
174 
175  if(_tree.find(id) != _tree.end())
176  {
177  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
178  return false;
179  }
180  _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
181  return true;
182 }
std::string tolower(std::string string)
Convert string to lower case.
Definition: Utility.h:205
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
TreeID id() const
Get tree ID.

◆ add_leaf()

bool add_leaf ( NodeID  id,
leaf_value 
)

Add a leaf node.

Template Parameters
TLeaf value type
Parameters
idLeaf node ID
leaf_valueLeaf node value
Returns
bool If the addition succeeded or not

Definition at line 136 of file HeuristicTree.cpp.

References ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, and HeuristicTree::id().

Referenced by HeuristicTree::check(), and arm_compute::mlgo::parser::heuristic_tree().

137 {
138  if(_tree.size() >= _max_num_nodes)
139  {
140  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
141  return false;
142  }
143  if(_tree.find(id) != _tree.end())
144  {
145  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
146  return false;
147  }
148  _tree[id] = std::make_unique<LeafNode<T>>(id, val);
149  return true;
150 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
TreeID id() const
Get tree ID.

◆ check()

bool check ( )

Check if tree is valid.

Returns
bool

Definition at line 220 of file HeuristicTree.cpp.

References HeuristicTree::add_leaf(), ARM_COMPUTE_LOG_INFO_MSG_CORE, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, and arm_compute::test::validation::shape.

221 {
222  if(_tree.empty())
223  {
224  ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
225  return false;
226  }
227  if(_tree.find(_root) == _tree.end())
228  {
229  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
230  return false;
231  }
232  return check_if_structurally_correct();
233 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87

◆ id()

TreeID id ( ) const
inline

Get tree ID.

Returns
TreeID

Definition at line 161 of file HeuristicTree.h.

Referenced by HeuristicTree::add_branch(), and HeuristicTree::add_leaf().

162  {
163  return _id;
164  }

◆ index()

Index index ( void  ) const
inline

Get tree index.

Returns
Index

Definition at line 169 of file HeuristicTree.h.

Referenced by arm_compute::mlgo::parser::heuristic_tree().

170  {
171  return std::make_tuple(_heuristic_type, _ip_target, _data_type);
172  }

◆ operator=() [1/2]

HeuristicTree& operator= ( const HeuristicTree )
delete

Prevent copy assignment.

◆ operator=() [2/2]

HeuristicTree& operator= ( HeuristicTree &&  other)
default

Move assignment.

◆ query()

std::pair< bool, T > query ( GEMMShape  shape) const

Query a leaf value given a gemm shape.

Template Parameters
TLeaf value type
Parameters
shapeA GEMMShape for the query
Returns
std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not

Definition at line 106 of file HeuristicTree.cpp.

References ARM_COMPUTE_ERROR_ON_MSG, ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE, HeuristicTree::Branch, and HeuristicTree::Leaf.

107 {
108  // Root ID = 0;
109  auto cur_node = _tree.at(_root).get();
110  size_t depth = 0;
111  while(cur_node->type() != NodeType::Leaf)
112  {
113  if(depth > _max_query_depth)
114  {
115  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
116  return std::make_pair(false, T{});
117  }
118  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
119  auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
120  if(evaluate(shape, br_node->condition))
121  {
122  cur_node = _tree.at(br_node->true_node).get();
123  }
124  else
125  {
126  cur_node = _tree.at(br_node->false_node).get();
127  }
128  ++depth;
129  }
130  ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
131  auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
132  return std::make_pair(true, l_node->value);
133 }
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456

Friends And Related Function Documentation

◆ add_leaf() [1/4]

template bool add_leaf ( NodeID  id,
GEMMType  val 
)
related

Explicit template instantiation.

◆ add_leaf() [2/4]

template bool add_leaf ( NodeID  id,
GEMMConfigNative  val 
)
related

Explicit template instantiation.

◆ add_leaf() [3/4]

template bool add_leaf ( NodeID  id,
GEMMConfigReshapedOnlyRHS  val 
)
related

Explicit template instantiation.

◆ add_leaf() [4/4]

template bool add_leaf ( NodeID  id,
GEMMConfigReshaped  val 
)
related

Explicit template instantiation.

◆ query() [1/4]

template std::pair< bool, GEMMType > query< GEMMType > ( GEMMShape  shape) const
related

Explicit template instantiation.

◆ query() [2/4]

template std::pair< bool, GEMMConfigNative > query< GEMMConfigNative > ( GEMMShape  shape) const
related

Explicit template instantiation.

◆ query() [3/4]

template std::pair< bool, GEMMConfigReshapedOnlyRHS > query< GEMMConfigReshapedOnlyRHS > ( GEMMShape  shape) const
related

Explicit template instantiation.

◆ query() [4/4]

template std::pair< bool, GEMMConfigReshaped > query< GEMMConfigReshaped > ( GEMMShape  shape) const
related

Explicit template instantiation.


The documentation for this class was generated from the following files: