Compute Library
 21.08
MLGOHeuristics.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Log.h"
29 
30 #include <fstream>
31 
32 namespace arm_compute
33 {
34 namespace mlgo
35 {
36 bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
37 {
38  return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
39 }
41 {
42  return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
43  rhs.export_cl_image);
44 }
45 bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
46 {
47  return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
49 }
50 
51 constexpr size_t MLGOHeuristics::_max_num_trees;
52 
54  : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
55 {
56 }
57 
58 std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) const
59 {
60  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
61  const auto invalid = GEMMType::RESHAPED;
62  if(!_valid)
63  {
64  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
65  return { false, invalid };
66  }
67  auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
68  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
69  if(_trees.find(index) == _trees.end())
70  {
71  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
72  return { false, invalid };
73  }
74  return _trees.at(index).query<GEMMType>(shape_query);
75 }
76 std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
77 {
78  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
79  const auto invalid = GEMMConfigNative{};
80  if(!_valid)
81  {
82  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
83  return { false, invalid };
84  }
85  auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
86  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
87  if(_trees.find(index) == _trees.end())
88  {
89  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
90  return { false, invalid };
91  }
92  return _trees.at(index).query<GEMMConfigNative>(shape_query);
93 }
94 std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
95 {
96  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
97  const auto invalid = GEMMConfigReshapedOnlyRHS{};
98  if(!_valid)
99  {
100  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
101  return { false, invalid };
102  }
103  auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
104  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
105  if(_trees.find(index) == _trees.end())
106  {
107  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
108  return { false, invalid };
109  }
110  return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
111 }
112 std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
113 {
114  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
115  const auto invalid = GEMMConfigReshaped{};
116  if(!_valid)
117  {
118  ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
119  return { false, invalid };
120  }
121  auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
122  GEMMShape shape_query{ query.m, query.n, query.k, query.b };
123  if(_trees.find(index) == _trees.end())
124  {
125  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
126  return { false, invalid };
127  }
128  return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
129 }
130 
132 {
133  bool status;
134  HeuristicTree *tree{ nullptr };
135  std::tie(status, tree) = get_heuristic_tree(id);
136  if(!status)
137  {
138  return status;
139  }
140  status = tree->check();
141  if(!status)
142  {
143  return status;
144  }
145  _tree_valid[id] = true;
146  return true;
147 }
148 
150 {
151  // Tree validities are already checked and cached.
152  bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
153  {
154  return !v.second;
155  })
156  == _tree_valid.end();
157  if(!all_trees_are_checked)
158  {
159  ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
160  return false;
161  }
162 
163  // Other top level checks...
164 
165  return true;
166 }
167 
168 std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
169 {
170  if(_indices.find(id) == _indices.end())
171  {
172  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
173  return std::make_pair(false, nullptr);
174  }
175  const auto index = _indices[id];
176 
177  if(_trees.find(index) == _trees.end())
178  {
179  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
180  return std::make_pair(false, nullptr);
181  }
182  auto &t = _trees[index];
183 
184  return std::make_pair(true, &t);
185 }
186 
188 {
189  if(_indices.size() >= _max_num_trees)
190  {
191  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
192  return false;
193  }
194  // PRE: correctness of t is guaranteed by the tree construction process
195  // Ensure unique id
196  const auto id = t.id();
197  if(_indices.find(id) != _indices.end())
198  {
199  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
200  return false;
201  }
202 
203  // Ensure unique index
204  const auto index = t.index();
205  if(_trees.find(index) != _trees.end())
206  {
207  ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
208  return false;
209  }
210 
211  _indices[id] = index;
212  _trees[index] = std::move(t);
213  _tree_valid[id] = false;
214  return true;
215 }
216 
217 bool MLGOHeuristics::reload_from_file(const std::string &filename)
218 {
219  std::ifstream fs;
220  fs.exceptions(std::ifstream::badbit);
221  fs.open(filename, std::ios::in);
222  if(!fs.is_open())
223  {
224  ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
225  return _valid = false;
226  }
227  return reload_from_stream(fs);
228 }
229 
231 {
232  auto parsed = parser::parse_mlgo(in);
233  if(!parsed.first)
234  {
235  ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
236  return _valid = false;
237  }
238  *this = std::move(parsed.second);
239  ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
240  return _valid = true;
241 }
242 
243 } // namespace mlgo
244 } // namespace arm_compute
unsigned int m0
Number of rows processed by the matrix multiplication.
Definition: Common.h:48
bool transpose_rhs
True if the (k0xn0) block has to be transposed before been stored.
Definition: Common.h:75
unsigned int b
Batch size.
Query interface.
std::pair< bool, HeuristicTree * > get_heuristic_tree(HeuristicTree::TreeID id)
Get the heuristic tree from tree id.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Common.h:58
bool reload_from_file(const std::string &filename)
(Re)Load the heuristics from reading a dotmlgo file
bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
bool interleave_rhs
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
Definition: Common.h:74
GEMM Configuration for Reshaped kernel.
Definition: Common.h:66
bool interleave_rhs
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
Definition: Common.h:60
std::string to_string(const GEMMConfigNative &config)
Definition: Utils.cpp:156
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Definition: Common.h:72
bool export_cl_image
True if the reshaped rhs has to be exported to cl_image.
Definition: Common.h:62
CLGEMMKernelType
OpenCL GEMM kernel types.
Definition: CLTypes.h:31
Reshaped GEMM kernel where both lhs and rhs matrices are reshaped.
unsigned int n0
Number of columns processed by the matrix multiplication.
Definition: Common.h:69
std::pair< bool, GEMMType > query_gemm_type(const Query &query) const
Query the gemm type.
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Definition: Common.h:59
Copyright (c) 2017-2021 Arm Limited.
bool check_all() const
Check the overall validity of the heuristics.
About the gemm config for reshaped kernel.
About the gemm config for native kernel.
std::string ip_target
The name of the IP target.
#define ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(fmt,...)
Log information level formatted message to the core system logger.
Definition: Log.h:99
unsigned int n0
Number of columns processed by the matrix multiplication.
Definition: Common.h:57
unsigned int m0
Number of rows processed by the matrix multiplication.
Definition: Common.h:56
bool add_heuristic_tree(HeuristicTree &&t)
Add a heuristic tree.
GEMM Configuration for Reshaped Only RHS kernel.
Definition: Common.h:54
std::pair< bool, MLGOHeuristics > parse_mlgo(std::istream &in)
Parse and construct a MLGOHeuristics from input stream.
Definition: MLGOParser.cpp:798
DataType data_type
Data type.
unsigned int n
Number of columns for the rhs matrix.
bool export_cl_image
True if the reshaped rhs has to be exported to cl_image.
Definition: Common.h:76
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Common.h:50
unsigned int m0
Number of rows processed by the matrix multiplication.
Definition: Common.h:68
A binary decision tree based heuristic.
Definition: HeuristicTree.h:67
bool transpose_rhs
True if the (k0xn0) block has to be transposed before been stored.
Definition: Common.h:61
unsigned int v0
Number of vertical blocks of size (m0xk0) stored on the same output row.
Definition: Common.h:71
bool interleave_lhs
True if the v0 (m0xk0) blocks have to be interleaved in the output row.
Definition: Common.h:73
std::pair< bool, GEMMConfigReshapedOnlyRHS > query_gemm_config_reshaped_only_rhs(const Query &query) const
Query the gemm configuration for reshaped only rhs kernel.
std::pair< bool, GEMMConfigNative > query_gemm_config_native(const Query &query) const
Query the gemm configuration for native kernel.
MLGOHeuristics mlgo(TokenStream &in, bool &valid)
Definition: MLGOParser.cpp:784
bool check_heuristic_tree(HeuristicTree::TreeID id)
Check the validity of the heuristic tree.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Common.h:70
#define ARM_COMPUTE_LOG_INFO_MSG_CORE(msg)
Log information level message to the core system logger.
Definition: Log.h:87
GEMM Configuration for Native kernel.
Definition: Common.h:46
About the gemm config for reshaped only rhs kernel.
GEMM Shape used for query.
Definition: HeuristicTree.h:58
bool reload_from_stream(std::istream &istream)
(Re)Load the heuristics from reading an input stream
unsigned int n0
Number of columns processed by the matrix multiplication.
Definition: Common.h:49
unsigned int m
Number of rows for the lhs matrix.
std::pair< bool, GEMMConfigReshaped > query_gemm_config_reshaped(const Query &query) const
Query the gemm configuration for reshaped kernel.
unsigned int k
Number of rows for the rhs matrix.