Compute Library
 21.11
MLGOHeuristics.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
25 #define SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
26 
29 
30 #include <iostream>
31 #include <map>
32 #include <string>
33 #include <utility>
34 namespace arm_compute
35 {
36 namespace mlgo
37 {
38 /** Query interface */
39 struct Query
40 {
41  std::string ip_target; /**< The name of the IP target */
42  DataType data_type; /**< Data type */
43  unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
44  unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
45  unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
46  unsigned int b; /**< Batch size */
47 };
48 
49 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs);
51 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs);
52 
53 /** MLGOHeuristics for configuring GEMM kernels */
55 {
56 public:
57  /** Constructor */
59  /** Default Destructor */
60  ~MLGOHeuristics() = default;
61  /** Prevent Copy Construct */
62  MLGOHeuristics(const MLGOHeuristics &) = delete;
63  /** Prevent Copy Assignment */
64  MLGOHeuristics &operator=(const MLGOHeuristics &) = delete;
65  /** Default Move Constructor */
66  MLGOHeuristics(MLGOHeuristics &&) = default;
67  /** Default Move Assignment */
68  MLGOHeuristics &operator=(MLGOHeuristics &&) = default;
69  /** Query the gemm type
70  *
71  * @param[in] query Query
72  *
73  * @return std::pair<bool, GEMMType> signals if the query succeeded or failed
74  */
75  std::pair<bool, GEMMType> query_gemm_type(const Query &query) const;
76  /** Query the gemm configuration for native kernel
77  *
78  * @param[in] query Query
79  *
80  * @return std::pair<bool, GEMMConfigNative> bool signals if the query succeeded or failed
81  */
82  std::pair<bool, GEMMConfigNative> query_gemm_config_native(const Query &query) const;
83  /** Query the gemm configuration for reshaped only rhs kernel
84  *
85  * @param[in] query Query
86  *
87  * @return std::pair<bool, GEMMConfigReshapedOnlyRHS> bool signals if the query succeeded or failed
88  */
89  std::pair<bool, GEMMConfigReshapedOnlyRHS> query_gemm_config_reshaped_only_rhs(const Query &query) const;
90  /** Query the gemm configuration for reshaped kernel
91  *
92  * @param[in] query Query
93  *
94  * @return std::pair<bool, GEMMConfigReshaped> bool signals if the query succeeded or failed
95  */
96  std::pair<bool, GEMMConfigReshaped> query_gemm_config_reshaped(const Query &query) const;
97  /** (Re)Load the heuristics from reading a dotmlgo file
98  *
99  * @param[in] filename Path to the dotmlgo file
100  *
101  * @return bool Signals if the reload succeeded or failed
102  */
103  bool reload_from_file(const std::string &filename);
104  /** (Re)Load the heuristics from reading an input stream
105  *
106  * @param[in] istream Istream containing mlgo heuristics
107  *
108  * @return bool Signals if the reload succeeded or failed
109  */
110  bool reload_from_stream(std::istream &istream);
111 
112  /** Get the heuristic tree from tree id
113  *
114  * @param[in] id Tree id.
115  *
116  * @return HeuristicTree&
117  */
118  std::pair<bool, HeuristicTree *> get_heuristic_tree(HeuristicTree::TreeID id);
119  /** Add a heuristic tree
120  * @param t Heuristic tree to be added
121  */
122  bool add_heuristic_tree(HeuristicTree &&t);
123 
124  /** Check the validity of the heuristic tree.
125  *
126  * @param id ID of the tree to be checked
127  *
128  * @return bool
129  */
130  bool check_heuristic_tree(HeuristicTree::TreeID id);
131 
132  /** Check the overall validity of the heuristics.
133  * @return bool
134  */
135  bool check_all() const;
136 
137 private:
138  static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/
139 
140 private:
141  // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree
142  std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices; /**< A mapping from TreeID to Index */
143  std::map<HeuristicTree::Index, HeuristicTree> _trees; /**< A mapping from Index to HeuristicTree */
144  std::map<HeuristicTree::TreeID, bool> _tree_valid; /**< Result cache of the tree validity checks */
145  bool _valid; /**< Overall validity */
146 };
147 
148 } // namespace mlgo
149 } // namespace arm_compute
150 #endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
unsigned int b
Batch size.
Query interface.
bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
GEMM Configuration for Reshaped kernel.
Definition: Common.h:66
Copyright (c) 2017-2021 Arm Limited.
std::string ip_target
The name of the IP target.
GEMM Configuration for Reshaped Only RHS kernel.
Definition: Common.h:54
MLGOHeuristics for configuring GEMM kernels.
DataType data_type
Data type.
unsigned int n
Number of columns for the rhs matrix.
A binary decision tree based heuristic.
Definition: HeuristicTree.h:67
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:784
GEMM Configuration for Native kernel.
Definition: Common.h:46
DataType
Available data types.
Definition: Types.h:79
unsigned int m
Number of rows for the lhs matrix.
unsigned int k
Number of rows for the rhs matrix.