Hack up make_fx to natively support varargs (#83210)

This is kind of nasty but it works.  I attempted to fix FX
first but the inspect logic is impenetrable.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83210
Approved by: https://github.com/Chillee, https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index f51248f..48c96ba 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -329,6 +329,12 @@
         inp = torch.randn(3, 3, 250, 250)
         self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
 
+    def test_varargs(self):
+        def f(*args):
+            return sum(args)
+
+        self._test(f, [torch.randn(2), torch.randn(2)])
+
     def test_proxy_tensor(self):
         def f_grad(x):
             val = x.cos().cos().sum()
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 8251753..2d3f3c8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -15,6 +15,7 @@
 from torch.utils._mode_utils import no_dispatch
 from torch.fx.passes.shape_prop import _extract_tensor_metadata
 from contextlib import contextmanager, nullcontext
+import inspect
 
 from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
 from torch._subclasses import FakeTensor
@@ -27,6 +28,12 @@
 CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
 
 
+def fake_signature(fn, nargs):
+    """FX gets confused by varargs, de-confuse it"""
+    argnames = ",".join(f"arg{i}" for i in range(nargs))
+    return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
+
+
 class ProxySymInt(object):
     def __init__(self, sym_int, proxy):
         assert isinstance(sym_int, torch._C.SymIntNode) or isinstance(sym_int, int)
@@ -575,8 +582,15 @@
         else:
             args = pytree.tree_map(wrap_fn_map[tracing_mode], args)
 
+        if not hasattr(f, '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS:
+            # FX doesn't support varargs, so we gotta fake up a wrapper
+            # TODO: Would be nice to fix this at the source...
+            func = fake_signature(f, len(phs))
+        else:
+            func = f
+
         with decompose(decomposition_table), fake_tensor_mode, proxy_mode:  # type: ignore[attr-defined]
-            t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs))
+            t = dispatch_trace(wrap_key(func, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs))
 
         # TODO: kind of a bad way to do it, should maybe figure out a better way
         t.shape_env = shape_env  # type: ignore[assignment]