Make aot_module_simplified accept fake tensors (#89670)

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89670
Approved by: https://github.com/voznesenskym
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index 32ffb68..992648a 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -16,7 +16,7 @@
 import torch.utils.dlpack
 from torch import Tensor
 from torch._dynamo.utils import dynamo_timed
-from torch._subclasses import FakeTensorMode, CrossRefFakeMode
+from torch._subclasses import FakeTensorMode, CrossRefFakeMode, FakeTensor
 from torch.fx import immutable_collections, Interpreter
 from torch.fx.experimental.symbolic_shapes import ShapeEnv
 from torch.multiprocessing.reductions import StorageWeakRef
@@ -1604,18 +1604,30 @@
     if config.use_dynamic_shapes:
         assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor"
 
-    shape_env = ShapeEnv() if config.use_dynamic_shapes else None
-    fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
+    # Check flat_args to see if they're already fake.  If so, use that fake
+    # mode instead.
+
+    for x in flat_args:
+        if isinstance(x, FakeTensor):
+            fake_mode = x.fake_mode
+            break
+    else:
+        shape_env = ShapeEnv() if config.use_dynamic_shapes else None
+        fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
+
     cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext()
     python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
 
     with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
 
         def process_inputs(flat_args):
-            if config.use_fake_tensor:
+            if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
                 def convert(idx, x):
                     if not isinstance(x, torch.Tensor):
                         return x
+                    if isinstance(x, FakeTensor):
+                        assert x.fake_mode is fake_mode
+                        return x
                     if idx < aot_config.num_params_buffers and config.static_weight_shapes:
                         return fake_mode.from_tensor(x, static_shapes=True)
                     return fake_mode.from_tensor(x, static_shapes=False)
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 99a776e..a815316 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -38,8 +38,9 @@
     skip,
     skipOps,
 )
-from torch._subclasses.fake_tensor import DynamicOutputShapeException
+from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
 from torch.fx.experimental.proxy_tensor import is_sym_node
+from torch.fx.experimental.symbolic_shapes import ShapeEnv
 
 USE_TORCHVISION = False
 try:
@@ -1546,6 +1547,43 @@
         assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
         assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
 
+    def test_aot_module_simplified_dynamic(self):
+        class MockModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(20, 30)
+
+            def forward(self, x, y):
+                return (self.linear(x) + y, )
+
+        mod = MockModule()
+
+        shape_env = ShapeEnv()
+        fake_mode = FakeTensorMode(shape_env=shape_env)
+
+        x = torch.randn(128, 20, requires_grad=True)
+        y = torch.randn(128, 30, requires_grad=True)
+
+        inputs = [x, y]
+        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
+        compiled_f = aot_module_simplified(mod, fake_inputs, nop)
+
+        ref = mod(*inputs)
+        ref[0].sum().backward()
+
+        cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
+        res = compiled_f(*cloned_inputs)
+        res[0].sum().backward()
+
+        self.assertExpectedInline(shape_env.format_guards(), """\
+ - Eq(s1, 20)
+ - Eq(s2, 30)""")
+
+        assert torch.allclose(ref[0], res[0])
+        assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
+        assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
+
+
     def test_aot_module_simplified_preserves_stack_trace(self):
         class MockModule(torch.nn.Module):
             def __init__(self):