Simplify Storage meta conversion with PyObject preservation (#122018)

Thanks to https://github.com/pytorch/pytorch/pull/109039 we can rely on
finalizers on Storage PyObject to handle removal from dict.

Irritatingly, we still have to attach finalizer, because we don't have
a weak key AND value dict (only one or the other).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122018
Approved by: https://github.com/eellison, https://github.com/kurtamohler
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 4217329..fedae64 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -937,11 +937,9 @@
         self.assertEqual(stor_id, torch._C._storage_id(y_conv))
         del x
         self.assertEqual(len(converter.tensor_memo), 1)
-        converter.meta_converter.check_for_expired_weak_storages()
         self.assertEqual(len(converter.meta_converter.storage_memo), 1)
         del y
         self.assertEqual(len(converter.tensor_memo), 0)
-        converter.meta_converter.check_for_expired_weak_storages()
         self.assertEqual(len(converter.meta_converter.storage_memo), 0)
 
 
@@ -952,11 +950,11 @@
         mode = FakeTensorMode()
         converter = FakeTensorConverter()
         x_conv = converter(mode, x)
-        x_conv_storage = torch._C._storage_id(x_conv)
+        x_conv_storage = x_conv.untyped_storage()
         del x_conv
         self.assertFalse(x in converter.tensor_memo)
         y_conv = converter(mode, y)
-        self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
+        self.assertIs(x_conv_storage, y_conv.untyped_storage())
 
     @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
     def test_dead_key(self):
diff --git a/test/test_meta.py b/test/test_meta.py
index 65e17ce..be968b7 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -294,7 +294,6 @@
         self.assertEqual(len(m.storage_memo), 1)
         del x
         self.assertEqual(len(m.tensor_memo), 0)
-        m.check_for_expired_weak_storages()
         self.assertEqual(len(m.storage_memo), 0)
         li = []
         r = []
@@ -304,7 +303,6 @@
         self.assertEqual(len(m.tensor_memo), 4)
         del li
         self.assertEqual(len(m.tensor_memo), 0)
-        m.check_for_expired_weak_storages()
         self.assertEqual(len(m.storage_memo), 0)
 
     @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 2745bac..aca1625 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -20,7 +20,6 @@
 )
 from torch._guards import Source
 
-from torch.multiprocessing.reductions import StorageWeakRef
 from torch.utils._python_dispatch import (
     is_traceable_wrapper_subclass,
     transform_subclass,
@@ -112,11 +111,8 @@
 # and tensor storages.
 class MetaConverter:
     def __init__(self):
-        self.storage_memo = {}
+        self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
         self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
-        self.maybe_storages_to_delete = []
-        self.check_expired_frequency = 128
-        self.check_expired_count = 0
         self.hit = 0
         self.miss = 0
         self.del_hook = None
@@ -125,26 +121,6 @@
     def successful(self):
         return self.hit > 0 and self.miss == 0
 
-    def check_for_expired_weak_storages(self):
-        new_li = []
-        stor_to_delete = []
-        for obj in self.maybe_storages_to_delete:
-            if not obj.expired():
-                new_li.append(obj)
-            else:
-                stor_to_delete.append(obj)
-        for obj in stor_to_delete:
-            self.storage_memo.pop(obj, None)
-        self.maybe_storages_to_delete = new_li
-
-        # if for some reason we have aquired many storages which have not expired
-        # even though a tensor with their storage has expired (aliasing or otherwise)
-        # check for expired storages less often so as to bound the amount of work we
-        # do checking for expired storages
-        self.check_expired_frequency = max(
-            self.check_expired_frequency, len(self.maybe_storages_to_delete)
-        )
-
     def get_tensor_memo(self, t):
         return self.tensor_memo.get(WeakIdRef(t), None)
 
@@ -152,10 +128,6 @@
         # hold a weak ref to self, otherwise it will be kept alive
         # by the del_ten closure
         self_weak_ref = weakref.ref(self)
-        if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
-            weak_st = None
-        else:
-            weak_st = StorageWeakRef(t._typed_storage())
         tensor_ref_key = WeakIdRef(t)
 
         def del_ten():
@@ -165,36 +137,42 @@
                 return
             # on shutdown, tensor_ref_key may not be in memo
             self_ref.tensor_memo.pop(tensor_ref_key, None)
-            if weak_st and weak_st.expired():
-                self_ref.storage_memo.pop(weak_st, None)
-            elif weak_st is not None:
-                # [expired-storages]
-                # NB: even though the tensor has died,
-                # the deallocation of its storage can take longer,
-                # even when the storage has no other uses/views.
-                # In this case, the StorageWeakRef object will be kept alive
-                # longer than it needs to be, however the storage itself
-                # will be deallocated. We retain the possibly dead storages
-                # and periodically check if any of them are expired and
-                # can be freed.
-                self_ref.maybe_storages_to_delete.append(weak_st)
 
         weakref.finalize(t, del_ten)
         self.tensor_memo[tensor_ref_key] = v
 
+    def get_storage_memo(self, s):
+        return self.storage_memo.get(WeakIdRef(s), None)
+
+    def set_storage_memo(self, s, v):
+        # hold a weak ref to self, otherwise it will be kept alive
+        # by the del_ten closure
+        self_weak_ref = weakref.ref(self)
+        storage_ref_key = WeakIdRef(s)
+
+        def del_storage():
+            # tensor outlives the converter
+            self_ref = self_weak_ref()
+            if self_ref is None:
+                return
+            # on shutdown, tensor_ref_key may not be in memo
+            self_ref.storage_memo.pop(storage_ref_key, None)
+
+        weakref.finalize(s, del_storage)
+        self.storage_memo[storage_ref_key] = v
+
     # NB: doesn't actually return a storage, because meta storage is
     # not supported
-    def meta_storage(self, s, callback):
-        # NB: TypedStorage is freshly allocated and cannot be used as hash
-        # key index.
-
+    def meta_storage(self, s: torch.UntypedStorage, callback):
         # Use a Weak Ref to s in order to not leak memory
-        swr = StorageWeakRef(s)
-        if swr not in self.storage_memo:
-            self.storage_memo[swr] = callback(
+        if self.get_storage_memo(s) is None:
+            r_s = callback(
                 lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
             ).untyped_storage()
-        return self.storage_memo[swr]
+            self.set_storage_memo(s, r_s)
+            return r_s
+        else:
+            return self.get_storage_memo(s)
 
     # This function assumes that it's possible to do the conversion
     # NB: name here is used in a conventional way by Dynamo; it corresponds
@@ -494,12 +472,6 @@
             torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
             return fake_t
 
-        # see expired-storages
-        self.check_expired_count += 1
-        if self.check_expired_count >= self.check_expired_frequency:
-            self.check_for_expired_weak_storages()
-            self.check_expired_count = 0
-
         if self.get_tensor_memo(t) is None:
             with torch.inference_mode(t.is_inference()):
                 if t.is_sparse:
@@ -816,8 +788,7 @@
                         return NotImplemented
 
                     s = t.untyped_storage()
-                    swr = StorageWeakRef(s)
-                    if swr not in self.storage_memo and (
+                    if WeakIdRef(s) not in self.storage_memo and (
                         r.is_nested
                         or (
                             r.stride() == strides
@@ -825,7 +796,7 @@
                         )
                     ):
                         # You're normal and happy, install the fresh storage into the memo
-                        self.storage_memo[swr] = r.untyped_storage()
+                        self.set_storage_memo(s, r.untyped_storage())
                     else:
                         # You're in crazy town; somehow you gave us a tensor
                         # that wasn't a view, but had nonzero storage offset,