LazyGraphModule: improve the fix for the FakeTensorMode mismatch issue (#119311)
The previous fix https://github.com/pytorch/pytorch/pull/118981 misses some corner cases. It works when both LazyGraphModule and compiled-autograd are enabled. But it fail with FakeTensorMode mismatch error again if LazyGraphModule+CompiledAutograd+DynamicShape are all enabled. Note that disabling any of the three does not trigger the issue.
The reason why enabling DynamicShape cause the previous fix not working is, we will call the bw_compiler here before running the backward pass if there are symints saved for backward: https://github.com/pytorch/pytorch/blob/73f0fdea5b845a09d849404b06383c329c2c5a8a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py#L382
The bw_compiler may cause extra GraphModule recompilation on the bw_module which cause it's forward method become the lazy one again. The fix is just to delay applying the previous fix after the potential extra call of the bw_compiler.
Repro on hf_Whisper:
```
CUDA_VISIBLE_DEVICES=1 time benchmarks/dynamo/torchbench.py -dcuda --training --backend=inductor --disable-cudagraphs --accuracy --only hf_Whisper --repeat 1 --compiled-autograd --dynamic-batch-only
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119311
Approved by: https://github.com/xmfan, https://github.com/jansel
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index b747f6d..0d1bb4a 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -636,14 +636,19 @@
self.check_output_and_recompiles(fn, 3)
- def test_mismatch_fake_tensor_mode(self):
+ def test_mismatch_fake_tensor_mode(self, dynamic_shape=False):
"""
Repro the failure of training nanogpt with both compiled-autograd
and _LazyGraphModule. Check https://github.com/pytorch/pytorch/pull/118981
for more context.
"""
- x = torch.rand(2, 16)
- y = nn.Parameter(torch.rand(2, 16))
+ B = 8
+ x = torch.rand(B, 16)
+ y = torch.rand(B, 16, requires_grad=True)
+
+ if dynamic_shape:
+ torch._dynamo.mark_dynamic(x, 0)
+ torch._dynamo.mark_dynamic(y, 0)
def f():
out = x + y
@@ -655,6 +660,9 @@
self.check_output_and_recompiles(f, compile_fn=True)
+ def test_mismatch_fake_tensor_mode_dynamic_shape(self):
+ self.test_mismatch_fake_tensor_mode(dynamic_shape=True)
+
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent
diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index e6a2222..59a4572 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -183,18 +183,6 @@
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
- # Compiled autograd will run the bw_module in the backward pass,
- # so recompilation need happen anyway if the backward pass is ever
- # called.
- #
- # The reason we do the GraphModule recompilation here is because
- # the lazy recompilation will cause issue in the backward pass
- # with compiled autograd.
- if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count:
- from torch.fx._lazy_graph_module import _LazyGraphModule
-
- _LazyGraphModule.force_recompile(bw_module)
-
fw_outs = next(n for n in fw_module.graph.nodes if n.op == "output").args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output
@@ -387,6 +375,24 @@
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True,
)
+ # Compiled autograd will run the bw_module in the backward pass,
+ # so recompilation need happen anyway if the backward pass is ever
+ # called.
+ #
+ # The reason we do the GraphModule recompilation here is because
+ # the lazy recompilation will cause issue in the backward pass
+ # with compiled autograd.
+ #
+ # Do the _LazyGraphModule.force_recompile here rather than when
+ # bw_module is first generated by the partitioner because the bw_module.recompile
+ # may be called in some code path later and cause the _LazyGraphModule.forward
+ # becomes the lazy version again. One example is when dynamic shape is enabled
+ # upfront, the bw_compiler will be called above which can cause extra
+ # graph module recompilation on bw_module.
+ if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count:
+ from torch.fx._lazy_graph_module import _LazyGraphModule
+
+ _LazyGraphModule.force_recompile(bw_module)
saved_context = TracingContext.try_get()