Simplify Storage meta conversion with PyObject preservation (#122018)
Thanks to https://github.com/pytorch/pytorch/pull/109039 we can rely on
finalizers on Storage PyObject to handle removal from dict.
Irritatingly, we still have to attach finalizer, because we don't have
a weak key AND value dict (only one or the other).
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122018
Approved by: https://github.com/eellison, https://github.com/kurtamohler
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 4217329..fedae64 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -937,11 +937,9 @@
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
self.assertEqual(len(converter.tensor_memo), 1)
- converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
self.assertEqual(len(converter.tensor_memo), 0)
- converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@@ -952,11 +950,11 @@
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter(mode, x)
- x_conv_storage = torch._C._storage_id(x_conv)
+ x_conv_storage = x_conv.untyped_storage()
del x_conv
self.assertFalse(x in converter.tensor_memo)
y_conv = converter(mode, y)
- self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
+ self.assertIs(x_conv_storage, y_conv.untyped_storage())
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_dead_key(self):
diff --git a/test/test_meta.py b/test/test_meta.py
index 65e17ce..be968b7 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -294,7 +294,6 @@
self.assertEqual(len(m.storage_memo), 1)
del x
self.assertEqual(len(m.tensor_memo), 0)
- m.check_for_expired_weak_storages()
self.assertEqual(len(m.storage_memo), 0)
li = []
r = []
@@ -304,7 +303,6 @@
self.assertEqual(len(m.tensor_memo), 4)
del li
self.assertEqual(len(m.tensor_memo), 0)
- m.check_for_expired_weak_storages()
self.assertEqual(len(m.storage_memo), 0)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 2745bac..aca1625 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -20,7 +20,6 @@
)
from torch._guards import Source
-from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
@@ -112,11 +111,8 @@
# and tensor storages.
class MetaConverter:
def __init__(self):
- self.storage_memo = {}
+ self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
- self.maybe_storages_to_delete = []
- self.check_expired_frequency = 128
- self.check_expired_count = 0
self.hit = 0
self.miss = 0
self.del_hook = None
@@ -125,26 +121,6 @@
def successful(self):
return self.hit > 0 and self.miss == 0
- def check_for_expired_weak_storages(self):
- new_li = []
- stor_to_delete = []
- for obj in self.maybe_storages_to_delete:
- if not obj.expired():
- new_li.append(obj)
- else:
- stor_to_delete.append(obj)
- for obj in stor_to_delete:
- self.storage_memo.pop(obj, None)
- self.maybe_storages_to_delete = new_li
-
- # if for some reason we have aquired many storages which have not expired
- # even though a tensor with their storage has expired (aliasing or otherwise)
- # check for expired storages less often so as to bound the amount of work we
- # do checking for expired storages
- self.check_expired_frequency = max(
- self.check_expired_frequency, len(self.maybe_storages_to_delete)
- )
-
def get_tensor_memo(self, t):
return self.tensor_memo.get(WeakIdRef(t), None)
@@ -152,10 +128,6 @@
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
- if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
- weak_st = None
- else:
- weak_st = StorageWeakRef(t._typed_storage())
tensor_ref_key = WeakIdRef(t)
def del_ten():
@@ -165,36 +137,42 @@
return
# on shutdown, tensor_ref_key may not be in memo
self_ref.tensor_memo.pop(tensor_ref_key, None)
- if weak_st and weak_st.expired():
- self_ref.storage_memo.pop(weak_st, None)
- elif weak_st is not None:
- # [expired-storages]
- # NB: even though the tensor has died,
- # the deallocation of its storage can take longer,
- # even when the storage has no other uses/views.
- # In this case, the StorageWeakRef object will be kept alive
- # longer than it needs to be, however the storage itself
- # will be deallocated. We retain the possibly dead storages
- # and periodically check if any of them are expired and
- # can be freed.
- self_ref.maybe_storages_to_delete.append(weak_st)
weakref.finalize(t, del_ten)
self.tensor_memo[tensor_ref_key] = v
+ def get_storage_memo(self, s):
+ return self.storage_memo.get(WeakIdRef(s), None)
+
+ def set_storage_memo(self, s, v):
+ # hold a weak ref to self, otherwise it will be kept alive
+ # by the del_ten closure
+ self_weak_ref = weakref.ref(self)
+ storage_ref_key = WeakIdRef(s)
+
+ def del_storage():
+ # tensor outlives the converter
+ self_ref = self_weak_ref()
+ if self_ref is None:
+ return
+ # on shutdown, tensor_ref_key may not be in memo
+ self_ref.storage_memo.pop(storage_ref_key, None)
+
+ weakref.finalize(s, del_storage)
+ self.storage_memo[storage_ref_key] = v
+
# NB: doesn't actually return a storage, because meta storage is
# not supported
- def meta_storage(self, s, callback):
- # NB: TypedStorage is freshly allocated and cannot be used as hash
- # key index.
-
+ def meta_storage(self, s: torch.UntypedStorage, callback):
# Use a Weak Ref to s in order to not leak memory
- swr = StorageWeakRef(s)
- if swr not in self.storage_memo:
- self.storage_memo[swr] = callback(
+ if self.get_storage_memo(s) is None:
+ r_s = callback(
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
).untyped_storage()
- return self.storage_memo[swr]
+ self.set_storage_memo(s, r_s)
+ return r_s
+ else:
+ return self.get_storage_memo(s)
# This function assumes that it's possible to do the conversion
# NB: name here is used in a conventional way by Dynamo; it corresponds
@@ -494,12 +472,6 @@
torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
return fake_t
- # see expired-storages
- self.check_expired_count += 1
- if self.check_expired_count >= self.check_expired_frequency:
- self.check_for_expired_weak_storages()
- self.check_expired_count = 0
-
if self.get_tensor_memo(t) is None:
with torch.inference_mode(t.is_inference()):
if t.is_sparse:
@@ -816,8 +788,7 @@
return NotImplemented
s = t.untyped_storage()
- swr = StorageWeakRef(s)
- if swr not in self.storage_memo and (
+ if WeakIdRef(s) not in self.storage_memo and (
r.is_nested
or (
r.stride() == strides
@@ -825,7 +796,7 @@
)
):
# You're normal and happy, install the fresh storage into the memo
- self.storage_memo[swr] = r.untyped_storage()
+ self.set_storage_memo(s, r.untyped_storage())
else:
# You're in crazy town; somehow you gave us a tensor
# that wasn't a view, but had nonzero storage offset,