31 from collections
import Counter, defaultdict, deque, namedtuple
33 from pathlib
import Path
34 from typing
import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
41 Strategy = Enum(
"Strategy", [
"Native",
"ReshapedOnlyRHS",
"Reshaped"])
55 return cls(*
map(int, M_N_K_B),
str(data_type))
69 (*mnk,) =
map(int, args)
88 export_to_cl_image_rhs: bool
92 (*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) =
map(int, args)
93 interleave_rhs = interleave_rhs == 1
94 transpose_rhs = transpose_rhs == 1
95 export_to_cl_image_rhs = export_to_cl_image_rhs == 1
96 return cls(*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
118 export_to_cl_image_rhs: bool
122 (*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) =
map(int, args)
123 interleave_lhs = interleave_lhs == 1
124 interleave_rhs = interleave_rhs == 1
125 transpose_rhs = transpose_rhs == 1
126 export_to_cl_image_rhs = export_to_cl_image_rhs == 1
127 return cls(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
130 return ",".
join(
map(str, self))
135 opencl_timer_ms_reshape: float
136 opencl_timer_ms_kernel: float
139 return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel
142 return math.fabs(self.
get_total_ms() - other.get_total_ms()) < tol
151 self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape,
152 self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel,
157 self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape,
158 self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel,
163 self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape,
164 self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel,
169 self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape,
170 self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel,
175 self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape,
176 self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel,
181 self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power
185 return ",".
join(
map(str, self))
189 GEMMConfigT = Union[NativeGEMMConfig,
190 ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
195 gemm_param: GEMMParam
197 gemm_config: GEMMConfigT
198 measurement: Measurement
202 """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
205 SummaryLevel = Enum(
"SummaryLevel", [
"Short",
"Detailed"])
210 self._benchmark_result_record: List[BenchmarkResult] = []
215 def add(self, benchmark_result: BenchmarkResult):
216 """ Add a benchmark result to the record.
218 gemm_param, strategy, gemm_config, measurement = benchmark_result
222 self._benchmark_result_record.append(benchmark_result)
224 def get_record(self) -> Generator[BenchmarkResult, None, None]:
225 """ Return an iterator that iterates over the record.
227 yield from self._benchmark_result_record
230 """ Get the best GEMMConfig set per GEMMParam per Strategy
233 Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
234 ] = defaultdict(list)
235 for gemm_param, strategy, gemm_config, measurement
in self.
get_record():
236 best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
237 best_gc_set.append((gemm_config, measurement))
239 best_gc_set = sorted(
240 best_gc_set, key=
lambda gc_and_m: gc_and_m[1].get_total_ms()
243 best_gc, best_m = best_gc_set[0]
245 (gemm_config, measurement)
246 for gemm_config, measurement
in best_gc_set[1:]
247 if measurement.is_close_to(best_m, self.
_tol)
250 best_gc_set_new.insert(0, (best_gc, best_m))
251 best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
256 """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
260 (gemm_param, strategy),
263 for best_gemm_config, best_measurement
in best_gc_sets:
265 gemm_param, strategy, best_gemm_config, best_measurement
269 """ Return GEMMConfigDistribution for each strategy
271 gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
272 GEMMConfigDistribution
275 _, strategy, _, _ = benchmark_result
276 gemm_config_distributions[strategy].
add(benchmark_result)
278 return gemm_config_distributions
281 """ Get the best Stratey per GEMMParam
283 all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict(
287 best_strategies: Dict[GEMMParam, Strategy] = {}
289 for gemm_param, strategy, gemm_config, measurement
in self.
get_record():
290 all_results[gemm_param].append((strategy, measurement))
292 for gemm_param, results_set
in all_results.items():
294 results_set = sorted(
295 results_set, key=
lambda s_and_m: s_and_m[1].get_total_ms()
298 best_s, best_m = results_set[0]
299 best_strategies[gemm_param] = best_s
301 return best_strategies
304 """ Save records to an output directory of JSON files.
305 The directory is organized such that each strategy gets its own JSON file.
306 The directory also includes a JSON file to define the best strategy per GEMM Param.
308 if not os.path.exists(out_dir):
310 "Output directory {} does not exist. Creating...".
format(
315 out_json_path = os.path.join(out_dir,
"gemm_type_selection.json")
318 results = {
str(key): value.name
for key, value
in results.items()}
322 out_json_path = os.path.join(
323 out_dir, (
"gemm_config_" + strategy.name.lower() +
".json")
331 results = defaultdict(list)
333 if res.strategy == strategy:
334 results[
str(res.gemm_param)].append(
336 "GEMMConfig":
str(res.gemm_config),
337 "OpenCL_Timer_ms_reshape":
str(
338 res.measurement.opencl_timer_ms_reshape
340 "OpenCL_Timer_ms_kernel":
str(
341 res.measurement.opencl_timer_ms_kernel
347 def summary(self, sum_level=SummaryLevel.Short):
348 """ Return the summary string of the record
350 num_raw_records = sum(1
for _
in self.
get_record())
351 gemm_params_per_strategy = defaultdict(list)
353 gemm_params_per_strategy[strategy].append(gemm_param)
354 global_summary = f
"""
355 === {self.__class__.__name__} Summary ===
357 Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
358 Total number of results recorded: {num_raw_records}
362 strategy_summaries = []
363 for strategy
in gemm_params_per_strategy:
365 Strategy {strategy.name}:
367 Number of: {len(gemm_params_per_strategy[strategy])}
369 if sum_level == self.__class__.SummaryLevel.Detailed:
371 Content: {gemm_params_per_strategy[strategy]}
373 strategy_summaries.append(summary)
374 return global_summary +
"".
join(strategy_summaries)
378 """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
384 self._gemm_config_dist: Dict[
385 GEMMConfig, List[Tuple[GEMMParam, Measurement]]
386 ] = defaultdict(list)
389 def add(self, benchmark_result: BenchmarkResult):
390 """ Add a benchmark result to the distribution
392 gemm_param, _, gemm_config, measurement = benchmark_result
393 self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
397 return self._gemm_config_dist
400 """ Get the frequency of each (best) gemm config recorded
405 """ Get the overall best config, as voted by all benchmark results.
410 """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
411 as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
416 mean_freq = sum(freqs) / len(freqs)
417 return math.sqrt(sum((freq - mean_freq) ** 2
for freq
in freqs) / len(freqs))
426 GEMM_CONFIG_FACTORY = {
427 Strategy.Native: NativeGEMMConfig,
428 Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
429 Strategy.Reshaped: ReshapedGEMMConfig,
434 EXAMPLE_FILE_2_STRATEGY = {
435 "benchmark_cl_gemm_native": Strategy.Native,
436 "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
437 "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
450 GEMM_EXAMPLE_ARGS_FACTORY = {
452 strategy: namedtuple(
453 "{}_Gemm_Example_Args".
format(strategy_name),
454 GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields,
456 for strategy_name, strategy
in Strategy.__members__.items()
457 if strategy_name == strategy.name
461 BENCHMARK_RESULT_JSON_EXTENSION =
"gemmtuner_benchmark"
469 """ Parse the benchmark example command-line string into a dictionary of command-line arguments
472 commandline = commandline.replace(
",--type=",
" --type=")
474 args = commandline.split()
478 args =
map(
lambda arg: arg.split(
"="), args)
480 def transform(_name):
482 _name = _name.lstrip(
"-")
485 return {transform(name): val
for name, val
in args}
489 json_results: Dict, measurement_method=
"avg"
490 ) -> Generator[BenchmarkResult,
None,
None]:
491 """ Parse the benchmark result and extract relevant information, namely:
497 for json_res
in json_results:
500 example_tests = list(json_res[
"tests"].items())
501 assert len(example_tests) == 1
502 example_fn, example_test_data = example_tests[0]
505 example_fn = example_fn.split(os.path.sep)[-1]
508 strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
512 Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
513 example_args = Gemm_Example_Args_T(
514 *(benchmark_args[
"example_args"].split(
",")))
517 gemm_param_fields_len = len(GEMMParam._fields) - 1
518 gemm_param = GEMMParam.parse_from_strs(
519 *example_args[:gemm_param_fields_len],
520 data_type = benchmark_args[
"type"])
521 GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
522 gemm_config = GEMMConfig.parse_from_strs(
523 *example_args[gemm_param_fields_len:])
526 measurements = list(example_test_data[
"measurements"].items())
529 measurement_ms_reshape = 0
530 measurement_ms_kernel = 0
531 for single_measurement
in measurements:
532 measurement_instrument, data = single_measurement
534 measurement_instrument_name = measurement_instrument.split(
"/")[0]
535 assert measurement_instrument_name ==
"OpenCLTimer"
537 if measurement_method ==
"min":
538 measurement_val = min(data[
"raw"])
539 elif measurement_method ==
"avg":
540 measurement_val = sum(data[
"raw"]) / len(data[
"raw"])
543 "Invalid measurement method: {}".
format(measurement_method)
546 measurement_type = measurement_instrument.split(
"/")[1]
547 if "reshape" in measurement_type.split(
"_"):
548 measurement_ms_reshape = measurement_val
550 measurement_ms_kernel = measurement_val
553 measurement_ms_reshape, measurement_ms_kernel)
559 """ Glob all benchmark result json files and parse them into json objects (dicts).
561 for res_fn
in Path(dir_name).rglob(
"*.{}".
format(BENCHMARK_RESULT_JSON_EXTENSION)):
562 with open(res_fn)
as res_fp:
563 yield json.load(res_fp)
567 if os.path.exists(out_path):
570 "Output JSON {} already exists. Overwrite? [Y/N]: ".
format(
576 logging.info(
"Skipping {}".
format(out_path))
578 logging.info(
"Saving JSON file to {}".
format(out_path))
583 with open(out_path,
"w")
as f:
585 logging.info(
"Saved")
595 "Searching best gemm configurations from {}".
format(
596 args.benchmark_results_dir)
605 for benchmark_result
in benchmark_results:
606 benchmark_result_recorder.add(benchmark_result)
609 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
611 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
614 logging.info(benchmark_result_recorder.summary(
615 sum_level=recorder_sum_level))
618 all_config_dists = benchmark_result_recorder.get_config_distributions()
620 logging.info(
"=== Result ===")
621 for strategy, config_dist
in all_config_dists.items():
622 logging.info(
"Strategy: {}".
format(strategy.name))
623 logging.debug(
"GEMM Config, Votes")
624 for config, freq
in config_dist.frequency():
625 logging.debug(
"{}, {}".
format(config, freq))
627 "Best GEMM Config: {} with std: {}".
format(
628 config_dist.best_config(), config_dist.std()
633 if args.output_dir
is not None:
634 benchmark_result_recorder.save_to_jsons(
635 args.output_dir, only_best_config=(
not args.debug)
639 if __name__ ==
"__main__":
640 parser = argparse.ArgumentParser(description=
"CL GEMM Tuner")
643 "--benchmark_results",
644 dest=
"benchmark_results_dir",
648 help=
"Path to benchmark result directory, where benchmark result json files have a file \
649 extension of '{}'".
format(
650 BENCHMARK_RESULT_JSON_EXTENSION
661 help=
"Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection",
669 help=
"For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
670 milliseconds. Recommended value: <= 0.1 ms",
677 help=
"Enable script debugging output",
679 args = parser.parse_args()
680 logging_level = logging.DEBUG
if args.debug
else logging.INFO
681 logging.basicConfig(level=logging_level)
682 logging.debug(
"Arguments: {}".
format(args))