[Traceable FSDP2] Add Dynamo support for run_with_rng_state HOP (#127247)
Test command:
`pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127247
Approved by: https://github.com/bdhirsh
ghstack dependencies: #129414
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index fe65dca..f218a20 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -11,6 +11,7 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd, config
from torch._dynamo.utils import counters
@@ -1308,6 +1309,99 @@
self.check_output_and_recompiles(fn, 1)
+ def test_trace_run_with_rng_state(self):
+ def sdpa(xq, xk):
+ return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
+
+ def g(xq_1, xk_1, xq_2, xk_2):
+ # xq: (bs, n_local_heads, seqlen, head_dim)
+ # xk: (bs, n_local_heads, cache_len + seqlen, head_dim)
+ y1 = sdpa(xq_1, xk_1)
+ y2 = torch.utils.checkpoint.checkpoint(
+ sdpa, xq_2, xk_2, use_reentrant=False
+ )
+ y = torch.mul(y1, y2)
+ z = torch.matmul(y, y)
+ return z
+
+ def f():
+ bs = 1
+ n_local_heads = 1
+ seqlen = 2
+ head_dim = 2
+ cache_len = 2
+ xq_list = [
+ torch.ones(
+ (bs, n_local_heads, seqlen, head_dim),
+ requires_grad=True,
+ device="cpu",
+ )
+ for _ in range(2)
+ ]
+ xk_list = [
+ torch.ones(
+ (bs, n_local_heads, cache_len + seqlen, head_dim),
+ requires_grad=True,
+ device="cpu",
+ )
+ for _ in range(2)
+ ]
+ out = torch.compile(g, fullgraph=True)(
+ xq_list[0], xk_list[0], xq_list[1], xk_list[1]
+ )
+ out.sum().backward()
+ return out, *[x.grad for x in xq_list + xk_list]
+
+ """
+ Walkthrough of what happens with `run_with_rng_state`:
+ 1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner).
+ 2. The Dynamo graph captured by Compiled Autograd looks like:
+ ```
+ ===== __compiled_fn_3 =====
+ torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
+ def forward(self, L_inputs_ : list):
+ ...
+ run_with_rng_state = torch.ops.higher_order.run_with_rng_state(
+ getitem_8,
+ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
+ getitem_3, getitem_4, getitem_4, 0.0, True,
+ )
+ ...
+ ```
+ 3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling
+ in `run_with_rng_state` op's py_functionalize_impl.
+ """
+
+ def _run_with_rng_state_op_check(inductor_post_grad_graph):
+ # Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph.
+ op_set = {node.target for node in inductor_post_grad_graph.nodes}
+ if torch.ops.higher_order.run_and_save_rng_state not in op_set:
+ # This is backward graph, so check existence of `run_with_rng_state` op
+ self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set)
+
+ with torch._inductor.config.patch(
+ post_grad_custom_post_pass=_run_with_rng_state_op_check
+ ):
+ compiler_fn = make_compiler_fn(fullgraph=True)
+
+ def make_compiler_fn_with_op_check():
+ def _compiler_fn(gm):
+ # Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph.
+ self.assertTrue(
+ any(
+ node.target is torch.ops.higher_order.run_with_rng_state
+ for node in gm.graph.nodes
+ )
+ )
+ return compiler_fn(gm)
+
+ return _compiler_fn
+
+ compiler_fn_with_op_check = make_compiler_fn_with_op_check()
+ self.check_output_and_recompiles(
+ f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
+ )
+
def test_autograd_cpp_node(self):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 59f8c26..f874ca2 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -554,6 +554,8 @@
return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
elif value.__name__ == "strict_mode":
return StrictModeHigherOrderVariable(value, source, **kwargs)
+ elif value.__name__ == "run_with_rng_state":
+ return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "associative_scan":
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
@@ -1440,6 +1442,26 @@
)
+class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable):
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ from .builder import wrap_fx_proxy
+
+ p_args = tuple(arg.as_proxy() for arg in args)
+ p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
+ return wrap_fx_proxy(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ self.value,
+ args=p_args,
+ kwargs=p_kwargs,
+ ),
+ example_value=None,
+ )
+
+
class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
"""
Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index 1345ff0..f1d2fc6 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -247,6 +247,18 @@
with mode:
return op(*args, **kwargs)
+ @run_with_rng_state.py_functionalize_impl
+ def impl_functional(ctx, rng_state, op, *args, **kwargs):
+ unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
+ unwrapped_args = ctx.unwrap_tensors(args)
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
+
+ with ctx.redispatch_to_next():
+ out = run_with_rng_state(
+ unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
+ )
+ return ctx.wrap_tensors(out)
+
return run_with_rng_state