[inductor] Add support for tl.make_block_ptr (#116079)
On A100 this is a small regression:

So I will leave it disabled by default.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116079
Approved by: https://github.com/shunting314
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index c38ffa5..885dee5 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -1072,6 +1072,48 @@
self.assertEqual(o1, o2)
+ @config.patch("triton.use_block_ptr", True)
+ def test_selecsls42b_misaligned_address(self):
+ # https://github.com/openai/triton/issues/2836
+
+ @torch.compile(fullgraph=True)
+ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
+ div = torch.ops.aten.div.Scalar(expand, 16)
+ where = torch.ops.aten.where.self(arg207_1, full, div)
+ convert_element_type_43 = torch.ops.prims.convert_element_type.default(
+ where, torch.float32
+ )
+ sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
+ sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
+ mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
+ sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
+ mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
+ unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
+ unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
+ unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
+ mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
+ mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
+ unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
+ unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
+ unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
+ mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
+ sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
+ sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
+ return (sub_2,)
+
+ args = [
+ torch.randn((8, 1024, 4, 4), device="cuda") > 0, # torch.bool tensor
+ torch.randn((1, 1024, 1, 1), device="cuda"),
+ torch.randn((8, 1024, 4, 4), device="cuda"),
+ torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand(
+ (8, 1024, 4, 4)
+ ),
+ torch.randn((), device="cuda"),
+ torch.randn((1024,), device="cuda"),
+ ]
+ fn(*args)
+ torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index d48ad5e..5b18cc6 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -546,6 +546,8 @@
stmt = line.split(".store")[-1]
elif "[" in line:
stmt = line.split("[")[-1].split("]")[0]
+ if "tl.make_block_ptr(" in line:
+ continue
if stmt is None:
continue
@@ -8627,6 +8629,7 @@
self.assertEqual(fn_opt(), fn())
+ @config.patch("triton.use_block_ptr", False)
def test_evict_last_non_coalesced_loads(self):
@torch.compile
def f(a, b):
@@ -8638,13 +8641,33 @@
torch.randn(N, N, N, device="cuda").permute(1, 2, 0),
)
code = run_and_get_triton_code(f, *inps)
- self.assertTrue(
- "tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last'"
- in code
+ lines = [line for line in code.split("\n") if "tl.load" in line]
+ self.assertExpectedInline(
+ "\n".join(lines),
+ """\
+ tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
+ tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
)
- self.assertTrue(
- "tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first',"
- in code
+
+ @skipIfRocm
+ @config.patch("triton.use_block_ptr", True)
+ def test_evict_last_non_coalesced_loads_block_ptr(self):
+ @torch.compile
+ def f(a, b):
+ return (a * b).sum(dim=-1)
+
+ N = 512
+ inps = (
+ torch.randn(N, N, N, device="cuda").permute(2, 1, 0),
+ torch.randn(N, N, N, device="cuda").permute(1, 2, 0),
+ )
+ code = run_and_get_triton_code(f, *inps)
+ lines = [line for line in code.split("\n") if "tl.load" in line]
+ self.assertExpectedInline(
+ "\n".join(lines),
+ """\
+ tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
+ tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""",
)
# Disable index propagation, so the indirect indexing isn't optimized away
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index af08736..a09cf97 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -495,6 +495,7 @@
def __init__(self, name, line):
super().__init__(line)
self.name = name
+ assert not isinstance(line, DeferredLineBase)
def __call__(self):
if all(
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 5bb43dc..0f29230 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -43,6 +43,7 @@
from ..scheduler import BaseScheduling, WhyNoFuse
from ..triton_heuristics import AutotuneHint
from ..utils import (
+ cache_on_self,
do_bench,
get_fused_kernel_name,
get_kernel_metadata,
@@ -50,6 +51,7 @@
is_welford_reduction,
next_power_of_2,
Placeholder,
+ sympy_dot,
sympy_product,
sympy_subs,
sympy_symbol,
@@ -97,6 +99,150 @@
return "tmp" in self.mask_str
+@dataclasses.dataclass
+class BlockPtrOptions:
+ constant_offset: sympy.Expr
+ shape: List[sympy.Expr]
+ strides: List[sympy.Expr]
+ block_shape: List[str]
+ order: List[int]
+ offsets: List[str]
+ mask_vars: Set[sympy.Symbol]
+ reshape_suffix: List[str]
+
+ @staticmethod
+ def create(
+ strides: List[sympy.Expr],
+ constant_offset: sympy.Expr,
+ range_trees: List[IterationRangesEntry],
+ mask_vars: Set[sympy.Symbol],
+ ) -> BlockPtrOptions:
+ """Helper to create a BlockPtrOptions instance"""
+ block_shape = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
+ reshape_suffix = [*block_shape]
+
+ broadcasting_dim = [s == 0 for s in strides]
+ for i, is_broadcasting in enumerate(broadcasting_dim):
+ if is_broadcasting:
+ # drop any stride==0 dimensions for performance
+ reshape_suffix[i] = "1"
+
+ if V.kernel.no_x_dim:
+ assert range_trees[0].prefix == "x"
+ reshape_suffix.pop(0)
+
+ if (
+ not V.kernel.inside_reduction
+ and len(strides) == len(V.kernel.numels) - 1
+ and V.kernel.numels[-1] != 1
+ ):
+ # Need to expand rank by 1 to match rank when self.inside_reduction=True
+ reshape_suffix.append("1")
+
+ def filter(it):
+ """Removes any broadcasting dims from a given sequence"""
+ assert len(it) == len(broadcasting_dim)
+ return [
+ item
+ for item, is_broadcasting in zip(it, broadcasting_dim)
+ if not is_broadcasting
+ ]
+
+ return BlockPtrOptions(
+ constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
+ shape=[
+ V.graph.sizevars.lookup_precomputed_size(t.numel)
+ for t in filter(range_trees)
+ ],
+ strides=[*map(V.graph.sizevars.lookup_precomputed_size, filter(strides))],
+ block_shape=filter(block_shape),
+ order=V.graph.sizevars.guarded_order(filter(strides)),
+ offsets=filter([f"{t.prefix}offset" for t in range_trees]),
+ mask_vars=mask_vars,
+ reshape_suffix=reshape_suffix,
+ )
+
+ def format(self, name: str, roffset=True) -> str:
+ """
+ Codegen a call to tl.make_block_ptr()
+
+ Args:
+ name: variable name for pointer
+ roffset: should roffset be included in offsets=..., for use with tl.advance()
+
+ Returns:
+ "tl.make_block_ptr(...)"
+ """
+ f = V.kernel.index_to_str
+ offsets = [*self.offsets]
+ if not roffset:
+ offsets[offsets.index("roffset")] = "0"
+ args = [
+ f"{name} + ({f(self.constant_offset)})"
+ if self.constant_offset != 0
+ else name,
+ f"shape={f(self.shape)}",
+ f"strides={f(self.strides)}",
+ f"block_shape={f(self.block_shape)}",
+ f"order={f(self.order)}",
+ f"offsets={f(offsets)}",
+ ]
+ return f"tl.make_block_ptr({', '.join(args)})"
+
+ @cache_on_self
+ def boundary_check(self) -> List[int]:
+ """List of indices to pass to tl.load(boundary_check=...)"""
+ check = []
+ for i in range(len(self.shape)):
+ if (
+ self.block_shape[i] != "1"
+ and not V.graph.sizevars.statically_known_equals(self.strides[i], 0)
+ and not V.graph.sizevars.statically_known_multiple_of(
+ self.shape[i],
+ config.triton.max_block[self.block_shape[i][0]],
+ )
+ and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
+ ):
+ check.append(i)
+ return check
+
+ def advance_roffset(self):
+ """Codegen string to pass to tl.advance(name, ...)"""
+ advance = ["0"] * len(self.shape)
+ advance[self.offsets.index("roffset")] = "RBLOCK"
+ return V.kernel.index_to_str(advance)
+
+ def has_rindex(self):
+ return "RBLOCK" in self.block_shape
+
+ def has_tmpmask(self):
+ return False # block_ptr can't do indirect indexing
+
+ def has_mask(self):
+ return bool(self.boundary_check())
+
+
+def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
+ """Workaround https://github.com/openai/triton/issues/2836"""
+ assert isinstance(old_shape, list) and isinstance(new_shape, list)
+ if old_shape == new_shape:
+ return value
+ if [s for s in new_shape if s != "1"] != old_shape:
+ return f"tl.reshape({value}, [{', '.join(new_shape)}])"
+ # rewrite to [:, None] syntax, which is less buggy
+ idx = 0
+ expand = []
+ for size in new_shape:
+ if idx < len(old_shape) and size == old_shape[idx]:
+ expand.append(":")
+ idx += 1
+ else:
+ assert size == "1"
+ expand.append("None")
+ assert idx == len(old_shape)
+ return f"{value}[{', '.join(expand)}]"
+
+
class TritonPrinter(PythonPrinter):
def _print_floor(self, expr):
assert len(expr.args) == 1
@@ -207,6 +353,17 @@
return f"tl.{triton_type_name}"
+def triton_store_type(dtype):
+ triton_type_name = str(dtype).split(".")[-1]
+ if triton_type_name == "bool":
+ triton_type_name = "int8"
+ elif triton_type_name == "float8_e4m3fn":
+ triton_type_name = "float8e4nv"
+ elif triton_type_name == "float8_e5m2":
+ triton_type_name = "float8e5"
+ return f"tl.{triton_type_name}"
+
+
def triton_acc_type(dtype):
if is_integer_dtype(dtype) and dtype.is_signed:
nbits = 64 if dtype == torch.int64 else 32
@@ -657,7 +814,7 @@
@classmethod
def index_expr(cls, expr, dtype):
- indexing = V.kernel.indexing(expr)
+ indexing = V.kernel.indexing(expr, block_ptr=False)
assert isinstance(indexing, IndexingOptions)
# This is called from CSEProxy.__getattr__, so we'll set the bounds there
var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
@@ -1005,6 +1162,7 @@
self.index_dtype: str = index_dtype
self.min_elem_per_thread = min_elem_per_thread
self.last_usage: Set[str] = set()
+ self.block_ptr_id = itertools.count()
self.persistent_reduction: bool = self.should_use_persistent_reduction()
self.no_x_dim = (
@@ -1289,7 +1447,8 @@
copy_shape=None,
dense_indexing=False,
override_mask=None,
- ) -> IndexingOptions:
+ block_ptr=False,
+ ) -> Union[IndexingOptions, BlockPtrOptions]:
"""
Compute the index and mask to pass to tl.load() or tl.store()
"""
@@ -1353,6 +1512,36 @@
have_dense = False
dense_mask_vars.add(f"{tree.prefix}mask")
+ if (
+ block_ptr
+ and config.triton.use_block_ptr
+ and not override_mask
+ and not self._load_mask
+ and len(mask_vars - dense_mask_vars) == 0
+ and not self.is_indirect_indexing(index)
+ and have_loop_vars
+ # workaround https://github.com/openai/triton/issues/2821
+ and self.index_dtype == "tl.int32"
+ ):
+ index_relative_to_xyr_index = sympy_subs(
+ index, {v: t.expr for v, t in self.range_tree_nodes.items()}
+ )
+ range_trees = self.active_range_trees(reorder=True)
+ symbols = [t.symbol() for t in range_trees]
+ strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols]
+ offset = sympy.Wild("_offset", exclude=symbols)
+ m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset)
+ # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with
+ # a tl.reshape the correct block. We will miss these cases today.
+ if m:
+ self.filter_masks(mask_vars)
+ return BlockPtrOptions.create(
+ [m[s] for s in strides],
+ m[offset],
+ range_trees,
+ mask_vars,
+ )
+
expand_str = None
index_str = self.index_to_str(index)
if isinstance(index, sympy.Integer):
@@ -1491,11 +1680,54 @@
)
return strides
+ def codegen_block_ptr(
+ self, name: str, var: str, indexing: BlockPtrOptions, other=""
+ ) -> Tuple[str, Optional[DeferredLine], str]:
+ advance_block_ptr = None
+ check = indexing.boundary_check()
+ if not check:
+ # workaround https://github.com/openai/triton/issues/2813
+ other = ""
+ elif other:
+ assert other == ", other=0.0"
+ other = f", boundary_check={check!r}, padding_option='zero'"
+ else:
+ other = f", boundary_check={check!r}"
+ if (
+ self.inside_reduction
+ and not self.persistent_reduction
+ and indexing.has_rindex()
+ ):
+ block_ptr = f"block_ptr{next(self.block_ptr_id)}"
+ self.body.writeline(
+ DeferredLine(
+ name, f"{block_ptr} = {indexing.format(var, roffset=False)}"
+ )
+ )
+ advance_block_ptr = DeferredLine(
+ name,
+ f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})",
+ )
+ else:
+ block_ptr = indexing.format(var)
+ return block_ptr, advance_block_ptr, other
+
+ def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
+ # broadcasting is not implicit for block_ptrs
+ value = (
+ f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
+ )
+ # drop any extra size=1 dimensions
+ value = triton_reshape(value, indexing.reshape_suffix, indexing.block_shape)
+ # workaround https://github.com/openai/triton/issues/2814
+ value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
+ return f"tl.store({block_ptr}, {value}{other})"
+
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
indirect_indexing = self.is_indirect_indexing(index)
original_index = index
- indexing = self.indexing(index)
+ indexing = self.indexing(index, block_ptr=True)
has_rindex = indexing.has_rindex()
has_tmpmask = indexing.has_tmpmask()
@@ -1541,11 +1773,21 @@
else:
other = ""
+ advance_block_ptr = None
append_broadcast = None
if V.graph.is_unspec_arg(name):
line = var
else:
- if isinstance(original_index, sympy.Integer):
+ if isinstance(indexing, BlockPtrOptions):
+ block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
+ name, var, indexing, other
+ )
+ line = f"tl.load({block_ptr}{other}{ep})"
+ # add needed size=1 dimensions
+ line = triton_reshape(
+ line, indexing.block_shape, indexing.reshape_suffix
+ )
+ elif isinstance(original_index, sympy.Integer):
line = f"tl.load({var} + ({original_index}))"
append_broadcast = indexing.expand_str
else:
@@ -1583,6 +1825,9 @@
line = f"tl.broadcast_to({result_var}, {append_broadcast})"
result_var = self.cse.generate(load_buffer, line)
+ if advance_block_ptr:
+ load_buffer.writeline(advance_block_ptr)
+
if not self.inside_reduction or not has_rindex:
self.outside_loop_vars.add(result_var)
@@ -1591,7 +1836,7 @@
def store(self, name, index, value, mode=None):
var = self.args.output(name)
original_index = index
- indexing = self.indexing(index, dense_indexing=True)
+ indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None)
# Guard against write-after-read corruption in triton.
# See # https://github.com/openai/triton/issues/1615
@@ -1604,13 +1849,24 @@
if is_inplace and is_broadcasted:
self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
- if mode is None:
+ advance_block_ptr = None
+ if isinstance(indexing, BlockPtrOptions):
+ block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
+ name, var, indexing
+ )
+ # block_ptr stores don't do implicit casting
+ line = self.codegen_block_ptr_store_line(
+ name, indexing, block_ptr, value, other
+ )
+ elif mode is None:
line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
elif mode == "atomic_add":
line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
else:
raise NotImplementedError(f"store mode={mode}")
self.stores.writeline(DeferredLine(name, line))
+ if advance_block_ptr:
+ self.stores.writeline(advance_block_ptr)
if not self.inside_reduction:
self.outside_loop_vars.add(value)
@@ -1896,17 +2152,31 @@
def store_reduction(self, name, index, value):
assert self.inside_reduction
self.inside_reduction = False
- indexing = self.indexing(index)
+ indexing = self.indexing(index, block_ptr=True)
self.inside_reduction = True
var = self.args.output(name)
- assert isinstance(indexing, IndexingOptions)
- self.suffix.writeline(
- DeferredLine(
- name,
- f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
+ if isinstance(indexing, BlockPtrOptions):
+ self.suffix.writeline(
+ DeferredLine(
+ name,
+ self.codegen_block_ptr_store_line(
+ name,
+ indexing,
+ indexing.format(var),
+ value,
+ f", boundary_check={indexing.boundary_check()!r}",
+ ),
+ )
)
- )
+ else:
+ assert isinstance(indexing, IndexingOptions)
+ self.suffix.writeline(
+ DeferredLine(
+ name,
+ f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
+ )
+ )
def _lift_helper(self, fn, num_args) -> str:
# Lift IR function into a triton function in the global namespace
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index 18649a4..8080673 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -543,6 +543,9 @@
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
+ # Generate code containing the newer tl.make_block_ptr() API for loads/store
+ use_block_ptr = False
+
# Inject a bug into our relu implementation; useful for testing our repro
# extraction and minification functionality.
# Valid values: "compile_error", "runtime_error", "accuracy"
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index aba6805..01dabbf 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -329,6 +329,7 @@
dense_indexing=False,
copy_shape=None,
override_mask=None,
+ block_ptr=False,
):
"""
Override the default indexing to use our custom mask and force
@@ -339,6 +340,7 @@
dense_indexing=False,
copy_shape=self.template_mask,
override_mask=self.template_mask,
+ block_ptr=block_ptr,
)
def initialize_range_tree(self, pid_cache):
diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py
index b8ed515..c99fcaa 100644
--- a/torch/_inductor/sizevars.py
+++ b/torch/_inductor/sizevars.py
@@ -343,6 +343,23 @@
else:
return self.guard_equals(left, right)
+ def guarded_order(self, seq):
+ """
+ Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
+ Used for generating block_ptrs.
+ """
+ seq = [*map(self.remove_precomputed_replacements, seq)]
+ seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
+ seq.sort()
+ order = [-1] * len(seq)
+ last_var = None
+ for new_index, (_, orig_index, var) in enumerate(seq):
+ order[orig_index] = new_index
+ if last_var is not None:
+ self.guard_leq(last_var, var)
+ last_var = var
+ return order
+
# The evaluate functions evaluate some symbolic sympy expression
# (NB: not necessarily an Expr) and return what the concrete result
# is, guarding on the expression being that result