[Traceable FSDP2] Add Dynamo support for run_with_rng_state HOP (#127247)

Test command:
`pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127247
Approved by: https://github.com/bdhirsh
ghstack dependencies: #129414
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index fe65dca..f218a20 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -11,6 +11,7 @@
 
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from torch import _inductor as inductor
 from torch._dynamo import compiled_autograd, config
 from torch._dynamo.utils import counters
@@ -1308,6 +1309,99 @@
 
         self.check_output_and_recompiles(fn, 1)
 
+    def test_trace_run_with_rng_state(self):
+        def sdpa(xq, xk):
+            return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
+
+        def g(xq_1, xk_1, xq_2, xk_2):
+            # xq: (bs, n_local_heads, seqlen, head_dim)
+            # xk: (bs, n_local_heads, cache_len + seqlen, head_dim)
+            y1 = sdpa(xq_1, xk_1)
+            y2 = torch.utils.checkpoint.checkpoint(
+                sdpa, xq_2, xk_2, use_reentrant=False
+            )
+            y = torch.mul(y1, y2)
+            z = torch.matmul(y, y)
+            return z
+
+        def f():
+            bs = 1
+            n_local_heads = 1
+            seqlen = 2
+            head_dim = 2
+            cache_len = 2
+            xq_list = [
+                torch.ones(
+                    (bs, n_local_heads, seqlen, head_dim),
+                    requires_grad=True,
+                    device="cpu",
+                )
+                for _ in range(2)
+            ]
+            xk_list = [
+                torch.ones(
+                    (bs, n_local_heads, cache_len + seqlen, head_dim),
+                    requires_grad=True,
+                    device="cpu",
+                )
+                for _ in range(2)
+            ]
+            out = torch.compile(g, fullgraph=True)(
+                xq_list[0], xk_list[0], xq_list[1], xk_list[1]
+            )
+            out.sum().backward()
+            return out, *[x.grad for x in xq_list + xk_list]
+
+        """
+        Walkthrough of what happens with `run_with_rng_state`:
+        1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner).
+        2. The Dynamo graph captured by Compiled Autograd looks like:
+        ```
+        ===== __compiled_fn_3 =====
+        torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
+            def forward(self, L_inputs_ : list):
+                ...
+                run_with_rng_state = torch.ops.higher_order.run_with_rng_state(
+                    getitem_8,
+                    torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
+                    getitem_3, getitem_4, getitem_4, 0.0, True,
+                )
+                ...
+        ```
+        3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling
+        in `run_with_rng_state` op's py_functionalize_impl.
+        """
+
+        def _run_with_rng_state_op_check(inductor_post_grad_graph):
+            # Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph.
+            op_set = {node.target for node in inductor_post_grad_graph.nodes}
+            if torch.ops.higher_order.run_and_save_rng_state not in op_set:
+                # This is backward graph, so check existence of `run_with_rng_state` op
+                self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set)
+
+        with torch._inductor.config.patch(
+            post_grad_custom_post_pass=_run_with_rng_state_op_check
+        ):
+            compiler_fn = make_compiler_fn(fullgraph=True)
+
+            def make_compiler_fn_with_op_check():
+                def _compiler_fn(gm):
+                    # Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph.
+                    self.assertTrue(
+                        any(
+                            node.target is torch.ops.higher_order.run_with_rng_state
+                            for node in gm.graph.nodes
+                        )
+                    )
+                    return compiler_fn(gm)
+
+                return _compiler_fn
+
+            compiler_fn_with_op_check = make_compiler_fn_with_op_check()
+            self.check_output_and_recompiles(
+                f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
+            )
+
     def test_autograd_cpp_node(self):
         cpp_source = """
 struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 59f8c26..f874ca2 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -554,6 +554,8 @@
             return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
         elif value.__name__ == "strict_mode":
             return StrictModeHigherOrderVariable(value, source, **kwargs)
+        elif value.__name__ == "run_with_rng_state":
+            return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
         elif value.__name__ == "associative_scan":
             return AssociativeScanHigherOrderVariable(value, source, **kwargs)
         elif value.__name__ == "call_torchbind":
@@ -1440,6 +1442,26 @@
         )
 
 
+class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable):
+    def call_function(
+        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+    ) -> "VariableTracker":
+        from .builder import wrap_fx_proxy
+
+        p_args = tuple(arg.as_proxy() for arg in args)
+        p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
+        return wrap_fx_proxy(
+            tx=tx,
+            proxy=tx.output.create_proxy(
+                "call_function",
+                self.value,
+                args=p_args,
+                kwargs=p_kwargs,
+            ),
+            example_value=None,
+        )
+
+
 class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
     """
     Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index 1345ff0..f1d2fc6 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -247,6 +247,18 @@
         with mode:
             return op(*args, **kwargs)
 
+    @run_with_rng_state.py_functionalize_impl
+    def impl_functional(ctx, rng_state, op, *args, **kwargs):
+        unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
+        unwrapped_args = ctx.unwrap_tensors(args)
+        unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
+
+        with ctx.redispatch_to_next():
+            out = run_with_rng_state(
+                unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
+            )
+            return ctx.wrap_tensors(out)
+
     return run_with_rng_state