[inductor] Add support for tl.make_block_ptr (#116079)

On A100 this is a small regression:
![image](https://github.com/pytorch/pytorch/assets/533820/b30eee9d-c0fe-4123-99da-d554fc5d0171)

So I will leave it disabled by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116079
Approved by: https://github.com/shunting314
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index c38ffa5..885dee5 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -1072,6 +1072,48 @@
 
         self.assertEqual(o1, o2)
 
+    @config.patch("triton.use_block_ptr", True)
+    def test_selecsls42b_misaligned_address(self):
+        # https://github.com/openai/triton/issues/2836
+
+        @torch.compile(fullgraph=True)
+        def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
+            div = torch.ops.aten.div.Scalar(expand, 16)
+            where = torch.ops.aten.where.self(arg207_1, full, div)
+            convert_element_type_43 = torch.ops.prims.convert_element_type.default(
+                where, torch.float32
+            )
+            sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
+            sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
+            mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
+            sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
+            mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
+            unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
+            unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
+            unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
+            mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
+            mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
+            unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
+            unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
+            unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
+            mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
+            sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
+            sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
+            return (sub_2,)
+
+        args = [
+            torch.randn((8, 1024, 4, 4), device="cuda") > 0,  # torch.bool tensor
+            torch.randn((1, 1024, 1, 1), device="cuda"),
+            torch.randn((8, 1024, 4, 4), device="cuda"),
+            torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand(
+                (8, 1024, 4, 4)
+            ),
+            torch.randn((), device="cuda"),
+            torch.randn((1024,), device="cuda"),
+        ]
+        fn(*args)
+        torch.cuda.synchronize()  # shake out Triton Error [CUDA]: misaligned address
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index d48ad5e..5b18cc6 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -546,6 +546,8 @@
                 stmt = line.split(".store")[-1]
             elif "[" in line:
                 stmt = line.split("[")[-1].split("]")[0]
+            if "tl.make_block_ptr(" in line:
+                continue
 
             if stmt is None:
                 continue
@@ -8627,6 +8629,7 @@
 
                 self.assertEqual(fn_opt(), fn())
 
+        @config.patch("triton.use_block_ptr", False)
         def test_evict_last_non_coalesced_loads(self):
             @torch.compile
             def f(a, b):
@@ -8638,13 +8641,33 @@
                 torch.randn(N, N, N, device="cuda").permute(1, 2, 0),
             )
             code = run_and_get_triton_code(f, *inps)
-            self.assertTrue(
-                "tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last'"
-                in code
+            lines = [line for line in code.split("\n") if "tl.load" in line]
+            self.assertExpectedInline(
+                "\n".join(lines),
+                """\
+        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
+        tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
             )
-            self.assertTrue(
-                "tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first',"
-                in code
+
+        @skipIfRocm
+        @config.patch("triton.use_block_ptr", True)
+        def test_evict_last_non_coalesced_loads_block_ptr(self):
+            @torch.compile
+            def f(a, b):
+                return (a * b).sum(dim=-1)
+
+            N = 512
+            inps = (
+                torch.randn(N, N, N, device="cuda").permute(2, 1, 0),
+                torch.randn(N, N, N, device="cuda").permute(1, 2, 0),
+            )
+            code = run_and_get_triton_code(f, *inps)
+            lines = [line for line in code.split("\n") if "tl.load" in line]
+            self.assertExpectedInline(
+                "\n".join(lines),
+                """\
+        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
+        tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""",
             )
 
         # Disable index propagation, so the indirect indexing isn't optimized away
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index af08736..a09cf97 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -495,6 +495,7 @@
     def __init__(self, name, line):
         super().__init__(line)
         self.name = name
+        assert not isinstance(line, DeferredLineBase)
 
     def __call__(self):
         if all(
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 5bb43dc..0f29230 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -43,6 +43,7 @@
 from ..scheduler import BaseScheduling, WhyNoFuse
 from ..triton_heuristics import AutotuneHint
 from ..utils import (
+    cache_on_self,
     do_bench,
     get_fused_kernel_name,
     get_kernel_metadata,
@@ -50,6 +51,7 @@
     is_welford_reduction,
     next_power_of_2,
     Placeholder,
+    sympy_dot,
     sympy_product,
     sympy_subs,
     sympy_symbol,
@@ -97,6 +99,150 @@
         return "tmp" in self.mask_str
 
 
+@dataclasses.dataclass
+class BlockPtrOptions:
+    constant_offset: sympy.Expr
+    shape: List[sympy.Expr]
+    strides: List[sympy.Expr]
+    block_shape: List[str]
+    order: List[int]
+    offsets: List[str]
+    mask_vars: Set[sympy.Symbol]
+    reshape_suffix: List[str]
+
+    @staticmethod
+    def create(
+        strides: List[sympy.Expr],
+        constant_offset: sympy.Expr,
+        range_trees: List[IterationRangesEntry],
+        mask_vars: Set[sympy.Symbol],
+    ) -> BlockPtrOptions:
+        """Helper to create a  BlockPtrOptions instance"""
+        block_shape = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
+        reshape_suffix = [*block_shape]
+
+        broadcasting_dim = [s == 0 for s in strides]
+        for i, is_broadcasting in enumerate(broadcasting_dim):
+            if is_broadcasting:
+                # drop any stride==0 dimensions for performance
+                reshape_suffix[i] = "1"
+
+        if V.kernel.no_x_dim:
+            assert range_trees[0].prefix == "x"
+            reshape_suffix.pop(0)
+
+        if (
+            not V.kernel.inside_reduction
+            and len(strides) == len(V.kernel.numels) - 1
+            and V.kernel.numels[-1] != 1
+        ):
+            # Need to expand rank by 1 to match rank when self.inside_reduction=True
+            reshape_suffix.append("1")
+
+        def filter(it):
+            """Removes any broadcasting dims from a given sequence"""
+            assert len(it) == len(broadcasting_dim)
+            return [
+                item
+                for item, is_broadcasting in zip(it, broadcasting_dim)
+                if not is_broadcasting
+            ]
+
+        return BlockPtrOptions(
+            constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
+            shape=[
+                V.graph.sizevars.lookup_precomputed_size(t.numel)
+                for t in filter(range_trees)
+            ],
+            strides=[*map(V.graph.sizevars.lookup_precomputed_size, filter(strides))],
+            block_shape=filter(block_shape),
+            order=V.graph.sizevars.guarded_order(filter(strides)),
+            offsets=filter([f"{t.prefix}offset" for t in range_trees]),
+            mask_vars=mask_vars,
+            reshape_suffix=reshape_suffix,
+        )
+
+    def format(self, name: str, roffset=True) -> str:
+        """
+        Codegen a call to tl.make_block_ptr()
+
+        Args:
+            name: variable name for pointer
+            roffset: should roffset be included in offsets=..., for use with tl.advance()
+
+        Returns:
+            "tl.make_block_ptr(...)"
+        """
+        f = V.kernel.index_to_str
+        offsets = [*self.offsets]
+        if not roffset:
+            offsets[offsets.index("roffset")] = "0"
+        args = [
+            f"{name} + ({f(self.constant_offset)})"
+            if self.constant_offset != 0
+            else name,
+            f"shape={f(self.shape)}",
+            f"strides={f(self.strides)}",
+            f"block_shape={f(self.block_shape)}",
+            f"order={f(self.order)}",
+            f"offsets={f(offsets)}",
+        ]
+        return f"tl.make_block_ptr({', '.join(args)})"
+
+    @cache_on_self
+    def boundary_check(self) -> List[int]:
+        """List of indices to pass to tl.load(boundary_check=...)"""
+        check = []
+        for i in range(len(self.shape)):
+            if (
+                self.block_shape[i] != "1"
+                and not V.graph.sizevars.statically_known_equals(self.strides[i], 0)
+                and not V.graph.sizevars.statically_known_multiple_of(
+                    self.shape[i],
+                    config.triton.max_block[self.block_shape[i][0]],
+                )
+                and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
+            ):
+                check.append(i)
+        return check
+
+    def advance_roffset(self):
+        """Codegen string to pass to tl.advance(name, ...)"""
+        advance = ["0"] * len(self.shape)
+        advance[self.offsets.index("roffset")] = "RBLOCK"
+        return V.kernel.index_to_str(advance)
+
+    def has_rindex(self):
+        return "RBLOCK" in self.block_shape
+
+    def has_tmpmask(self):
+        return False  # block_ptr can't do indirect indexing
+
+    def has_mask(self):
+        return bool(self.boundary_check())
+
+
+def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
+    """Workaround https://github.com/openai/triton/issues/2836"""
+    assert isinstance(old_shape, list) and isinstance(new_shape, list)
+    if old_shape == new_shape:
+        return value
+    if [s for s in new_shape if s != "1"] != old_shape:
+        return f"tl.reshape({value}, [{', '.join(new_shape)}])"
+    # rewrite to [:, None] syntax, which is less buggy
+    idx = 0
+    expand = []
+    for size in new_shape:
+        if idx < len(old_shape) and size == old_shape[idx]:
+            expand.append(":")
+            idx += 1
+        else:
+            assert size == "1"
+            expand.append("None")
+    assert idx == len(old_shape)
+    return f"{value}[{', '.join(expand)}]"
+
+
 class TritonPrinter(PythonPrinter):
     def _print_floor(self, expr):
         assert len(expr.args) == 1
@@ -207,6 +353,17 @@
     return f"tl.{triton_type_name}"
 
 
+def triton_store_type(dtype):
+    triton_type_name = str(dtype).split(".")[-1]
+    if triton_type_name == "bool":
+        triton_type_name = "int8"
+    elif triton_type_name == "float8_e4m3fn":
+        triton_type_name = "float8e4nv"
+    elif triton_type_name == "float8_e5m2":
+        triton_type_name = "float8e5"
+    return f"tl.{triton_type_name}"
+
+
 def triton_acc_type(dtype):
     if is_integer_dtype(dtype) and dtype.is_signed:
         nbits = 64 if dtype == torch.int64 else 32
@@ -657,7 +814,7 @@
 
     @classmethod
     def index_expr(cls, expr, dtype):
-        indexing = V.kernel.indexing(expr)
+        indexing = V.kernel.indexing(expr, block_ptr=False)
         assert isinstance(indexing, IndexingOptions)
         # This is called from CSEProxy.__getattr__,  so we'll set the bounds there
         var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
@@ -1005,6 +1162,7 @@
         self.index_dtype: str = index_dtype
         self.min_elem_per_thread = min_elem_per_thread
         self.last_usage: Set[str] = set()
+        self.block_ptr_id = itertools.count()
 
         self.persistent_reduction: bool = self.should_use_persistent_reduction()
         self.no_x_dim = (
@@ -1289,7 +1447,8 @@
         copy_shape=None,
         dense_indexing=False,
         override_mask=None,
-    ) -> IndexingOptions:
+        block_ptr=False,
+    ) -> Union[IndexingOptions, BlockPtrOptions]:
         """
         Compute the index and mask to pass to tl.load() or tl.store()
         """
@@ -1353,6 +1512,36 @@
                 have_dense = False
             dense_mask_vars.add(f"{tree.prefix}mask")
 
+        if (
+            block_ptr
+            and config.triton.use_block_ptr
+            and not override_mask
+            and not self._load_mask
+            and len(mask_vars - dense_mask_vars) == 0
+            and not self.is_indirect_indexing(index)
+            and have_loop_vars
+            # workaround https://github.com/openai/triton/issues/2821
+            and self.index_dtype == "tl.int32"
+        ):
+            index_relative_to_xyr_index = sympy_subs(
+                index, {v: t.expr for v, t in self.range_tree_nodes.items()}
+            )
+            range_trees = self.active_range_trees(reorder=True)
+            symbols = [t.symbol() for t in range_trees]
+            strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols]
+            offset = sympy.Wild("_offset", exclude=symbols)
+            m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset)
+            # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with
+            #               a tl.reshape the correct block.  We will miss these cases today.
+            if m:
+                self.filter_masks(mask_vars)
+                return BlockPtrOptions.create(
+                    [m[s] for s in strides],
+                    m[offset],
+                    range_trees,
+                    mask_vars,
+                )
+
         expand_str = None
         index_str = self.index_to_str(index)
         if isinstance(index, sympy.Integer):
@@ -1491,11 +1680,54 @@
             )
         return strides
 
+    def codegen_block_ptr(
+        self, name: str, var: str, indexing: BlockPtrOptions, other=""
+    ) -> Tuple[str, Optional[DeferredLine], str]:
+        advance_block_ptr = None
+        check = indexing.boundary_check()
+        if not check:
+            # workaround https://github.com/openai/triton/issues/2813
+            other = ""
+        elif other:
+            assert other == ", other=0.0"
+            other = f", boundary_check={check!r}, padding_option='zero'"
+        else:
+            other = f", boundary_check={check!r}"
+        if (
+            self.inside_reduction
+            and not self.persistent_reduction
+            and indexing.has_rindex()
+        ):
+            block_ptr = f"block_ptr{next(self.block_ptr_id)}"
+            self.body.writeline(
+                DeferredLine(
+                    name, f"{block_ptr} = {indexing.format(var, roffset=False)}"
+                )
+            )
+            advance_block_ptr = DeferredLine(
+                name,
+                f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})",
+            )
+        else:
+            block_ptr = indexing.format(var)
+        return block_ptr, advance_block_ptr, other
+
+    def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
+        # broadcasting is not implicit for block_ptrs
+        value = (
+            f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
+        )
+        # drop any extra size=1 dimensions
+        value = triton_reshape(value, indexing.reshape_suffix, indexing.block_shape)
+        # workaround https://github.com/openai/triton/issues/2814
+        value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
+        return f"tl.store({block_ptr}, {value}{other})"
+
     def load(self, name: str, index: sympy.Expr):
         var = self.args.input(name)
         indirect_indexing = self.is_indirect_indexing(index)
         original_index = index
-        indexing = self.indexing(index)
+        indexing = self.indexing(index, block_ptr=True)
         has_rindex = indexing.has_rindex()
         has_tmpmask = indexing.has_tmpmask()
 
@@ -1541,11 +1773,21 @@
         else:
             other = ""
 
+        advance_block_ptr = None
         append_broadcast = None
         if V.graph.is_unspec_arg(name):
             line = var
         else:
-            if isinstance(original_index, sympy.Integer):
+            if isinstance(indexing, BlockPtrOptions):
+                block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
+                    name, var, indexing, other
+                )
+                line = f"tl.load({block_ptr}{other}{ep})"
+                # add needed size=1 dimensions
+                line = triton_reshape(
+                    line, indexing.block_shape, indexing.reshape_suffix
+                )
+            elif isinstance(original_index, sympy.Integer):
                 line = f"tl.load({var} + ({original_index}))"
                 append_broadcast = indexing.expand_str
             else:
@@ -1583,6 +1825,9 @@
             line = f"tl.broadcast_to({result_var}, {append_broadcast})"
             result_var = self.cse.generate(load_buffer, line)
 
+        if advance_block_ptr:
+            load_buffer.writeline(advance_block_ptr)
+
         if not self.inside_reduction or not has_rindex:
             self.outside_loop_vars.add(result_var)
 
@@ -1591,7 +1836,7 @@
     def store(self, name, index, value, mode=None):
         var = self.args.output(name)
         original_index = index
-        indexing = self.indexing(index, dense_indexing=True)
+        indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None)
 
         # Guard against write-after-read corruption in triton.
         # See # https://github.com/openai/triton/issues/1615
@@ -1604,13 +1849,24 @@
         if is_inplace and is_broadcasted:
             self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
 
-        if mode is None:
+        advance_block_ptr = None
+        if isinstance(indexing, BlockPtrOptions):
+            block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
+                name, var, indexing
+            )
+            # block_ptr stores don't do implicit casting
+            line = self.codegen_block_ptr_store_line(
+                name, indexing, block_ptr, value, other
+            )
+        elif mode is None:
             line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
         elif mode == "atomic_add":
             line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
         else:
             raise NotImplementedError(f"store mode={mode}")
         self.stores.writeline(DeferredLine(name, line))
+        if advance_block_ptr:
+            self.stores.writeline(advance_block_ptr)
 
         if not self.inside_reduction:
             self.outside_loop_vars.add(value)
@@ -1896,17 +2152,31 @@
     def store_reduction(self, name, index, value):
         assert self.inside_reduction
         self.inside_reduction = False
-        indexing = self.indexing(index)
+        indexing = self.indexing(index, block_ptr=True)
         self.inside_reduction = True
         var = self.args.output(name)
 
-        assert isinstance(indexing, IndexingOptions)
-        self.suffix.writeline(
-            DeferredLine(
-                name,
-                f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
+        if isinstance(indexing, BlockPtrOptions):
+            self.suffix.writeline(
+                DeferredLine(
+                    name,
+                    self.codegen_block_ptr_store_line(
+                        name,
+                        indexing,
+                        indexing.format(var),
+                        value,
+                        f", boundary_check={indexing.boundary_check()!r}",
+                    ),
+                )
             )
-        )
+        else:
+            assert isinstance(indexing, IndexingOptions)
+            self.suffix.writeline(
+                DeferredLine(
+                    name,
+                    f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
+                )
+            )
 
     def _lift_helper(self, fn, num_args) -> str:
         # Lift IR function into a triton function in the global namespace
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index 18649a4..8080673 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -543,6 +543,9 @@
     # We should revisit this once we understand more of the source of register spills.
     spill_threshold: int = 16
 
+    # Generate code containing the newer tl.make_block_ptr() API for loads/store
+    use_block_ptr = False
+
     # Inject a bug into our relu implementation; useful for testing our repro
     # extraction and minification functionality.
     # Valid values: "compile_error", "runtime_error", "accuracy"
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index aba6805..01dabbf 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -329,6 +329,7 @@
         dense_indexing=False,
         copy_shape=None,
         override_mask=None,
+        block_ptr=False,
     ):
         """
         Override the default indexing to use our custom mask and force
@@ -339,6 +340,7 @@
             dense_indexing=False,
             copy_shape=self.template_mask,
             override_mask=self.template_mask,
+            block_ptr=block_ptr,
         )
 
     def initialize_range_tree(self, pid_cache):
diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py
index b8ed515..c99fcaa 100644
--- a/torch/_inductor/sizevars.py
+++ b/torch/_inductor/sizevars.py
@@ -343,6 +343,23 @@
         else:
             return self.guard_equals(left, right)
 
+    def guarded_order(self, seq):
+        """
+        Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
+        Used for generating block_ptrs.
+        """
+        seq = [*map(self.remove_precomputed_replacements, seq)]
+        seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
+        seq.sort()
+        order = [-1] * len(seq)
+        last_var = None
+        for new_index, (_, orig_index, var) in enumerate(seq):
+            order[orig_index] = new_index
+            if last_var is not None:
+                self.guard_leq(last_var, var)
+            last_var = var
+        return order
+
     # The evaluate functions evaluate some symbolic sympy expression
     # (NB: not necessarily an Expr) and return what the concrete result
     # is, guarding on the expression being that result