[Inductor] Support top level constants in user defined triton kernels (#111970)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111970
Approved by: https://github.com/jansel
ghstack dependencies: #111956
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 5f10367..8837c4c 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1380,6 +1380,11 @@
     # NB: This also addresses a triton limitation where if the kernels are
     # getting called indirectly, triton cannot find the kernels unless they
     # are at top level.
+    # Define constants here for the same triton limitation
+    CONSTANT_C = 4
+    STRING_CONSTANT_C = "CONSTANT_C"
+    BOOL_CONSTANT_C = True
+
     @triton.jit
     def add_kernel(
         in_ptr0,
@@ -1933,6 +1938,57 @@
 
     @requires_cuda()
     @requires_triton()
+    def test_triton_kernel_constants(self):
+        @triton.jit
+        def mulC_kernel(
+            in_ptr0,
+            out_ptr,
+            n_elements,
+            BLOCK_SIZE: "tl.constexpr",
+            CONSTANT_NAME: "tl.constexpr",
+        ):
+            pid = tl.program_id(axis=0)
+            block_start = pid * BLOCK_SIZE
+            offsets = block_start + tl.arange(0, BLOCK_SIZE)
+            mask = offsets < n_elements
+            x = tl.load(in_ptr0 + offsets, mask=mask)
+            if CONSTANT_NAME.value == STRING_CONSTANT_C:
+                output = CONSTANT_C * x
+            if BOOL_CONSTANT_C:
+                output *= CONSTANT_C
+            tl.store(out_ptr + offsets, output, mask=mask)
+
+        def call_triton(
+            x: torch.Tensor,
+        ):
+            output = torch.zeros_like(x)
+            n_elements = output.numel()
+
+            grid = (x.numel(),)
+            mulC_kernel[grid](
+                x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
+            )
+            return output
+
+        # Triton kernels capture global constants by their parse time value
+        # not runtime value
+        global CONSTANT_C
+        prev_c = CONSTANT_C
+        # If the behavior of triton kernels change, this test will fail
+        CONSTANT_C = 10
+        assert CONSTANT_C != prev_c
+
+        t = torch.randn(5, device="cuda")
+        torch_result = call_triton(t)
+        compiled_result = torch.compile(call_triton)(t)
+
+        self.assertEqual(torch_result, compiled_result)
+
+        # reset back
+        CONSTANT_C = prev_c
+
+    @requires_cuda()
+    @requires_triton()
     @common_utils.parametrize("grad", [False, True])
     @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
     @patch.object(torch._inductor.config, "implicit_fallbacks", False)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 8cadb4f..99dc56c 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -857,6 +857,10 @@
                         compile_wrapper.splice(symbol.src, strip=True)
                         symbols_included.add(symbol_name)
                         traverse(symbol)
+                    elif isinstance(symbol, (int, str, bool)):
+                        compile_wrapper.newline()
+                        compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
+                        symbols_included.add(symbol_name)
 
         traverse(kernel)