overestimate `time_taken_ns` for autotuning (#133633)
tldr; in `autotune_to_one_config` we now include the precompile time, and in coordesc tuning we include the time from `autotune_to_one_config`, since this is a precursor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133633
Approved by: https://github.com/oulgen, https://github.com/eellison
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
index e54647e..2648cd9 100644
--- a/torch/_inductor/runtime/triton_heuristics.py
+++ b/torch/_inductor/runtime/triton_heuristics.py
@@ -224,6 +224,9 @@
)
self.filename = filename
+ self.precompile_time_taken_ns = 0
+ self.autotune_time_taken_ns = 0
+
def precompile(self, warm_cache_only=False):
with self.lock:
if self.launchers:
@@ -726,10 +729,13 @@
"""Do the actual autotuning"""
start_time = time.time_ns()
timings = self.benchmark_all_configs(*args, **kwargs)
- time_taken_ns = time.time_ns() - start_time
+ benchmark_time_taken_ns = time.time_ns() - start_time
self.launchers = [builtins.min(timings, key=timings.get)]
+ self.autotune_time_taken_ns = (
+ self.precompile_time_taken_ns + benchmark_time_taken_ns
+ )
if self.save_cache_hook:
- self.save_cache_hook(self.launchers[0].config, time_taken_ns)
+ self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
def save_gpu_kernel(self, grid, stream, launcher):
if callable(grid):
@@ -811,17 +817,23 @@
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
)
- time_taken_ns = time.time_ns() - start_time
+ coordesc_time_taken_ns = time.time_ns() - start_time
best_config.found_by_coordesc = True
if self.save_cache_hook:
- self.save_cache_hook(best_config, time_taken_ns, found_by_coordesc=True)
+ self.save_cache_hook(
+ best_config,
+ self.autotune_time_taken_ns + coordesc_time_taken_ns,
+ found_by_coordesc=True,
+ )
return config2launcher.get(best_config)
def run(self, *args, grid, stream, **kwargs):
if len(self.launchers) != 1:
if len(self.launchers) == 0:
+ start_time = time.time_ns()
self.precompile()
+ self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid, **kwargs)