Refactor thunkify to return proper thunk abstraction (#132407)
This is superior to lru_cache because (1) it's more explicit and (2) it
doesn't leak the original function after it's been forced.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132407
Approved by: https://github.com/albanD
ghstack dependencies: #131649
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index d1e3256..ae35fc8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -44,7 +44,7 @@
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _WeakHashRef
from typing import (
- Any, Callable, Dict, List, Optional, Tuple, Union, Mapping, Sequence,
+ Any, Callable, Dict, List, Optional, Tuple, Union, Mapping, Sequence, Generic,
TypeVar, Generator, Protocol, overload, Type, TYPE_CHECKING)
from typing_extensions import Concatenate, ParamSpec, Self
from weakref import WeakKeyDictionary
@@ -174,7 +174,25 @@
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
-_PySymProxyType = Callable[[], Proxy]
+class Thunk(Generic[R]):
+ f: Optional[Callable[[], R]]
+ r: Optional[R]
+
+ __slots__ = ['f', 'r']
+
+ def __init__(self, f: Callable[[], R]):
+ self.f = f
+ self.r = None
+
+ def force(self) -> R:
+ if self.f is None:
+ return self.r # type: ignore[return-value]
+ self.r = self.f()
+ self.f = None
+ return self.r
+
+
+_PySymProxyType = Thunk[Proxy]
@overload
@@ -341,12 +359,12 @@
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
return proxy
-def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Callable[[], R]:
+def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
"""
Delays computation of f until it's called again
Also caches the result
"""
- return functools.lru_cache(1)(functools.partial(f, *args, **kwargs))
+ return Thunk(functools.partial(f, *args, **kwargs))
def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None:
def try_set_proxy_slot(
@@ -407,7 +425,7 @@
assert isinstance(proxy, Proxy)
# NB: eagerly set meta here, so that the numbering is in order
set_meta(proxy, e)
- set_proxy_slot(e, tracer, lambda: proxy)
+ set_proxy_slot(e, tracer, thunkify(lambda: proxy))
elif isinstance(e, _AnyScriptObject):
assert isinstance(proxy, Proxy)
set_proxy_slot(e, tracer, proxy)
@@ -484,7 +502,7 @@
else:
assert isinstance(e, py_sym_types)
# NB: we REQUIRE all symints to be tracked
- return get_proxy_slot(e, tracer)()
+ return get_proxy_slot(e, tracer).force()
return inner
@overload
@@ -850,7 +868,7 @@
if isinstance(e, Tensor):
return get_proxy_slot(e, self, e, lambda x: x.proxy)
elif isinstance(e, py_sym_types):
- return get_proxy_slot(e, self, e, lambda e: e())
+ return get_proxy_slot(e, self, e, lambda e: e.force())
elif isinstance(e, _AnyScriptObject):
return get_proxy_slot(e, self, e)
else:
@@ -930,7 +948,7 @@
)
def get_sym_proxy_slot(t: PySymType) -> Proxy:
- return get_proxy_slot(t, tracer)()
+ return get_proxy_slot(t, tracer).force()
out = pytree.tree_map_only(
py_sym_types,
@@ -1120,7 +1138,7 @@
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
n_args = tuple(
- get_proxy_slot(a, self.tracer)().node if isinstance(a, py_sym_types) else a
+ get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
for a in args
)
@@ -1160,7 +1178,6 @@
# were symbolic) and it is no longer necessary to trace the
# computation. This could occur if func triggered some guards.
if isinstance(out, py_sym_types):
- # Delays tracing out the proxies on this op until we actually need it
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
set_proxy_slot(out, self.tracer, p_out_thunk)