Improve ProxyTensor tensor_tree list/tuple handling (#99897)

This PR improves the list/tuple handling by merging the logic into
`wrap_with_proxy` directly, and set_meta when we find the current
proxy is a fx.Proxy. This also solves the problem that even `fused_adam`
have `val`, some corresponding `getitem` calls followed after `fused_adam` don't have val
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99897
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 5003765..ce2c737 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -757,7 +757,7 @@
 
     def test_fused_adam(self):
         # See https://github.com/pytorch/pytorch/issues/99356
-        params = [torch.randn(10, 10, requires_grad=True) for _ in range(10)]
+        params = [torch.randn(10, 10) for _ in range(10)]
         grads = [torch.randn(10, 10) for _ in range(10)]
         exp_avgs = [torch.randn(10, 10) for _ in range(10)]
         exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
@@ -765,7 +765,7 @@
         state_steps = [torch.tensor(0) for _ in range(10)]
 
         def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
-            return aten._fused_adam.default(
+            (new_params, _, _, _, _) = aten._fused_adam.default(
                 params,
                 grads,
                 exp_avgs,
@@ -781,6 +781,11 @@
                 maximize=False,
             )
 
+            for p, new_p in zip(params, new_params):
+                p.copy_(new_p)
+
+            return params
+
         gm = make_fx(fused_adam, tracing_mode='fake')(
             params,
             grads,
@@ -789,8 +794,9 @@
             max_exp_avg_sqs,
             state_steps,
         )
+        ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
         for n in gm.graph.nodes:
-            if n.op == "call_function" and n.target == aten._fused_adam.default:
+            if n.op == "call_function" and n.target in ensure_ops_have_val:
                 self.assertIn('val', n.meta)
 
     def test_alias(self):
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index dab0e2e..0032f1d 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -190,27 +190,22 @@
             # NB: eagerly set meta here, so that the numbering is in order
             set_meta(proxy, e)
             set_proxy_slot(e.node, tracer, lambda: proxy)
-        elif isinstance(e, list):
+        elif isinstance(e, (tuple, list)):
+            if isinstance(proxy, fx.Proxy):
+                set_meta(proxy, e)
+
             # example use case: allreduce_ returns ([tensor], work)
             for idx, ee in enumerate(e):
                 wrap_with_proxy(ee, proxy[idx], get_constant(idx))
 
+
     def get_constant(idx):
         if constant is None:
             return None
         else:
             return constant[idx]
 
-    # Unfortunately, tree_map cannot directly be used here. As the resulting
-    # object may be a proxy that represents a tuple, we may need to
-    # explicitly unwrap the proxy by simulating the flattening operations.
-    if isinstance(inner_res, (tuple, list)):
-        if isinstance(proxy_res, fx.Proxy):
-            set_meta(proxy_res, inner_res)
-        for idx, e in enumerate(inner_res):
-            wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
-    elif isinstance(inner_res, py_sym_types + (torch.Tensor,)):
-        wrap_with_proxy(inner_res, proxy_res, constant)
+    wrap_with_proxy(inner_res, proxy_res, constant)
 
     return inner_res