Make aot_module_simplified accept fake tensors (#89670)
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.
If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test). Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89670
Approved by: https://github.com/voznesenskym
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index 32ffb68..992648a 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -16,7 +16,7 @@
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import dynamo_timed
-from torch._subclasses import FakeTensorMode, CrossRefFakeMode
+from torch._subclasses import FakeTensorMode, CrossRefFakeMode, FakeTensor
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.multiprocessing.reductions import StorageWeakRef
@@ -1604,18 +1604,30 @@
if config.use_dynamic_shapes:
assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor"
- shape_env = ShapeEnv() if config.use_dynamic_shapes else None
- fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
+ # Check flat_args to see if they're already fake. If so, use that fake
+ # mode instead.
+
+ for x in flat_args:
+ if isinstance(x, FakeTensor):
+ fake_mode = x.fake_mode
+ break
+ else:
+ shape_env = ShapeEnv() if config.use_dynamic_shapes else None
+ fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext()
+
cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext()
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
def process_inputs(flat_args):
- if config.use_fake_tensor:
+ if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
def convert(idx, x):
if not isinstance(x, torch.Tensor):
return x
+ if isinstance(x, FakeTensor):
+ assert x.fake_mode is fake_mode
+ return x
if idx < aot_config.num_params_buffers and config.static_weight_shapes:
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(x, static_shapes=False)
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 99a776e..a815316 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -38,8 +38,9 @@
skip,
skipOps,
)
-from torch._subclasses.fake_tensor import DynamicOutputShapeException
+from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
+from torch.fx.experimental.symbolic_shapes import ShapeEnv
USE_TORCHVISION = False
try:
@@ -1546,6 +1547,43 @@
assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
+ def test_aot_module_simplified_dynamic(self):
+ class MockModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(20, 30)
+
+ def forward(self, x, y):
+ return (self.linear(x) + y, )
+
+ mod = MockModule()
+
+ shape_env = ShapeEnv()
+ fake_mode = FakeTensorMode(shape_env=shape_env)
+
+ x = torch.randn(128, 20, requires_grad=True)
+ y = torch.randn(128, 30, requires_grad=True)
+
+ inputs = [x, y]
+ fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
+ compiled_f = aot_module_simplified(mod, fake_inputs, nop)
+
+ ref = mod(*inputs)
+ ref[0].sum().backward()
+
+ cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
+ res = compiled_f(*cloned_inputs)
+ res[0].sum().backward()
+
+ self.assertExpectedInline(shape_env.format_guards(), """\
+ - Eq(s1, 20)
+ - Eq(s2, 30)""")
+
+ assert torch.allclose(ref[0], res[0])
+ assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
+ assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
+
+
def test_aot_module_simplified_preserves_stack_trace(self):
class MockModule(torch.nn.Module):
def __init__(self):