Improve ProxyTensor tensor_tree list/tuple handling (#99897)
This PR improves the list/tuple handling by merging the logic into
`wrap_with_proxy` directly, and set_meta when we find the current
proxy is a fx.Proxy. This also solves the problem that even `fused_adam`
have `val`, some corresponding `getitem` calls followed after `fused_adam` don't have val
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99897
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 5003765..ce2c737 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -757,7 +757,7 @@
def test_fused_adam(self):
# See https://github.com/pytorch/pytorch/issues/99356
- params = [torch.randn(10, 10, requires_grad=True) for _ in range(10)]
+ params = [torch.randn(10, 10) for _ in range(10)]
grads = [torch.randn(10, 10) for _ in range(10)]
exp_avgs = [torch.randn(10, 10) for _ in range(10)]
exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
@@ -765,7 +765,7 @@
state_steps = [torch.tensor(0) for _ in range(10)]
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
- return aten._fused_adam.default(
+ (new_params, _, _, _, _) = aten._fused_adam.default(
params,
grads,
exp_avgs,
@@ -781,6 +781,11 @@
maximize=False,
)
+ for p, new_p in zip(params, new_params):
+ p.copy_(new_p)
+
+ return params
+
gm = make_fx(fused_adam, tracing_mode='fake')(
params,
grads,
@@ -789,8 +794,9 @@
max_exp_avg_sqs,
state_steps,
)
+ ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
for n in gm.graph.nodes:
- if n.op == "call_function" and n.target == aten._fused_adam.default:
+ if n.op == "call_function" and n.target in ensure_ops_have_val:
self.assertIn('val', n.meta)
def test_alias(self):
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index dab0e2e..0032f1d 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -190,27 +190,22 @@
# NB: eagerly set meta here, so that the numbering is in order
set_meta(proxy, e)
set_proxy_slot(e.node, tracer, lambda: proxy)
- elif isinstance(e, list):
+ elif isinstance(e, (tuple, list)):
+ if isinstance(proxy, fx.Proxy):
+ set_meta(proxy, e)
+
# example use case: allreduce_ returns ([tensor], work)
for idx, ee in enumerate(e):
wrap_with_proxy(ee, proxy[idx], get_constant(idx))
+
def get_constant(idx):
if constant is None:
return None
else:
return constant[idx]
- # Unfortunately, tree_map cannot directly be used here. As the resulting
- # object may be a proxy that represents a tuple, we may need to
- # explicitly unwrap the proxy by simulating the flattening operations.
- if isinstance(inner_res, (tuple, list)):
- if isinstance(proxy_res, fx.Proxy):
- set_meta(proxy_res, inner_res)
- for idx, e in enumerate(inner_res):
- wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
- elif isinstance(inner_res, py_sym_types + (torch.Tensor,)):
- wrap_with_proxy(inner_res, proxy_res, constant)
+ wrap_with_proxy(inner_res, proxy_res, constant)
return inner_res