Compute Library
 23.05
GemmTuner.py
Go to the documentation of this file.
1 # Copyright (c) 2019-2020 Arm Limited.
2 #
3 # SPDX-License-Identifier: MIT
4 #
5 # Permission is hereby granted, free of charge, to any person obtaining a copy
6 # of this software and associated documentation files (the "Software"), to
7 # deal in the Software without restriction, including without limitation the
8 # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9 # sell copies of the Software, and to permit persons to whom the Software is
10 # furnished to do so, subject to the following conditions:
11 #
12 # The above copyright notice and this permission notice shall be included in all
13 # copies or substantial portions of the Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 # SOFTWARE.
22 
23 #!/usr/bin/python3
24 
25 import argparse
26 import csv
27 import json
28 import logging
29 import math
30 import os
31 from collections import Counter, defaultdict, deque, namedtuple
32 from enum import Enum
33 from pathlib import Path
34 from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
35 
36 ################################################################################
37 # Types
38 ################################################################################
39 
40 # Gemm strategy
41 Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
42 
43 # Gemm parameter
44 
45 
46 class GEMMParam(NamedTuple):
47  M: int # Number of lhs matrix rows
48  N: int # Number of rhs matrix columns
49  K: int # Number of lhs matrix columns/rhs matrix rows
50  B: int # Batch size
51  data_type: str # Data type
52 
53  @classmethod
54  def parse_from_strs(cls, *M_N_K_B, data_type):
55  return cls(*map(int, M_N_K_B), str(data_type))
56 
57  def __str__(self):
58  return ",".join(map(str, self))
59 
60 
61 # Gemm configuration for strategy Native
62 class NativeGEMMConfig(NamedTuple):
63  m0: int # Number of rows processed by the matrix multiplication
64  n0: int # Number of columns processed by the matrix multiplication
65  k0: int # Number of partial accumulations performed by the matrix multiplication
66 
67  @classmethod
68  def parse_from_strs(cls, *args):
69  (*mnk,) = map(int, args)
70  return cls(*mnk)
71 
72  def __str__(self):
73  return ",".join(map(str, self))
74 
75 
76 # Gemm configuration for strategy Reshaped Only RHS
77 class ReshapedOnlyRHSGEMMConfig(NamedTuple):
78  m0: int # Number of rows processed by the matrix multiplication
79  n0: int # Number of columns processed by the matrix multiplication
80  k0: int # Number of partial accumulations performed by the matrix multiplication
81  # Number of horizontal blocks of size (k0xn0) stored on the same output row
82  h0: int
83  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
84  interleave_rhs: bool
85  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
86  transpose_rhs: bool
87  # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)
88  export_to_cl_image_rhs: bool
89 
90  @classmethod
91  def parse_from_strs(cls, *args):
92  (*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args)
93  interleave_rhs = interleave_rhs == 1
94  transpose_rhs = transpose_rhs == 1
95  export_to_cl_image_rhs = export_to_cl_image_rhs == 1
96  return cls(*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
97 
98  def __str__(self):
99  return ",".join(map(str, self))
100 
101 
102 # Gemm configuration for strategy Reshaped
103 class ReshapedGEMMConfig(NamedTuple):
104  m0: int # Number of rows processed by the matrix multiplication
105  n0: int # Number of columns processed by the matrix multiplication
106  k0: int # Number of partial accumulations performed by the matrix multiplication
107  # Number of vertical blocks of size (m0xk0) stored on the same output row
108  v0: int
109  # Number of horizontal blocks of size (k0xn0) stored on the same output row
110  h0: int
111  # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
112  interleave_lhs: bool
113  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
114  interleave_rhs: bool
115  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
116  transpose_rhs: bool
117  # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)
118  export_to_cl_image_rhs: bool
119 
120  @classmethod
121  def parse_from_strs(cls, *args):
122  (*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args)
123  interleave_lhs = interleave_lhs == 1
124  interleave_rhs = interleave_rhs == 1
125  transpose_rhs = transpose_rhs == 1
126  export_to_cl_image_rhs = export_to_cl_image_rhs == 1
127  return cls(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
128 
129  def __str__(self):
130  return ",".join(map(str, self))
131 
132 
133 # Measurement we take from the benchmark result.
134 class Measurement(NamedTuple):
135  opencl_timer_ms_reshape: float
136  opencl_timer_ms_kernel: float
137 
138  def get_total_ms(self):
139  return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel
140 
141  def is_close_to(self, other, tol):
142  return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol
143 
144  def is_better_than(self, other, tol):
145  return self.get_total_ms() < other.get_total_ms() and not self.is_close_to(
146  other
147  )
148 
149  def __add__(self, other):
150  return Measurement(
151  self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape,
152  self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel,
153  )
154 
155  def __sub__(self, other):
156  return Measurement(
157  self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape,
158  self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel,
159  )
160 
161  def __mul__(self, other):
162  return Measurement(
163  self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape,
164  self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel,
165  )
166 
167  def __floordiv__(self, other):
168  return Measurement(
169  self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape,
170  self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel,
171  )
172 
173  def __truediv__(self, other):
174  return Measurement(
175  self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape,
176  self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel,
177  )
178 
179  def __pow__(self, power):
180  return Measurement(
181  self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power
182  )
183 
184  def __str__(self):
185  return ",".join(map(str, self))
186 
187 
188 # GEMMConfig Type
189 GEMMConfigT = Union[NativeGEMMConfig,
190  ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
191 
192 
193 # Representation of the benchmark result from a single experiment
194 class BenchmarkResult(NamedTuple):
195  gemm_param: GEMMParam
196  strategy: Strategy
197  gemm_config: GEMMConfigT
198  measurement: Measurement
199 
200 
202  """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
203  """
204 
205  SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
206 
207  def __init__(self, tol=0.01):
208  """ Initializer
209  """
210  self._benchmark_result_record: List[BenchmarkResult] = []
211  # Strategies recorded
212  self._strategies = set()
213  self._tol = tol
214 
215  def add(self, benchmark_result: BenchmarkResult):
216  """ Add a benchmark result to the record.
217  """
218  gemm_param, strategy, gemm_config, measurement = benchmark_result
219  # Update strategies encoutnered
220  self._strategies.add(strategy)
221 
222  self._benchmark_result_record.append(benchmark_result)
223 
224  def get_record(self) -> Generator[BenchmarkResult, None, None]:
225  """ Return an iterator that iterates over the record.
226  """
227  yield from self._benchmark_result_record
228 
230  """ Get the best GEMMConfig set per GEMMParam per Strategy
231  """
232  best_gc_sets: Dict[
233  Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
234  ] = defaultdict(list)
235  for gemm_param, strategy, gemm_config, measurement in self.get_record():
236  best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
237  best_gc_set.append((gemm_config, measurement))
238  # Sort the best config set (list)
239  best_gc_set = sorted(
240  best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms()
241  )
242  # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
243  best_gc, best_m = best_gc_set[0]
244  best_gc_set_new = [
245  (gemm_config, measurement)
246  for gemm_config, measurement in best_gc_set[1:]
247  if measurement.is_close_to(best_m, self._tol)
248  ]
249  # Add back the best config
250  best_gc_set_new.insert(0, (best_gc, best_m))
251  best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
252 
253  return best_gc_sets
254 
256  """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
257  of BenchmarkResults
258  """
259  for (
260  (gemm_param, strategy),
261  best_gc_sets,
262  ) in self.get_best_gemm_configs().items():
263  for best_gemm_config, best_measurement in best_gc_sets:
264  yield BenchmarkResult(
265  gemm_param, strategy, best_gemm_config, best_measurement
266  )
267 
269  """ Return GEMMConfigDistribution for each strategy
270  """
271  gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
272  GEMMConfigDistribution
273  )
274  for benchmark_result in self.get_best_gemm_configs_as_sequence():
275  _, strategy, _, _ = benchmark_result
276  gemm_config_distributions[strategy].add(benchmark_result)
277 
278  return gemm_config_distributions
279 
281  """ Get the best Stratey per GEMMParam
282  """
283  all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict(
284  list
285  )
286 
287  best_strategies: Dict[GEMMParam, Strategy] = {}
288 
289  for gemm_param, strategy, gemm_config, measurement in self.get_record():
290  all_results[gemm_param].append((strategy, measurement))
291 
292  for gemm_param, results_set in all_results.items():
293  # Sort the best results set (list)
294  results_set = sorted(
295  results_set, key=lambda s_and_m: s_and_m[1].get_total_ms()
296  )
297  # Select best Strategy
298  best_s, best_m = results_set[0]
299  best_strategies[gemm_param] = best_s
300 
301  return best_strategies
302 
303  def save_to_jsons(self, out_dir, only_best_config=True):
304  """ Save records to an output directory of JSON files.
305  The directory is organized such that each strategy gets its own JSON file.
306  The directory also includes a JSON file to define the best strategy per GEMM Param.
307  """
308  if not os.path.exists(out_dir):
309  logging.info(
310  "Output directory {} does not exist. Creating...".format(
311  out_dir)
312  )
313  os.mkdir(out_dir)
314 
315  out_json_path = os.path.join(out_dir, "gemm_type_selection.json")
316  if check_out_path(out_json_path):
317  results = self.get_best_gemm_strategies()
318  results = {str(key): value.name for key, value in results.items()}
319  dump_json(out_json_path, results)
320 
321  for strategy in self._strategies:
322  out_json_path = os.path.join(
323  out_dir, ("gemm_config_" + strategy.name.lower() + ".json")
324  )
325  if check_out_path(out_json_path):
326  record = (
328  if only_best_config
329  else self.get_record()
330  )
331  results = defaultdict(list)
332  for res in record:
333  if res.strategy == strategy:
334  results[str(res.gemm_param)].append(
335  {
336  "GEMMConfig": str(res.gemm_config),
337  "OpenCL_Timer_ms_reshape": str(
338  res.measurement.opencl_timer_ms_reshape
339  ),
340  "OpenCL_Timer_ms_kernel": str(
341  res.measurement.opencl_timer_ms_kernel
342  ),
343  }
344  )
345  dump_json(out_json_path, results)
346 
347  def summary(self, sum_level=SummaryLevel.Short):
348  """ Return the summary string of the record
349  """
350  num_raw_records = sum(1 for _ in self.get_record())
351  gemm_params_per_strategy = defaultdict(list)
352  for gemm_param, strategy in self.get_best_gemm_configs().keys():
353  gemm_params_per_strategy[strategy].append(gemm_param)
354  global_summary = f"""
355 === {self.__class__.__name__} Summary ===
356 [Global]
357 Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
358 Total number of results recorded: {num_raw_records}
359 
360 [Per strategy]
361  """
362  strategy_summaries = []
363  for strategy in gemm_params_per_strategy:
364  summary = f"""
365 Strategy {strategy.name}:
366 GEMM parameters:
367  Number of: {len(gemm_params_per_strategy[strategy])}
368  """
369  if sum_level == self.__class__.SummaryLevel.Detailed:
370  summary += f"""
371  Content: {gemm_params_per_strategy[strategy]}
372  """
373  strategy_summaries.append(summary)
374  return global_summary + "".join(strategy_summaries)
375 
376 
378  """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
379  """
380 
381  def __init__(self):
382  """ Initializer
383  """
384  self._gemm_config_dist: Dict[
385  GEMMConfig, List[Tuple[GEMMParam, Measurement]]
386  ] = defaultdict(list)
387  self._gemm_config_freq = Counter()
388 
389  def add(self, benchmark_result: BenchmarkResult):
390  """ Add a benchmark result to the distribution
391  """
392  gemm_param, _, gemm_config, measurement = benchmark_result
393  self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
394  self._gemm_config_freq[gemm_config] += 1
395 
396  def distribution(self):
397  return self._gemm_config_dist
398 
399  def frequency(self):
400  """ Get the frequency of each (best) gemm config recorded
401  """
402  return self._gemm_config_freq.most_common()
403 
404  def best_config(self):
405  """ Get the overall best config, as voted by all benchmark results.
406  """
407  return self._gemm_config_freq.most_common(1)
408 
409  def std(self):
410  """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
411  as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
412  """
413  freqs = self._gemm_config_freq.values()
414  if len(freqs) == 0:
415  return 0
416  mean_freq = sum(freqs) / len(freqs)
417  return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs))
418 
419 
420 ################################################################################
421 # Globals
422 ################################################################################
423 
424 # Gemm config type factory
425 # Produces a GEMMConfig type specific to a Strategy
426 GEMM_CONFIG_FACTORY = {
427  Strategy.Native: NativeGEMMConfig,
428  Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
429  Strategy.Reshaped: ReshapedGEMMConfig,
430 }
431 
432 # Mapping from example binary name to Strategy
433 # Assume 1-to-1 mapping
434 EXAMPLE_FILE_2_STRATEGY = {
435  "benchmark_cl_gemm_native": Strategy.Native,
436  "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
437  "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
438 }
439 
440 # Gemm example arguments type factory
441 # Produces a Gemm_Example_Args type specific to a Strategy
442 # Gemm example arguments consist of:
443 # GEMMParam + GEMMConfig
444 # in that order.
445 # For example, the example args of running a reshaped rhs only example could be:
446 # 100,100,100,1, 4, 4, 4, 1, 1, 1, 0
447 # M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs,export_to_cl_image_rhs
448 # <-GEMMParam-><-------------GEMMConfig--------------------------------------->
449 # Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
450 GEMM_EXAMPLE_ARGS_FACTORY = {
451  # We ignore the data type field from GEMMParam as that is extracted separately
452  strategy: namedtuple(
453  "{}_Gemm_Example_Args".format(strategy_name),
454  GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields,
455  )
456  for strategy_name, strategy in Strategy.__members__.items()
457  if strategy_name == strategy.name
458 }
459 
460 # File extension used for benchmark result json files
461 BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
462 
463 ################################################################################
464 # Functions
465 ################################################################################
466 
467 
468 def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
469  """ Parse the benchmark example command-line string into a dictionary of command-line arguments
470  """
471  # Separate the data type option from the example_args portion of the string
472  commandline = commandline.replace(",--type=", " --type=")
473 
474  args = commandline.split()
475  # Discard program name
476  args = args[1:]
477  # Split into a list of (argument name, argument value)
478  args = map(lambda arg: arg.split("="), args)
479 
480  def transform(_name):
481  # Strip '-'/"--" if it exists
482  _name = _name.lstrip("-")
483  return _name
484 
485  return {transform(name): val for name, val in args}
486 
487 
489  json_results: Dict, measurement_method="avg"
490 ) -> Generator[BenchmarkResult, None, None]:
491  """ Parse the benchmark result and extract relevant information, namely:
492  GEMM param,
493  Strategy,
494  GEMM config,
495  Measurements
496  """
497  for json_res in json_results:
498  # Get example test and test data.
499  # There should only be 1 test per run
500  example_tests = list(json_res["tests"].items())
501  assert len(example_tests) == 1
502  example_fn, example_test_data = example_tests[0]
503 
504  # Process example file name
505  example_fn = example_fn.split(os.path.sep)[-1]
506 
507  # Get strategy
508  strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
509 
510  # Get gemm params + gemm configs from example args
511  benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
512  Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
513  example_args = Gemm_Example_Args_T(
514  *(benchmark_args["example_args"].split(",")))
515  # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
516  # However data type option is parsed separately from end of options, hence -1 is applied to fields length
517  gemm_param_fields_len = len(GEMMParam._fields) - 1
518  gemm_param = GEMMParam.parse_from_strs(
519  *example_args[:gemm_param_fields_len],
520  data_type = benchmark_args["type"])
521  GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
522  gemm_config = GEMMConfig.parse_from_strs(
523  *example_args[gemm_param_fields_len:])
524 
525  # Get OpenCL_Time_Ms stats
526  measurements = list(example_test_data["measurements"].items())
527  # For reshaped RHS only we have two measurements (one also for the reshape kernel)
528  # Hence we must parse and sum them
529  measurement_ms_reshape = 0
530  measurement_ms_kernel = 0
531  for single_measurement in measurements:
532  measurement_instrument, data = single_measurement
533  # Get instrument name and assert that it is the one we expect
534  measurement_instrument_name = measurement_instrument.split("/")[0]
535  assert measurement_instrument_name == "OpenCLTimer"
536  # Take either the minimum or the average of the raw data as the measurement value
537  if measurement_method == "min":
538  measurement_val = min(data["raw"])
539  elif measurement_method == "avg":
540  measurement_val = sum(data["raw"]) / len(data["raw"])
541  else:
542  raise ValueError(
543  "Invalid measurement method: {}".format(measurement_method)
544  )
545 
546  measurement_type = measurement_instrument.split("/")[1]
547  if "reshape" in measurement_type.split("_"):
548  measurement_ms_reshape = measurement_val
549  else:
550  measurement_ms_kernel = measurement_val
551 
552  measurement = Measurement(
553  measurement_ms_reshape, measurement_ms_kernel)
554 
555  yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
556 
557 
558 def parse_json(dir_name):
559  """ Glob all benchmark result json files and parse them into json objects (dicts).
560  """
561  for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)):
562  with open(res_fn) as res_fp:
563  yield json.load(res_fp)
564 
565 
566 def check_out_path(out_path):
567  if os.path.exists(out_path):
568  overwrite = (
569  input(
570  "Output JSON {} already exists. Overwrite? [Y/N]: ".format(
571  out_path)
572  ).lower()
573  == "y"
574  )
575  if not overwrite:
576  logging.info("Skipping {}".format(out_path))
577  return False
578  logging.info("Saving JSON file to {}".format(out_path))
579  return True
580 
581 
582 def dump_json(out_path, dict):
583  with open(out_path, "w") as f:
584  json.dump(dict, f)
585  logging.info("Saved")
586 
587 
588 ################################################################################
589 # Main
590 ################################################################################
591 
592 
593 def main(args):
594  logging.info(
595  "Searching best gemm configurations from {}".format(
596  args.benchmark_results_dir)
597  )
598 
599  benchmark_results = extract_benchmark_results(
600  parse_json(args.benchmark_results_dir)
601  )
602 
603  # Add all benchmark results to the recorder
604  benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
605  for benchmark_result in benchmark_results:
606  benchmark_result_recorder.add(benchmark_result)
607 
608  if args.debug:
609  recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
610  else:
611  recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
612 
613  # Print overall summary of the recorded results
614  logging.info(benchmark_result_recorder.summary(
615  sum_level=recorder_sum_level))
616 
617  # Get GEMM configuration distributions for each strategy
618  all_config_dists = benchmark_result_recorder.get_config_distributions()
619 
620  logging.info("=== Result ===")
621  for strategy, config_dist in all_config_dists.items():
622  logging.info("Strategy: {}".format(strategy.name))
623  logging.debug("GEMM Config, Votes")
624  for config, freq in config_dist.frequency():
625  logging.debug("{}, {}".format(config, freq))
626  logging.info(
627  "Best GEMM Config: {} with std: {}".format(
628  config_dist.best_config(), config_dist.std()
629  )
630  )
631 
632  # Save the recorded results to JSON files in output directory
633  if args.output_dir is not None:
634  benchmark_result_recorder.save_to_jsons(
635  args.output_dir, only_best_config=(not args.debug)
636  )
637 
638 
639 if __name__ == "__main__":
640  parser = argparse.ArgumentParser(description="CL GEMM Tuner")
641  parser.add_argument(
642  "-b",
643  "--benchmark_results",
644  dest="benchmark_results_dir",
645  metavar="PATH",
646  action="store",
647  type=str,
648  help="Path to benchmark result directory, where benchmark result json files have a file \
649  extension of '{}'".format(
650  BENCHMARK_RESULT_JSON_EXTENSION
651  ),
652  required=True,
653  )
654  parser.add_argument(
655  "-o",
656  "--output_dir",
657  dest="output_dir",
658  metavar="PATH",
659  action="store",
660  type=str,
661  help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection",
662  )
663  parser.add_argument(
664  "-t",
665  "--tolerance",
666  action="store",
667  type=float,
668  default=0.01,
669  help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
670  milliseconds. Recommended value: <= 0.1 ms",
671  )
672  parser.add_argument(
673  "-D",
674  "--debug",
675  dest="debug",
676  action="store_true",
677  help="Enable script debugging output",
678  )
679  args = parser.parse_args()
680  logging_level = logging.DEBUG if args.debug else logging.INFO
681  logging.basicConfig(level=logging_level)
682  logging.debug("Arguments: {}".format(args))
683  main(args)
def __str__(self)
Definition: GemmTuner.py:57
std::string join(T first, T last, const std::string &separator)
Helper function to concatenate multiple strings.
Definition: Utils.h:93
def __truediv__(self, other)
Definition: GemmTuner.py:173
def __add__(self, other)
Definition: GemmTuner.py:149
def parse_from_strs(cls, M_N_K_B, data_type)
Definition: GemmTuner.py:54
def parse_from_strs(cls, args)
Definition: GemmTuner.py:68
def check_out_path(out_path)
Definition: GemmTuner.py:566
def __pow__(self, power)
Definition: GemmTuner.py:179
def extract_benchmark_results
Definition: GemmTuner.py:489
def __mul__(self, other)
Definition: GemmTuner.py:161
def parse_json(dir_name)
Definition: GemmTuner.py:558
def get_total_ms(self)
Definition: GemmTuner.py:138
def is_close_to(self, other, tol)
Definition: GemmTuner.py:141
def is_better_than(self, other, tol)
Definition: GemmTuner.py:144
def dump_json(out_path, dict)
Definition: GemmTuner.py:582
void map(T &tensor, bool blocking)
Maps a tensor if needed.
Definition: Utils.h:212
def summary(self, sum_level=SummaryLevel.Short)
Definition: GemmTuner.py:347
def save_to_jsons(self, out_dir, only_best_config=True)
Definition: GemmTuner.py:303
dst_shape set(0, output_wh.first)
def main(args)
Main.
Definition: GemmTuner.py:593
def parse_from_strs(cls, args)
Definition: GemmTuner.py:121
def __sub__(self, other)
Definition: GemmTuner.py:155
def __floordiv__(self, other)
Definition: GemmTuner.py:167
def parse_benchmark_commandline
Functions.
Definition: GemmTuner.py:468