[reland][inductor] coordinate descent tuning upon max-autotune (#99594)

Reland https://github.com/pytorch/pytorch/pull/97203 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99594
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py
new file mode 100644
index 0000000..71c5590
--- /dev/null
+++ b/test/inductor/test_coordinate_descent_tuner.py
@@ -0,0 +1,37 @@
+# Owner(s): ["module: inductor"]
+
+import sys
+import unittest
+
+from torch._dynamo.test_case import run_tests, TestCase
+from torch.testing._internal.common_utils import IS_LINUX
+from torch.testing._internal.inductor_utils import HAS_CUDA
+
+try:
+    import triton
+except ImportError:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise unittest.SkipTest("requires triton")
+
+from torch._inductor.coordinate_descent_tuner import CoordescTuner
+
+
+class TestCoordinateDescentTuner(TestCase):
+    def test_abs_function(self):
+        """
+        The benchmark result is simply abs(XBLOCK - 15)
+        """
+        tuner = CoordescTuner()
+        baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1)
+
+        def func(config):
+            return abs(config.kwargs["XBLOCK"] - 15)
+
+        best_config = tuner.autotune(func, baseline_config)
+        self.assertTrue(best_config.kwargs.get("XBLOCK") == 16)
+
+
+if __name__ == "__main__":
+    if IS_LINUX and HAS_CUDA:
+        run_tests()
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 76d92a7..86a1a7b 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -306,7 +306,11 @@
         https://github.com/pytorch/torchdynamo/issues/1670
         """
         from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
-        from torch._inductor.triton_heuristics import CachingAutotuner, grid
+        from torch._inductor.triton_heuristics import (
+            CachingAutotuner,
+            grid,
+            HeuristicType,
+        )
         from torch._inductor.utils import instance_descriptor
 
         def autotune(configs, meta):
@@ -318,6 +322,7 @@
                     configs=configs,
                     save_cache_hook=False,
                     mutated_arg_names=["in_out_ptr0"],
+                    heuristic_type=HeuristicType.POINTWISE,
                 )
 
             return decorator
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index c37bdc8..283a38a 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -60,6 +60,11 @@
 # We will disable creating subprocess for autotuning if this is False
 autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
 
+
+coordinate_descent_tuning = (
+    os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
+)
+
 # control store vs recompute heuristic
 # For fanouts, rematerialization can lead to exponential blowup. So, have
 # smaller threshold
@@ -248,7 +253,9 @@
     descriptive_names = "original_aten"
 
     # use alternate codegen for smaller reductions
-    persistent_reductions = True
+    persistent_reductions = (
+        os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
+    )
 
     # hint to Triton when arguments are divisible by 16
     divisible_by_16 = True
diff --git a/torch/_inductor/coordinate_descent_tuner.py b/torch/_inductor/coordinate_descent_tuner.py
new file mode 100644
index 0000000..93cd4a7
--- /dev/null
+++ b/torch/_inductor/coordinate_descent_tuner.py
@@ -0,0 +1,169 @@
+import copy
+import logging
+from typing import Callable, Optional
+
+from .utils import has_triton, triton_config_to_hashable
+
+if has_triton():
+    import triton
+else:
+    triton = None
+
+
+log = logging.getLogger(__name__)
+
+
+def get_field(config, name):
+    if name == "num_warps":
+        return config.num_warps
+    elif name == "num_stages":
+        return config.num_stages
+    else:
+        return config.kwargs.get(name, None)
+
+
+def set_field(config, name, value):
+    if name == "num_warps":
+        config.num_warps = value
+    elif name == "num_stages":
+        config.num_stages = value
+    else:
+        config.kwargs[name] = value
+
+
+class CoordescTuner:
+    """
+    The coordinate descent tuner. Tune one field/coordinate at a time.
+
+    TODO will it be necessary to tune multiple fields simultanuously.
+
+
+    TODO: what if both increasing and descreasing a field can improve perf.
+          i.e., there are multiple local optima..
+    """
+
+    def __init__(self, is_mm=False, name="unknown"):
+        self.is_mm = is_mm  # we will tune num_stages for mm
+        self.cached_benchmark_results = {}
+        self.name = name
+
+    def cache_benchmark_result(self, config, timing):
+        self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
+
+    def lookup_in_cache(self, config):
+        return self.cached_benchmark_results.get(triton_config_to_hashable(config))
+
+    def call_func(self, func, config):
+        found = self.lookup_in_cache(config)
+        if found is not None:
+            log.debug("  CACHED")
+            return found
+        timing = func(config)
+        self.cache_benchmark_result(config, timing)
+        return timing
+
+    @property
+    def tunable_fields(self):
+        out = [
+            "XBLOCK",
+            "YBLOCK",
+            "ZBLOCK",
+            # NOTE: we should not tune RBLOCK for persistent reduction.
+            # We rely on the fact that persistent reduction's triton.Config
+            # does not have the RBLOCK field to guarantee that.
+            "RBLOCK",
+            # the following 3 are for mm
+            "BLOCK_M",
+            "BLOCK_N",
+            "BLOCK_K",
+            "num_warps",
+        ]
+        if self.is_mm:
+            out.append("num_stages")
+
+        return out
+
+    @staticmethod
+    def get_neighbour_values(name, cur_val):
+        lhs_val = None
+        rhs_val = None
+        if name == "num_stages":
+            lhs_val = cur_val - 1
+            rhs_val = cur_val + 1
+        else:
+            lhs_val = cur_val // 2
+            rhs_val = cur_val * 2
+
+        out = []
+        if lhs_val > 0:
+            out.append(lhs_val)
+        out.append(rhs_val)
+        return out
+
+    @staticmethod
+    def has_improvement(baseline, test):
+        threshold = 0.001  # 0.1%
+        return test is not None and test < baseline * (1 - threshold)
+
+    def autotune(
+        self,
+        func: Callable[["triton.Config"], float],
+        baseline_config: "triton.Config",
+        baseline_timing: Optional[float] = None,
+    ) -> "triton.Config":
+        if baseline_timing is None:
+            baseline_timing = self.call_func(func, baseline_config)
+
+        log.debug("= Do coordinate descent tuning for %s =", self.name)
+        log.debug(
+            "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
+        )
+        improved = True
+        best_config = baseline_config
+        best_timing = baseline_timing
+        tunable_fields = self.tunable_fields
+
+        while improved:
+            improved = False
+
+            for name in tunable_fields:
+                cur_val = get_field(best_config, name)
+                # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
+                if cur_val is None:
+                    continue
+
+                candidate_values = self.get_neighbour_values(name, cur_val)
+                assert len(candidate_values) > 0
+
+                for next_val in candidate_values:
+                    candidate_config = copy.deepcopy(best_config)
+                    set_field(candidate_config, name, next_val)
+                    log.debug("Try config %s", candidate_config)
+                    try:
+                        candidate_timing = self.call_func(func, candidate_config)
+                    except Exception as e:
+                        log.debug("Got exception %s", e)
+                        continue
+
+                    if self.has_improvement(best_timing, candidate_timing):
+                        improved = True
+                        log.debug(
+                            "Tune from %s %f -> %s %f",
+                            best_config,
+                            best_timing,
+                            candidate_config,
+                            candidate_timing,
+                        )
+                        best_timing = candidate_timing
+                        best_config = candidate_config
+
+        log.debug(
+            "Improve from %s %f -> %s %f, %.3fx",
+            baseline_config,
+            baseline_timing,
+            best_config,
+            best_timing,
+            baseline_timing / best_timing,
+        )
+
+        return best_config
diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py
index 4c401a6..a16d50c 100644
--- a/torch/_inductor/triton_heuristics.py
+++ b/torch/_inductor/triton_heuristics.py
@@ -10,6 +10,7 @@
 import os.path
 import re
 import threading
+from enum import auto, Enum
 from typing import List
 
 import torch
@@ -17,6 +18,7 @@
 
 from . import config
 from .codecache import cache_dir, CudaKernelParamCache
+from .coordinate_descent_tuner import CoordescTuner
 
 from .ir import ReductionHint, TileHint
 from .utils import (
@@ -27,6 +29,7 @@
     get_num_bytes,
     has_triton,
     next_power_of_2,
+    triton_config_to_hashable,
 )
 
 
@@ -43,6 +46,13 @@
     triton = None
 
 
+class HeuristicType(Enum):
+    POINTWISE = auto()
+    REDUCTION = auto()
+    PERSISTENT_REDUCTION = auto()
+    TEMPLATE = auto()
+
+
 class CachingAutotuner(KernelInterface):
     """
     Simplified version of Triton autotuner that has no invalidation
@@ -51,13 +61,22 @@
     configs, and does not rely on the Triton JIT.
     """
 
-    def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
+    def __init__(
+        self, fn, meta, configs, save_cache_hook, mutated_arg_names, heuristic_type
+    ):
         super().__init__()
         self.fn = fn
         self.meta = meta
         self.save_cache_hook = save_cache_hook
         self.mutated_arg_names = mutated_arg_names
         self.configs = configs
+        self.heuristic_type = heuristic_type
+
+        if log.isEnabledFor(logging.DEBUG):
+            log.debug("CachingAutotuner gets %d configs", len(self.configs))
+            for c in self.configs:
+                log.debug(c)
+
         self.launchers = []
         self.lock = threading.Lock()
         if os.getenv("TRITON_CACHE_DIR") is None:
@@ -67,6 +86,8 @@
                 str(self.meta.get("device", 0)),
             )
 
+        self.coordesc_tuner = CoordescTuner(is_mm=False, name=self.fn.__name__)
+
     def precompile(self, warm_cache_only_with_cc=None):
         with self.lock:
             if self.launchers:
@@ -164,8 +185,7 @@
 
         return do_bench(kernel_call, rep=40, fast_flush=True)
 
-    @dynamo_timed
-    def benchmark_all_configs(self, *args, **kwargs):
+    def clone_args(self, *args):
         from .compile_fx import clone_preserve_strides
 
         # clone inplace buffers to avoid autotune contaminating them if
@@ -179,10 +199,24 @@
             else:
                 cloned_args.append(arg)
 
+        return cloned_args
+
+    @dynamo_timed
+    def benchmark_all_configs(self, *args, **kwargs):
+        cloned_args = self.clone_args(*args)
         timings = {
             launcher: self.bench(launcher, *cloned_args, **kwargs)
             for launcher in self.launchers
         }
+
+        for k, v in timings.items():
+            self.coordesc_tuner.cache_benchmark_result(k.config, v)
+
+        if log.isEnabledFor(logging.DEBUG):
+            log.debug("Benchmark all input configs get:")
+            for k, v in timings.items():
+                log.debug("%s: %f", k.config, v)
+
         return timings
 
     def autotune_to_one_config(self, *args, **kwargs):
@@ -210,6 +244,43 @@
         }
         CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
 
+    def coordinate_descent_tuning(self, launcher, *args, **kwargs):
+        """
+        Coordinate descent tuning can be run with or without max-autotune.
+
+        The only difference between these two is the starting config for coordinate_descent tuning.
+        E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
+        and max-autotune figure out C3 is the best.
+
+        Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
+        while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
+        """
+        if self.heuristic_type == HeuristicType.TEMPLATE:
+            # skip triton template
+            return launcher
+
+        cloned_args = self.clone_args(*args)
+        config2launcher = {launcher.config: launcher}
+
+        def benchmark_one_config(config):
+            with self.lock:
+                launcher = self._precompile_config(config, None)
+            config2launcher[config] = launcher
+            return self.bench(launcher, *cloned_args, **kwargs)
+
+        assert not (
+            self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
+            and "RBLOCK" in launcher.config.kwargs
+        ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
+        best_config = self.coordesc_tuner.autotune(
+            benchmark_one_config, launcher.config, None
+        )
+        best_config.found_by_coordesc = True
+
+        if self.save_cache_hook:
+            self.save_cache_hook(best_config, found_by_coordesc=True)
+        return config2launcher.get(best_config)
+
     def run(self, *args, grid, stream):
         if len(self.launchers) != 1:
             if len(self.launchers) == 0:
@@ -217,6 +288,14 @@
             if len(self.launchers) > 1:
                 self.autotune_to_one_config(*args, grid=grid)
 
+        if (
+            not getattr(self.launchers[0].config, "found_by_coordesc", False)
+            and config.coordinate_descent_tuning
+        ):
+            self.launchers = [
+                self.coordinate_descent_tuning(self.launchers[0], *args, grid=grid)
+            ]
+
         (launcher,) = self.launchers
         if launcher.store_cubin:
             self.save_cuda_kernel(grid, stream, self.launchers[0])
@@ -327,13 +406,22 @@
 
     with open(cache_filename, "r") as fd:
         best_config = json.loads(fd.read())
-    if best_config.get("configs_hash") != configs_hash:
+    if best_config.pop("configs_hash", None) != configs_hash:
         return None
 
+    if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False):
+        num_warps = best_config.pop("num_warps")
+        num_stages = best_config.pop("num_stages")
+        triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
+        triton_config.found_by_coordesc = True
+        return triton_config
+
     matching_configs = [
         cfg
         for cfg in configs
         if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
+        and cfg.num_warps == best_config.get("num_warps")
+        and cfg.num_stages == best_config.get("num_stages")
     ]
     if len(matching_configs) != 1:
         return None
@@ -344,6 +432,7 @@
 def cached_autotune(
     configs: List[Config],
     meta,
+    heuristic_type,
     filename=None,
 ):
     """
@@ -353,22 +442,30 @@
     configs = unique_configs(configs)
     assert len(configs) == 1 or filename
 
-    # The autotune cache will simply replace the list of candidate configs with
-    # the best config cached. We don't want that when we benchmark triton kernels.
-    # We need the perf for each of the candidate config instead.
-    cache_autotune_result = not config.benchmark_kernel
-
     # on disk caching logic
-    if cache_autotune_result and filename is not None and len(configs) > 1:
+    if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning):
         cache_filename = os.path.splitext(filename)[0] + ".best_config"
         configs_hash = hash_configs(configs)
         best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
         if best_config:
             configs = [best_config]
 
-        def save_cache_hook(cfg):
+        def save_cache_hook(cfg, found_by_coordesc=False):
             with open(cache_filename, "w") as fd:
-                fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
+                fd.write(
+                    json.dumps(
+                        {
+                            **cfg.kwargs,
+                            "num_warps": cfg.num_warps,
+                            "num_stages": cfg.num_stages,
+                            "configs_hash": configs_hash,
+                            "found_by_coordesc": found_by_coordesc,
+                        }
+                    )
+                )
+            if log.isEnabledFor(logging.DEBUG):
+                type_str = "coordesc" if found_by_coordesc else "heuristic"
+                log.debug("Save %s tuning result to %s", type_str, cache_filename)
 
     else:
         save_cache_hook = None
@@ -384,6 +481,7 @@
                 configs=configs,
                 save_cache_hook=save_cache_hook,
                 mutated_arg_names=mutated_arg_names,
+                heuristic_type=heuristic_type,
             )
         return CachingAutotuner(
             fn,
@@ -391,6 +489,7 @@
             configs=configs,
             save_cache_hook=save_cache_hook,
             mutated_arg_names=mutated_arg_names,
+            heuristic_type=heuristic_type,
         )
 
     return decorator
@@ -400,8 +499,9 @@
     """Remove duplicate configurations"""
     seen = set()
     pruned_configs = []
+
     for cfg in configs:
-        key = tuple(cfg.kwargs.items())
+        key = triton_config_to_hashable(cfg)
         if key not in seen:
             seen.add(key)
             pruned_configs.append(cfg)
@@ -554,12 +654,22 @@
     bs = max(256, min(numel // 128, 1024))
 
     if len(size_hints) == 1:
-        return cached_autotune([triton_config(size_hints, bs)], meta=meta)
+        return cached_autotune(
+            [triton_config(size_hints, bs)],
+            meta=meta,
+            heuristic_type=HeuristicType.POINTWISE,
+            filename=filename,
+        )
     if len(size_hints) == 2:
         if (
             not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
         ) and not (config.max_autotune or config.max_autotune_pointwise):
-            return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
+            return cached_autotune(
+                [triton_config(size_hints, 32, 32)],
+                meta=meta,
+                heuristic_type=HeuristicType.POINTWISE,
+                filename=filename,
+            )
         return cached_autotune(
             [
                 triton_config(size_hints, 32, 32),
@@ -571,10 +681,16 @@
             ],
             meta=meta,
             filename=filename,
+            heuristic_type=HeuristicType.POINTWISE,
         )
     if len(size_hints) == 3:
         if not config.triton.autotune_pointwise:
-            return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
+            return cached_autotune(
+                [triton_config(size_hints, 16, 16, 16)],
+                meta=meta,
+                heuristic_type=HeuristicType.POINTWISE,
+                filename=filename,
+            )
         return cached_autotune(
             [
                 triton_config(size_hints, 16, 16, 16),
@@ -587,6 +703,7 @@
             ],
             meta=meta,
             filename=filename,
+            heuristic_type=HeuristicType.POINTWISE,
         )
     raise NotImplementedError(f"size_hints: {size_hints}")
 
@@ -606,14 +723,32 @@
         if config.max_autotune or config.max_autotune_pointwise:
             pass  # skip all these cases
         elif reduction_hint == ReductionHint.INNER:
-            return cached_autotune([contiguous_config], meta=meta)
+            return cached_autotune(
+                [contiguous_config],
+                meta=meta,
+                heuristic_type=HeuristicType.REDUCTION,
+                filename=filename,
+            )
         elif reduction_hint == ReductionHint.OUTER:
-            return cached_autotune([outer_config], meta=meta)
+            return cached_autotune(
+                [outer_config],
+                meta=meta,
+                heuristic_type=HeuristicType.REDUCTION,
+                filename=filename,
+            )
         elif reduction_hint == ReductionHint.OUTER_TINY:
-            return cached_autotune([tiny_config], meta=meta)
+            return cached_autotune(
+                [tiny_config],
+                meta=meta,
+                heuristic_type=HeuristicType.REDUCTION,
+                filename=filename,
+            )
         if not config.triton.autotune_pointwise:
             return cached_autotune(
-                [triton_config_reduction(size_hints, 32, 128)], meta=meta
+                [triton_config_reduction(size_hints, 32, 128)],
+                meta=meta,
+                heuristic_type=HeuristicType.REDUCTION,
+                filename=filename,
             )
         return cached_autotune(
             [
@@ -625,6 +760,7 @@
             ],
             meta=meta,
             filename=filename,
+            heuristic_type=HeuristicType.REDUCTION,
         )
     raise NotImplementedError(f"size_hints: {size_hints}")
 
@@ -658,6 +794,7 @@
         configs,
         meta=meta,
         filename=filename,
+        heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
     )
 
 
@@ -666,7 +803,10 @@
     Compile a triton template
     """
     return cached_autotune(
-        [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
+        [triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
+        meta=meta,
+        heuristic_type=HeuristicType.TEMPLATE,
+        filename=filename,
     )
 
 
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 6d5d714..bf493ac 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -1050,3 +1050,14 @@
         parse_profile_event_list(
             benchmark_name, event_list, wall_time_ms, times * repeat
         )
+
+
+def triton_config_to_hashable(cfg):
+    """
+    Convert triton config to a tuple that can uniquely identify it. We can use
+    the return value as a dictionary key.
+    """
+    items = sorted(cfg.kwargs.items())
+    items.append(("num_warps", cfg.num_warps))
+    items.append(("num_stages", cfg.num_stages))
+    return tuple(items)