[Inductor] Support top level constants in user defined triton kernels (#111970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111970
Approved by: https://github.com/jansel
ghstack dependencies: #111956
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 5f10367..8837c4c 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1380,6 +1380,11 @@
# NB: This also addresses a triton limitation where if the kernels are
# getting called indirectly, triton cannot find the kernels unless they
# are at top level.
+ # Define constants here for the same triton limitation
+ CONSTANT_C = 4
+ STRING_CONSTANT_C = "CONSTANT_C"
+ BOOL_CONSTANT_C = True
+
@triton.jit
def add_kernel(
in_ptr0,
@@ -1933,6 +1938,57 @@
@requires_cuda()
@requires_triton()
+ def test_triton_kernel_constants(self):
+ @triton.jit
+ def mulC_kernel(
+ in_ptr0,
+ out_ptr,
+ n_elements,
+ BLOCK_SIZE: "tl.constexpr",
+ CONSTANT_NAME: "tl.constexpr",
+ ):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(in_ptr0 + offsets, mask=mask)
+ if CONSTANT_NAME.value == STRING_CONSTANT_C:
+ output = CONSTANT_C * x
+ if BOOL_CONSTANT_C:
+ output *= CONSTANT_C
+ tl.store(out_ptr + offsets, output, mask=mask)
+
+ def call_triton(
+ x: torch.Tensor,
+ ):
+ output = torch.zeros_like(x)
+ n_elements = output.numel()
+
+ grid = (x.numel(),)
+ mulC_kernel[grid](
+ x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
+ )
+ return output
+
+ # Triton kernels capture global constants by their parse time value
+ # not runtime value
+ global CONSTANT_C
+ prev_c = CONSTANT_C
+ # If the behavior of triton kernels change, this test will fail
+ CONSTANT_C = 10
+ assert CONSTANT_C != prev_c
+
+ t = torch.randn(5, device="cuda")
+ torch_result = call_triton(t)
+ compiled_result = torch.compile(call_triton)(t)
+
+ self.assertEqual(torch_result, compiled_result)
+
+ # reset back
+ CONSTANT_C = prev_c
+
+ @requires_cuda()
+ @requires_triton()
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 8cadb4f..99dc56c 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -857,6 +857,10 @@
compile_wrapper.splice(symbol.src, strip=True)
symbols_included.add(symbol_name)
traverse(symbol)
+ elif isinstance(symbol, (int, str, bool)):
+ compile_wrapper.newline()
+ compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
+ symbols_included.add(symbol_name)
traverse(kernel)