[FX] Fix using fx.wrap as a decorator (#50677)
Summary:
`torch.fx.wrap()` could not be used as a decorator as the docstring claimed because it returned None.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50677
Test Plan: Added `test_wrapped_via_decorator` which used to fail with `'NoneType' object is not callable` and now passes
Reviewed By: jamesr66a
Differential Revision: D25949313
Pulled By: jansel
fbshipit-source-id: 02d0f9adeed812f58ec94c94dd4adc43578f21ce
diff --git a/test/test_fx.py b/test/test_fx.py
index ec8321b..4f01e87 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -54,6 +54,11 @@
wrap('len')
+
+@wrap
+def wrapped_via_decorator(a):
+ return a + 1
+
class Pair(NamedTuple):
x : torch.Tensor
y : torch.Tensor
@@ -244,6 +249,16 @@
self.assertIn('a_lifted_leaf2', m.code)
self.assertEqual(27, m(2))
+ def test_wrapped_via_decorator(self):
+ self.assertEqual(wrapped_via_decorator(0), 1)
+
+ def to_trace(y):
+ return wrapped_via_decorator(y)
+
+ m = symbolic_trace(to_trace)
+ self.assertIn('wrapped_via_decorator', m.code)
+ self.assertEqual(m(0), 1)
+
def test_graph_edit_with_proxy(self):
class M(torch.nn.Module):
def forward(self, a, b):
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index c7b7c8f..87a3e4a 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -460,6 +460,7 @@
raise NotImplementedError('wrap must be called at the top level of a module')
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
+ return fn_or_name
def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
"""Symbolic tracing API