[functorch] Fix proxy unwrapping for cond(). (#91907)

In control_flow.cond(), we unwrap arguments' proxy by using
get_proxy_slot() call which call a lambda in the end to get the stored
proxy. For SymInt and SymFloat we hide the proxy under a thunk instead
of storing proxy on .proxy attribute diretly, therefore we need to
special case SymInt for unwrapping here.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91907
Approved by: https://github.com/ezyang
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index d435a86..c4ae34e 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -1,4 +1,5 @@
 from dataclasses import dataclass
+from functools import partial
 import torch
 from torch.multiprocessing.reductions import StorageWeakRef
 
@@ -10,10 +11,10 @@
 from torch._subclasses.fake_tensor import FakeTensorMode
 from torch.fx.experimental.proxy_tensor import (
     get_isolated_graphmodule,
-    get_proxy_slot,
     ProxyTorchDispatchMode,
     make_fx,
     track_tensor_tree,
+    unwrap_proxy,
 )
 from torch.fx.passes.shape_prop import _extract_tensor_metadata
 from torch.utils._python_dispatch import (
@@ -36,11 +37,6 @@
 
 
 def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
-    def _unwrap_proxy(e):
-        if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
-            return e
-        return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
-
     assert isinstance(operands, list), "Cond operands must be a list of tensors"
     assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
 
@@ -87,7 +83,7 @@
 
     args = (pred, true_graph, false_graph, [operands])
 
-    proxy_args = pytree.tree_map(_unwrap_proxy, args)
+    proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
 
     out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
                                                name="conditional")
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 271869c..568b2de 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -1,3 +1,5 @@
+from functools import partial
+
 import torch
 import torch.utils._pytree as pytree
 from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
@@ -5,10 +7,10 @@
 from torch._subclasses.fake_tensor import FakeTensorMode
 from torch.fx.experimental.proxy_tensor import (
     disable_proxy_modes_tracing,
-    get_proxy_slot,
     make_fx,
     ProxyTorchDispatchMode,
     track_tensor_tree,
+    unwrap_proxy,
 )
 from torch.utils._python_dispatch import (
     _get_current_dispatch_mode,
@@ -21,12 +23,6 @@
 
 
 def trace_map(proxy_mode, func_overload, f, xs, *args):
-    def _unwrap_proxy(e):
-        if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
-            return e
-        return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
-
-
     if not isinstance(xs, torch.Tensor):
         raise ValueError("map() must loop over a tensor")
     if len(xs.shape) == 0 or xs.shape[0] == 0:
@@ -48,7 +44,7 @@
 
     proxy_mode.tracer.root.register_module(next_name, body_graph)
     node_args = (body_graph, xs, *args)
-    proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
+    proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
     out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
                                                name="map")
     outs = [body_graph(x, *args) for x in xs]
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 2c984fd..c007bdd 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -636,5 +636,19 @@
         res = gm(p, pred, xs, y)
         self.assertEqual(res, main(p, pred, xs, y))
 
+    def test_cond_with_sym_pred(self):
+        def true_fn(x):
+            return x + x
+
+        def false_fn(x):
+            return x * x
+
+        def foo(x):
+            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
+
+        gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
+        x = torch.ones(4, 3, 2)
+        self.assertEqual(foo(x), gm(x))
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index da61e78..36cfdd8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -101,6 +101,14 @@
 def snapshot_fake(val):
     return val.detach()
 
+def unwrap_proxy(proxy_mode, e):
+    if isinstance(e, torch.Tensor):
+        return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
+    elif isinstance(e, (torch.SymInt, torch.SymFloat)):
+        return get_proxy_slot(e.node, proxy_mode.tracer, e, lambda e: e())
+    else:
+        return e
+
 # What invariants do we have for the 'val' set on the FX node?  It has accurate
 # metadata... but only for metadata that exists "below" all other subsystems
 # (most notably autograd, but also vmap, functorch transforms, etc).  This means