Refactor thunkify to return proper thunk abstraction (#132407)

This is superior to lru_cache because (1) it's more explicit and (2) it
doesn't leak the original function after it's been forced.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132407
Approved by: https://github.com/albanD
ghstack dependencies: #131649
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index d1e3256..ae35fc8 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -44,7 +44,7 @@
 from torch.utils._traceback import CapturedTraceback
 from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _WeakHashRef
 from typing import (
-    Any, Callable, Dict, List, Optional, Tuple, Union, Mapping, Sequence,
+    Any, Callable, Dict, List, Optional, Tuple, Union, Mapping, Sequence, Generic,
     TypeVar, Generator, Protocol, overload, Type, TYPE_CHECKING)
 from typing_extensions import Concatenate, ParamSpec, Self
 from weakref import WeakKeyDictionary
@@ -174,7 +174,25 @@
     return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
 
 
-_PySymProxyType = Callable[[], Proxy]
+class Thunk(Generic[R]):
+    f: Optional[Callable[[], R]]
+    r: Optional[R]
+
+    __slots__ = ['f', 'r']
+
+    def __init__(self, f: Callable[[], R]):
+        self.f = f
+        self.r = None
+
+    def force(self) -> R:
+        if self.f is None:
+            return self.r  # type: ignore[return-value]
+        self.r = self.f()
+        self.f = None
+        return self.r
+
+
+_PySymProxyType = Thunk[Proxy]
 
 
 @overload
@@ -341,12 +359,12 @@
         proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
     return proxy
 
-def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Callable[[], R]:
+def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
     """
     Delays computation of f until it's called again
     Also caches the result
     """
-    return functools.lru_cache(1)(functools.partial(f, *args, **kwargs))
+    return Thunk(functools.partial(f, *args, **kwargs))
 
 def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None:
     def try_set_proxy_slot(
@@ -407,7 +425,7 @@
             assert isinstance(proxy, Proxy)
             # NB: eagerly set meta here, so that the numbering is in order
             set_meta(proxy, e)
-            set_proxy_slot(e, tracer, lambda: proxy)
+            set_proxy_slot(e, tracer, thunkify(lambda: proxy))
         elif isinstance(e, _AnyScriptObject):
             assert isinstance(proxy, Proxy)
             set_proxy_slot(e, tracer, proxy)
@@ -484,7 +502,7 @@
         else:
             assert isinstance(e, py_sym_types)
             # NB: we REQUIRE all symints to be tracked
-            return get_proxy_slot(e, tracer)()
+            return get_proxy_slot(e, tracer).force()
     return inner
 
 @overload
@@ -850,7 +868,7 @@
         if isinstance(e, Tensor):
             return get_proxy_slot(e, self, e, lambda x: x.proxy)
         elif isinstance(e, py_sym_types):
-            return get_proxy_slot(e, self, e, lambda e: e())
+            return get_proxy_slot(e, self, e, lambda e: e.force())
         elif isinstance(e, _AnyScriptObject):
             return get_proxy_slot(e, self, e)
         else:
@@ -930,7 +948,7 @@
         )
 
         def get_sym_proxy_slot(t: PySymType) -> Proxy:
-            return get_proxy_slot(t, tracer)()
+            return get_proxy_slot(t, tracer).force()
 
         out = pytree.tree_map_only(
             py_sym_types,
@@ -1120,7 +1138,7 @@
 
     def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
         n_args = tuple(
-            get_proxy_slot(a, self.tracer)().node if isinstance(a, py_sym_types) else a
+            get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
             for a in args
         )
 
@@ -1160,7 +1178,6 @@
         # were symbolic) and it is no longer necessary to trace the
         # computation.  This could occur if func triggered some guards.
         if isinstance(out, py_sym_types):
-            # Delays tracing out the proxies on this op until we actually need it
             p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
             set_proxy_slot(out, self.tracer, p_out_thunk)