[inductor] If a kernel contains bucketize, try using config with num_elements_per_warp=32 (#104456)

In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5.

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104456
Approved by: https://github.com/eellison
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 6c62f36..55300e7 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -6557,6 +6557,23 @@
             for right in [True, False]:
                 self.common(fn, (input, offsets, out_int32, right), check_lowp=False)
 
+    @patch.object(config.triton, "autotune_pointwise", True)
+    def test_inductor_bucketize_add_autotune(self):
+        """
+        Causes a @pointwise(size_hints) where size_hints is 2D
+        """
+
+        def fn(input, offsets, add_value):
+            return torch.ops.prims._inductor_bucketize(input, offsets) + add_value
+
+        input = torch.rand((16, 16, 64, 64))
+        boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
+        add_value = torch.randint(0, 1024, (16, 16, 64, 64)).to(
+            memory_format=torch.channels_last
+        )
+
+        self.common(fn, (input, boundaries, add_value), check_lowp=False)
+
 
 @dataclasses.dataclass
 class TestFailure:
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index c5e268e..834639e 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -20,6 +20,7 @@
 from ..codecache import code_hash, get_path
 from ..ir import ReductionHint
 from ..optimize_indexing import indexing_dtype_strength_reduction
+from ..triton_heuristics import AutotuneHint
 from ..utils import (
     DeferredLineBase,
     get_fused_kernel_name,
@@ -741,6 +742,9 @@
         )
         self.initialize_range_tree(pid_cache)
 
+        # A set of autotuning hints to pass as part of triton_meta
+        self.autotune_hints: Set[AutotuneHint] = set()
+
         # define this in a closure to make cache local to object
         @functools.lru_cache(None)
         def simplify_indexing(index: sympy.Expr):
@@ -1303,6 +1307,12 @@
         See [Note: Inductor bucketize op]
         """
 
+        # Triton performance for bucketize_binary_search is much better when the number
+        # of threads equals the number of elements.
+        # If we're trying to use a bucketize kernel, we should make sure that an
+        # autotuning config with num_elements_per_warp=32 exists.
+        self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32)
+
         offsets_ptr = self.args.input(offsets_name)
         block_size = self.dense_size_str()
 
@@ -1604,7 +1614,7 @@
                     import triton.language as tl
                     from torch._inductor.ir import ReductionHint
                     from torch._inductor.ir import TileHint
-                    from torch._inductor.triton_heuristics import {heuristics}
+                    from torch._inductor.triton_heuristics import AutotuneHint, {heuristics}
                     from torch._inductor.utils import instance_descriptor
                     from torch._inductor import triton_helpers
                 """
@@ -1648,6 +1658,7 @@
             "device": V.graph.scheduler.current_device.index,
             "constants": {},
             "mutated_arg_names": mutated_args,
+            "autotune_hints": set(self.autotune_hints),
         }
 
         for tree in self.range_trees:
diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py
index f0ccebb..6102766 100644
--- a/torch/_inductor/triton_heuristics.py
+++ b/torch/_inductor/triton_heuristics.py
@@ -11,7 +11,7 @@
 import re
 import threading
 from enum import auto, Enum
-from typing import List
+from typing import List, Set
 
 import torch
 from torch._dynamo.utils import dynamo_timed
@@ -53,6 +53,54 @@
     TEMPLATE = auto()
 
 
+class AutotuneHint(Enum):
+    ELEMENTS_PER_WARP_32 = 0
+
+    # Triton codegen tries to codegen set of AutotuneHints.
+    # Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
+    # which isn't valid python.
+    # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
+    __repr__ = Enum.__str__
+
+
+def autotune_hints_to_configs(
+    hints: Set[AutotuneHint], size_hints, block_size
+) -> List[Config]:
+    """
+    AutotuneHints can be attached to the metadata of triton kernels for providing
+    suggestions about what to try for autotuning. One reason to do this is if there are
+    some configs that are only useful in specific scenarios, in which case we can avoid
+    wasting compile time on autotuning unless we know we are in one of those scenarios.
+
+    Based on those hints, this function will generate a list of additional autotuning
+    configs to try.
+    """
+    configs = []
+
+    for hint in hints:
+        if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
+            if len(size_hints) == 1:
+                xyz_options = ((block_size // 4,),)
+            elif len(size_hints) == 2:
+                xyz_options = ((block_size // 4, 1), (1, block_size // 4))
+            elif len(size_hints) == 3:
+                xyz_options = (
+                    (block_size // 4, 1, 1),
+                    (1, block_size // 4, 1),
+                    (1, 1, block_size // 4),
+                )
+            for xyz in xyz_options:
+                configs.append(
+                    triton_config(
+                        size_hints,
+                        *xyz,
+                        num_elements_per_warp=32,
+                    )
+                )
+
+    return configs
+
+
 def disable_pointwise_autotuning():
     # Autotuning can give different benchmarking results from run to run, and
     # therefore we disable autotuning when use_deterministic flag is on.
@@ -594,6 +642,13 @@
     Construct a pointwise triton config with some adjustment heuristics
     based on size_hints. Size_hints is a tuple of numels in each tile
     dimension and will be rounded up to the nearest power of 2.
+
+    num_elements_per_warp is a suggestion for controlling how many warps
+    the triton config should contain. e.g.: if x=16, y=8, z=4 then
+    num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
+    we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
+    just a suggestion, and sometimes other adjustment heuristics will
+    override the num_elements_per_warp.
     """
     # Ideally we want to read this from some device config
     maxGridSize = [2147483647, 65535, 65535]
@@ -718,6 +773,10 @@
     numel = functools.reduce(operator.mul, size_hints)
     bs = max(256, min(numel // 128, 1024))
 
+    hinted_configs = autotune_hints_to_configs(
+        meta.get("autotune_hints", set()), size_hints, bs
+    )
+
     if len(size_hints) == 1:
         if disable_pointwise_autotuning() and not (
             config.max_autotune or config.max_autotune_pointwise
@@ -735,6 +794,7 @@
                 [
                     triton_config(size_hints, bs, num_elements_per_warp=256),
                     triton_config(size_hints, bs // 2, num_elements_per_warp=64),
+                    *hinted_configs,
                 ],
                 meta=meta,
                 heuristic_type=HeuristicType.POINTWISE,
@@ -760,6 +820,7 @@
                 triton_config(size_hints, 16, 256),
                 triton_config(size_hints, bs, 1),
                 triton_config(size_hints, 1, bs),
+                *hinted_configs,
             ],
             meta=meta,
             filename=filename,
@@ -784,6 +845,7 @@
                 triton_config(size_hints, bs, 1, 1),
                 triton_config(size_hints, 1, bs, 1),
                 triton_config(size_hints, 1, 1, bs),
+                *hinted_configs,
             ],
             meta=meta,
             filename=filename,