[dynamo] add repro for functorch/fx interop issue (`allow_in_graph`) (#111746)
Fixes https://github.com/pytorch/pytorch/issues/109025 by adding repro
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111746
Approved by: https://github.com/voznesenskym
diff --git a/test/dynamo/test_interop.py b/test/dynamo/test_interop.py
index 1576706..163effe 100644
--- a/test/dynamo/test_interop.py
+++ b/test/dynamo/test_interop.py
@@ -31,6 +31,34 @@
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
self._common(lambda a, b: trace_fn(a, b) + 1)
+ def test_vmap_in_graph(self):
+ from functools import wraps
+
+ from torch._dynamo import allow_in_graph
+
+ def traceable(f):
+ f = allow_in_graph(f)
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ return f(*args, **kwargs)
+
+ return wrapper
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ x = torch.randn(3, 5, 3)
+
+ def fn(x):
+ return torch.vmap(torch.Tensor.t)(x)
+
+ fn_opt = torch.compile(fn, backend=cnts, fullgraph=True)
+ fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True)
+
+ self.assertEqual(fn(x), fn_opt(x))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(fn_opt(x), fn_opt_traceable(x))
+ self.assertEqual(cnts.frame_count, 2)
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests