[reland][inductor] coordinate descent tuning upon max-autotune (#99594)
Reland https://github.com/pytorch/pytorch/pull/97203 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99594
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py
new file mode 100644
index 0000000..71c5590
--- /dev/null
+++ b/test/inductor/test_coordinate_descent_tuner.py
@@ -0,0 +1,37 @@
+# Owner(s): ["module: inductor"]
+
+import sys
+import unittest
+
+from torch._dynamo.test_case import run_tests, TestCase
+from torch.testing._internal.common_utils import IS_LINUX
+from torch.testing._internal.inductor_utils import HAS_CUDA
+
+try:
+ import triton
+except ImportError:
+ if __name__ == "__main__":
+ sys.exit(0)
+ raise unittest.SkipTest("requires triton")
+
+from torch._inductor.coordinate_descent_tuner import CoordescTuner
+
+
+class TestCoordinateDescentTuner(TestCase):
+ def test_abs_function(self):
+ """
+ The benchmark result is simply abs(XBLOCK - 15)
+ """
+ tuner = CoordescTuner()
+ baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1)
+
+ def func(config):
+ return abs(config.kwargs["XBLOCK"] - 15)
+
+ best_config = tuner.autotune(func, baseline_config)
+ self.assertTrue(best_config.kwargs.get("XBLOCK") == 16)
+
+
+if __name__ == "__main__":
+ if IS_LINUX and HAS_CUDA:
+ run_tests()
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 76d92a7..86a1a7b 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -306,7 +306,11 @@
https://github.com/pytorch/torchdynamo/issues/1670
"""
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
- from torch._inductor.triton_heuristics import CachingAutotuner, grid
+ from torch._inductor.triton_heuristics import (
+ CachingAutotuner,
+ grid,
+ HeuristicType,
+ )
from torch._inductor.utils import instance_descriptor
def autotune(configs, meta):
@@ -318,6 +322,7 @@
configs=configs,
save_cache_hook=False,
mutated_arg_names=["in_out_ptr0"],
+ heuristic_type=HeuristicType.POINTWISE,
)
return decorator
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index c37bdc8..283a38a 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -60,6 +60,11 @@
# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
+
+coordinate_descent_tuning = (
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
+)
+
# control store vs recompute heuristic
# For fanouts, rematerialization can lead to exponential blowup. So, have
# smaller threshold
@@ -248,7 +253,9 @@
descriptive_names = "original_aten"
# use alternate codegen for smaller reductions
- persistent_reductions = True
+ persistent_reductions = (
+ os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
+ )
# hint to Triton when arguments are divisible by 16
divisible_by_16 = True
diff --git a/torch/_inductor/coordinate_descent_tuner.py b/torch/_inductor/coordinate_descent_tuner.py
new file mode 100644
index 0000000..93cd4a7
--- /dev/null
+++ b/torch/_inductor/coordinate_descent_tuner.py
@@ -0,0 +1,169 @@
+import copy
+import logging
+from typing import Callable, Optional
+
+from .utils import has_triton, triton_config_to_hashable
+
+if has_triton():
+ import triton
+else:
+ triton = None
+
+
+log = logging.getLogger(__name__)
+
+
+def get_field(config, name):
+ if name == "num_warps":
+ return config.num_warps
+ elif name == "num_stages":
+ return config.num_stages
+ else:
+ return config.kwargs.get(name, None)
+
+
+def set_field(config, name, value):
+ if name == "num_warps":
+ config.num_warps = value
+ elif name == "num_stages":
+ config.num_stages = value
+ else:
+ config.kwargs[name] = value
+
+
+class CoordescTuner:
+ """
+ The coordinate descent tuner. Tune one field/coordinate at a time.
+
+ TODO will it be necessary to tune multiple fields simultanuously.
+
+
+ TODO: what if both increasing and descreasing a field can improve perf.
+ i.e., there are multiple local optima..
+ """
+
+ def __init__(self, is_mm=False, name="unknown"):
+ self.is_mm = is_mm # we will tune num_stages for mm
+ self.cached_benchmark_results = {}
+ self.name = name
+
+ def cache_benchmark_result(self, config, timing):
+ self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
+
+ def lookup_in_cache(self, config):
+ return self.cached_benchmark_results.get(triton_config_to_hashable(config))
+
+ def call_func(self, func, config):
+ found = self.lookup_in_cache(config)
+ if found is not None:
+ log.debug(" CACHED")
+ return found
+ timing = func(config)
+ self.cache_benchmark_result(config, timing)
+ return timing
+
+ @property
+ def tunable_fields(self):
+ out = [
+ "XBLOCK",
+ "YBLOCK",
+ "ZBLOCK",
+ # NOTE: we should not tune RBLOCK for persistent reduction.
+ # We rely on the fact that persistent reduction's triton.Config
+ # does not have the RBLOCK field to guarantee that.
+ "RBLOCK",
+ # the following 3 are for mm
+ "BLOCK_M",
+ "BLOCK_N",
+ "BLOCK_K",
+ "num_warps",
+ ]
+ if self.is_mm:
+ out.append("num_stages")
+
+ return out
+
+ @staticmethod
+ def get_neighbour_values(name, cur_val):
+ lhs_val = None
+ rhs_val = None
+ if name == "num_stages":
+ lhs_val = cur_val - 1
+ rhs_val = cur_val + 1
+ else:
+ lhs_val = cur_val // 2
+ rhs_val = cur_val * 2
+
+ out = []
+ if lhs_val > 0:
+ out.append(lhs_val)
+ out.append(rhs_val)
+ return out
+
+ @staticmethod
+ def has_improvement(baseline, test):
+ threshold = 0.001 # 0.1%
+ return test is not None and test < baseline * (1 - threshold)
+
+ def autotune(
+ self,
+ func: Callable[["triton.Config"], float],
+ baseline_config: "triton.Config",
+ baseline_timing: Optional[float] = None,
+ ) -> "triton.Config":
+ if baseline_timing is None:
+ baseline_timing = self.call_func(func, baseline_config)
+
+ log.debug("= Do coordinate descent tuning for %s =", self.name)
+ log.debug(
+ "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
+ )
+ improved = True
+ best_config = baseline_config
+ best_timing = baseline_timing
+ tunable_fields = self.tunable_fields
+
+ while improved:
+ improved = False
+
+ for name in tunable_fields:
+ cur_val = get_field(best_config, name)
+ # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
+ if cur_val is None:
+ continue
+
+ candidate_values = self.get_neighbour_values(name, cur_val)
+ assert len(candidate_values) > 0
+
+ for next_val in candidate_values:
+ candidate_config = copy.deepcopy(best_config)
+ set_field(candidate_config, name, next_val)
+ log.debug("Try config %s", candidate_config)
+ try:
+ candidate_timing = self.call_func(func, candidate_config)
+ except Exception as e:
+ log.debug("Got exception %s", e)
+ continue
+
+ if self.has_improvement(best_timing, candidate_timing):
+ improved = True
+ log.debug(
+ "Tune from %s %f -> %s %f",
+ best_config,
+ best_timing,
+ candidate_config,
+ candidate_timing,
+ )
+ best_timing = candidate_timing
+ best_config = candidate_config
+
+ log.debug(
+ "Improve from %s %f -> %s %f, %.3fx",
+ baseline_config,
+ baseline_timing,
+ best_config,
+ best_timing,
+ baseline_timing / best_timing,
+ )
+
+ return best_config
diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py
index 4c401a6..a16d50c 100644
--- a/torch/_inductor/triton_heuristics.py
+++ b/torch/_inductor/triton_heuristics.py
@@ -10,6 +10,7 @@
import os.path
import re
import threading
+from enum import auto, Enum
from typing import List
import torch
@@ -17,6 +18,7 @@
from . import config
from .codecache import cache_dir, CudaKernelParamCache
+from .coordinate_descent_tuner import CoordescTuner
from .ir import ReductionHint, TileHint
from .utils import (
@@ -27,6 +29,7 @@
get_num_bytes,
has_triton,
next_power_of_2,
+ triton_config_to_hashable,
)
@@ -43,6 +46,13 @@
triton = None
+class HeuristicType(Enum):
+ POINTWISE = auto()
+ REDUCTION = auto()
+ PERSISTENT_REDUCTION = auto()
+ TEMPLATE = auto()
+
+
class CachingAutotuner(KernelInterface):
"""
Simplified version of Triton autotuner that has no invalidation
@@ -51,13 +61,22 @@
configs, and does not rely on the Triton JIT.
"""
- def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
+ def __init__(
+ self, fn, meta, configs, save_cache_hook, mutated_arg_names, heuristic_type
+ ):
super().__init__()
self.fn = fn
self.meta = meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.configs = configs
+ self.heuristic_type = heuristic_type
+
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug("CachingAutotuner gets %d configs", len(self.configs))
+ for c in self.configs:
+ log.debug(c)
+
self.launchers = []
self.lock = threading.Lock()
if os.getenv("TRITON_CACHE_DIR") is None:
@@ -67,6 +86,8 @@
str(self.meta.get("device", 0)),
)
+ self.coordesc_tuner = CoordescTuner(is_mm=False, name=self.fn.__name__)
+
def precompile(self, warm_cache_only_with_cc=None):
with self.lock:
if self.launchers:
@@ -164,8 +185,7 @@
return do_bench(kernel_call, rep=40, fast_flush=True)
- @dynamo_timed
- def benchmark_all_configs(self, *args, **kwargs):
+ def clone_args(self, *args):
from .compile_fx import clone_preserve_strides
# clone inplace buffers to avoid autotune contaminating them if
@@ -179,10 +199,24 @@
else:
cloned_args.append(arg)
+ return cloned_args
+
+ @dynamo_timed
+ def benchmark_all_configs(self, *args, **kwargs):
+ cloned_args = self.clone_args(*args)
timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
for launcher in self.launchers
}
+
+ for k, v in timings.items():
+ self.coordesc_tuner.cache_benchmark_result(k.config, v)
+
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug("Benchmark all input configs get:")
+ for k, v in timings.items():
+ log.debug("%s: %f", k.config, v)
+
return timings
def autotune_to_one_config(self, *args, **kwargs):
@@ -210,6 +244,43 @@
}
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
+ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
+ """
+ Coordinate descent tuning can be run with or without max-autotune.
+
+ The only difference between these two is the starting config for coordinate_descent tuning.
+ E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
+ and max-autotune figure out C3 is the best.
+
+ Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
+ while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
+ """
+ if self.heuristic_type == HeuristicType.TEMPLATE:
+ # skip triton template
+ return launcher
+
+ cloned_args = self.clone_args(*args)
+ config2launcher = {launcher.config: launcher}
+
+ def benchmark_one_config(config):
+ with self.lock:
+ launcher = self._precompile_config(config, None)
+ config2launcher[config] = launcher
+ return self.bench(launcher, *cloned_args, **kwargs)
+
+ assert not (
+ self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
+ and "RBLOCK" in launcher.config.kwargs
+ ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
+ best_config = self.coordesc_tuner.autotune(
+ benchmark_one_config, launcher.config, None
+ )
+ best_config.found_by_coordesc = True
+
+ if self.save_cache_hook:
+ self.save_cache_hook(best_config, found_by_coordesc=True)
+ return config2launcher.get(best_config)
+
def run(self, *args, grid, stream):
if len(self.launchers) != 1:
if len(self.launchers) == 0:
@@ -217,6 +288,14 @@
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid)
+ if (
+ not getattr(self.launchers[0].config, "found_by_coordesc", False)
+ and config.coordinate_descent_tuning
+ ):
+ self.launchers = [
+ self.coordinate_descent_tuning(self.launchers[0], *args, grid=grid)
+ ]
+
(launcher,) = self.launchers
if launcher.store_cubin:
self.save_cuda_kernel(grid, stream, self.launchers[0])
@@ -327,13 +406,22 @@
with open(cache_filename, "r") as fd:
best_config = json.loads(fd.read())
- if best_config.get("configs_hash") != configs_hash:
+ if best_config.pop("configs_hash", None) != configs_hash:
return None
+ if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False):
+ num_warps = best_config.pop("num_warps")
+ num_stages = best_config.pop("num_stages")
+ triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
+ triton_config.found_by_coordesc = True
+ return triton_config
+
matching_configs = [
cfg
for cfg in configs
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
+ and cfg.num_warps == best_config.get("num_warps")
+ and cfg.num_stages == best_config.get("num_stages")
]
if len(matching_configs) != 1:
return None
@@ -344,6 +432,7 @@
def cached_autotune(
configs: List[Config],
meta,
+ heuristic_type,
filename=None,
):
"""
@@ -353,22 +442,30 @@
configs = unique_configs(configs)
assert len(configs) == 1 or filename
- # The autotune cache will simply replace the list of candidate configs with
- # the best config cached. We don't want that when we benchmark triton kernels.
- # We need the perf for each of the candidate config instead.
- cache_autotune_result = not config.benchmark_kernel
-
# on disk caching logic
- if cache_autotune_result and filename is not None and len(configs) > 1:
+ if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning):
cache_filename = os.path.splitext(filename)[0] + ".best_config"
configs_hash = hash_configs(configs)
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
if best_config:
configs = [best_config]
- def save_cache_hook(cfg):
+ def save_cache_hook(cfg, found_by_coordesc=False):
with open(cache_filename, "w") as fd:
- fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
+ fd.write(
+ json.dumps(
+ {
+ **cfg.kwargs,
+ "num_warps": cfg.num_warps,
+ "num_stages": cfg.num_stages,
+ "configs_hash": configs_hash,
+ "found_by_coordesc": found_by_coordesc,
+ }
+ )
+ )
+ if log.isEnabledFor(logging.DEBUG):
+ type_str = "coordesc" if found_by_coordesc else "heuristic"
+ log.debug("Save %s tuning result to %s", type_str, cache_filename)
else:
save_cache_hook = None
@@ -384,6 +481,7 @@
configs=configs,
save_cache_hook=save_cache_hook,
mutated_arg_names=mutated_arg_names,
+ heuristic_type=heuristic_type,
)
return CachingAutotuner(
fn,
@@ -391,6 +489,7 @@
configs=configs,
save_cache_hook=save_cache_hook,
mutated_arg_names=mutated_arg_names,
+ heuristic_type=heuristic_type,
)
return decorator
@@ -400,8 +499,9 @@
"""Remove duplicate configurations"""
seen = set()
pruned_configs = []
+
for cfg in configs:
- key = tuple(cfg.kwargs.items())
+ key = triton_config_to_hashable(cfg)
if key not in seen:
seen.add(key)
pruned_configs.append(cfg)
@@ -554,12 +654,22 @@
bs = max(256, min(numel // 128, 1024))
if len(size_hints) == 1:
- return cached_autotune([triton_config(size_hints, bs)], meta=meta)
+ return cached_autotune(
+ [triton_config(size_hints, bs)],
+ meta=meta,
+ heuristic_type=HeuristicType.POINTWISE,
+ filename=filename,
+ )
if len(size_hints) == 2:
if (
not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
) and not (config.max_autotune or config.max_autotune_pointwise):
- return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
+ return cached_autotune(
+ [triton_config(size_hints, 32, 32)],
+ meta=meta,
+ heuristic_type=HeuristicType.POINTWISE,
+ filename=filename,
+ )
return cached_autotune(
[
triton_config(size_hints, 32, 32),
@@ -571,10 +681,16 @@
],
meta=meta,
filename=filename,
+ heuristic_type=HeuristicType.POINTWISE,
)
if len(size_hints) == 3:
if not config.triton.autotune_pointwise:
- return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
+ return cached_autotune(
+ [triton_config(size_hints, 16, 16, 16)],
+ meta=meta,
+ heuristic_type=HeuristicType.POINTWISE,
+ filename=filename,
+ )
return cached_autotune(
[
triton_config(size_hints, 16, 16, 16),
@@ -587,6 +703,7 @@
],
meta=meta,
filename=filename,
+ heuristic_type=HeuristicType.POINTWISE,
)
raise NotImplementedError(f"size_hints: {size_hints}")
@@ -606,14 +723,32 @@
if config.max_autotune or config.max_autotune_pointwise:
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
- return cached_autotune([contiguous_config], meta=meta)
+ return cached_autotune(
+ [contiguous_config],
+ meta=meta,
+ heuristic_type=HeuristicType.REDUCTION,
+ filename=filename,
+ )
elif reduction_hint == ReductionHint.OUTER:
- return cached_autotune([outer_config], meta=meta)
+ return cached_autotune(
+ [outer_config],
+ meta=meta,
+ heuristic_type=HeuristicType.REDUCTION,
+ filename=filename,
+ )
elif reduction_hint == ReductionHint.OUTER_TINY:
- return cached_autotune([tiny_config], meta=meta)
+ return cached_autotune(
+ [tiny_config],
+ meta=meta,
+ heuristic_type=HeuristicType.REDUCTION,
+ filename=filename,
+ )
if not config.triton.autotune_pointwise:
return cached_autotune(
- [triton_config_reduction(size_hints, 32, 128)], meta=meta
+ [triton_config_reduction(size_hints, 32, 128)],
+ meta=meta,
+ heuristic_type=HeuristicType.REDUCTION,
+ filename=filename,
)
return cached_autotune(
[
@@ -625,6 +760,7 @@
],
meta=meta,
filename=filename,
+ heuristic_type=HeuristicType.REDUCTION,
)
raise NotImplementedError(f"size_hints: {size_hints}")
@@ -658,6 +794,7 @@
configs,
meta=meta,
filename=filename,
+ heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
)
@@ -666,7 +803,10 @@
Compile a triton template
"""
return cached_autotune(
- [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
+ [triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
+ meta=meta,
+ heuristic_type=HeuristicType.TEMPLATE,
+ filename=filename,
)
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 6d5d714..bf493ac 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -1050,3 +1050,14 @@
parse_profile_event_list(
benchmark_name, event_list, wall_time_ms, times * repeat
)
+
+
+def triton_config_to_hashable(cfg):
+ """
+ Convert triton config to a tuple that can uniquely identify it. We can use
+ the return value as a dictionary key.
+ """
+ items = sorted(cfg.kwargs.items())
+ items.append(("num_warps", cfg.num_warps))
+ items.append(("num_stages", cfg.num_stages))
+ return tuple(items)