[FX] Fix using fx.wrap as a decorator (#50677)

Summary:
`torch.fx.wrap()` could not be used as a decorator as the docstring claimed because it returned None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50677

Test Plan: Added `test_wrapped_via_decorator` which used to fail with `'NoneType' object is not callable` and now passes

Reviewed By: jamesr66a

Differential Revision: D25949313

Pulled By: jansel

fbshipit-source-id: 02d0f9adeed812f58ec94c94dd4adc43578f21ce
diff --git a/test/test_fx.py b/test/test_fx.py
index ec8321b..4f01e87 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -54,6 +54,11 @@
 
 wrap('len')
 
+
+@wrap
+def wrapped_via_decorator(a):
+    return a + 1
+
 class Pair(NamedTuple):
     x : torch.Tensor
     y : torch.Tensor
@@ -244,6 +249,16 @@
         self.assertIn('a_lifted_leaf2', m.code)
         self.assertEqual(27, m(2))
 
+    def test_wrapped_via_decorator(self):
+        self.assertEqual(wrapped_via_decorator(0), 1)
+
+        def to_trace(y):
+            return wrapped_via_decorator(y)
+
+        m = symbolic_trace(to_trace)
+        self.assertIn('wrapped_via_decorator', m.code)
+        self.assertEqual(m(0), 1)
+
     def test_graph_edit_with_proxy(self):
         class M(torch.nn.Module):
             def forward(self, a, b):
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index c7b7c8f..87a3e4a 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -460,6 +460,7 @@
         raise NotImplementedError('wrap must be called at the top level of a module')
 
     _wrapped_fns_to_patch.append((f.f_globals, fn_name))
+    return fn_or_name
 
 def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
     """Symbolic tracing API