[functorch] Fix proxy unwrapping for cond(). (#91907)
In control_flow.cond(), we unwrap arguments' proxy by using
get_proxy_slot() call which call a lambda in the end to get the stored
proxy. For SymInt and SymFloat we hide the proxy under a thunk instead
of storing proxy on .proxy attribute diretly, therefore we need to
special case SymInt for unwrapping here.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91907
Approved by: https://github.com/ezyang
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index d435a86..c4ae34e 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -1,4 +1,5 @@
from dataclasses import dataclass
+from functools import partial
import torch
from torch.multiprocessing.reductions import StorageWeakRef
@@ -10,10 +11,10 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
get_isolated_graphmodule,
- get_proxy_slot,
ProxyTorchDispatchMode,
make_fx,
track_tensor_tree,
+ unwrap_proxy,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import (
@@ -36,11 +37,6 @@
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
- def _unwrap_proxy(e):
- if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
- return e
- return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
-
assert isinstance(operands, list), "Cond operands must be a list of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
@@ -87,7 +83,7 @@
args = (pred, true_graph, false_graph, [operands])
- proxy_args = pytree.tree_map(_unwrap_proxy, args)
+ proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 271869c..568b2de 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -1,3 +1,5 @@
+from functools import partial
+
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
@@ -5,10 +7,10 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
- get_proxy_slot,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
+ unwrap_proxy,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
@@ -21,12 +23,6 @@
def trace_map(proxy_mode, func_overload, f, xs, *args):
- def _unwrap_proxy(e):
- if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
- return e
- return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
-
-
if not isinstance(xs, torch.Tensor):
raise ValueError("map() must loop over a tensor")
if len(xs.shape) == 0 or xs.shape[0] == 0:
@@ -48,7 +44,7 @@
proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, xs, *args)
- proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
+ proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="map")
outs = [body_graph(x, *args) for x in xs]
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 2c984fd..c007bdd 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -636,5 +636,19 @@
res = gm(p, pred, xs, y)
self.assertEqual(res, main(p, pred, xs, y))
+ def test_cond_with_sym_pred(self):
+ def true_fn(x):
+ return x + x
+
+ def false_fn(x):
+ return x * x
+
+ def foo(x):
+ return cond(x.shape[0] == 4, true_fn, false_fn, [x])
+
+ gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
+ x = torch.ones(4, 3, 2)
+ self.assertEqual(foo(x), gm(x))
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index da61e78..36cfdd8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -101,6 +101,14 @@
def snapshot_fake(val):
return val.detach()
+def unwrap_proxy(proxy_mode, e):
+ if isinstance(e, torch.Tensor):
+ return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
+ elif isinstance(e, (torch.SymInt, torch.SymFloat)):
+ return get_proxy_slot(e.node, proxy_mode.tracer, e, lambda e: e())
+ else:
+ return e
+
# What invariants do we have for the 'val' set on the FX node? It has accurate
# metadata... but only for metadata that exists "below" all other subsystems
# (most notably autograd, but also vmap, functorch transforms, etc). This means