Hack up make_fx to natively support varargs (#83210)
This is kind of nasty but it works. I attempted to fix FX
first but the inspect logic is impenetrable.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83210
Approved by: https://github.com/Chillee, https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index f51248f..48c96ba 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -329,6 +329,12 @@
inp = torch.randn(3, 3, 250, 250)
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
+ def test_varargs(self):
+ def f(*args):
+ return sum(args)
+
+ self._test(f, [torch.randn(2), torch.randn(2)])
+
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 8251753..2d3f3c8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -15,6 +15,7 @@
from torch.utils._mode_utils import no_dispatch
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
+import inspect
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
from torch._subclasses import FakeTensor
@@ -27,6 +28,12 @@
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
+def fake_signature(fn, nargs):
+ """FX gets confused by varargs, de-confuse it"""
+ argnames = ",".join(f"arg{i}" for i in range(nargs))
+ return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
+
+
class ProxySymInt(object):
def __init__(self, sym_int, proxy):
assert isinstance(sym_int, torch._C.SymIntNode) or isinstance(sym_int, int)
@@ -575,8 +582,15 @@
else:
args = pytree.tree_map(wrap_fn_map[tracing_mode], args)
+ if not hasattr(f, '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS:
+ # FX doesn't support varargs, so we gotta fake up a wrapper
+ # TODO: Would be nice to fix this at the source...
+ func = fake_signature(f, len(phs))
+ else:
+ func = f
+
with decompose(decomposition_table), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined]
- t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs))
+ t = dispatch_trace(wrap_key(func, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs))
# TODO: kind of a bad way to do it, should maybe figure out a better way
t.shape_env = shape_env # type: ignore[assignment]