Add PyObject preservation for UntypedStorage (#97470)
Part of #91395
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97470
Approved by: https://github.com/ezyang
diff --git a/build_variables.bzl b/build_variables.bzl
index 4fa0b0d..97c4ab7 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -888,6 +888,7 @@
"torch/csrc/utils/python_dispatch.cpp",
"torch/csrc/utils/python_symnode.cpp",
"torch/csrc/utils/pybind.cpp",
+ "torch/csrc/utils/pyobject_preservation.cpp",
"torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp",
diff --git a/c10/core/RefcountedDeleter.cpp b/c10/core/RefcountedDeleter.cpp
new file mode 100644
index 0000000..44555cf
--- /dev/null
+++ b/c10/core/RefcountedDeleter.cpp
@@ -0,0 +1,79 @@
+#include <c10/core/RefcountedDeleter.h>
+
+#include <mutex>
+
+namespace c10 {
+
+void refcounted_deleter(void* ctx_) {
+ RefcountedDeleterContext& ctx =
+ *reinterpret_cast<RefcountedDeleterContext*>(ctx_);
+ ctx.refcount--;
+ if (ctx.refcount == 0) {
+ ctx.other_ctx = nullptr;
+ delete &ctx;
+ }
+}
+
+std::mutex replace_data_ptr_mutex;
+
+void maybeApplyRefcountedDeleter(c10::Storage storage) {
+ std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
+ c10::DataPtr& data_ptr = storage.mutable_data_ptr();
+
+ if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
+ // Data pointer is already shared
+ return;
+ }
+
+ void* data = data_ptr.get();
+ void* other_ctx = data_ptr.get_context();
+ c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
+ c10::Device device = data_ptr.device();
+
+ // Release the context of the original DataPtr so that the data doesn't
+ // get deleted when the original DataPtr is replaced
+ data_ptr.release_context();
+
+ c10::RefcountedDeleterContext* refcount_ctx =
+ new c10::RefcountedDeleterContext(other_ctx, other_deleter);
+
+ c10::DataPtr new_data_ptr(
+ data,
+ reinterpret_cast<void*>(refcount_ctx),
+ &c10::refcounted_deleter,
+ device);
+ storage.set_data_ptr(std::move(new_data_ptr));
+}
+
+c10::Storage newStorageImplFromRefcountedDataPtr(c10::Storage storage) {
+ c10::maybeApplyRefcountedDeleter(storage);
+
+ c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
+
+ c10::DataPtr& data_ptr = storage.mutable_data_ptr();
+ c10::DataPtr new_data_ptr(
+ data_ptr.get(),
+ data_ptr.get_context(),
+ data_ptr.get_deleter(),
+ data_ptr.device());
+
+ // NOTE: This refcount increment should always happen immediately after
+ // `new_data_ptr` is created. No other lines of code should be added between
+ // them in the future, unless there's a very good reason for it, because if
+ // any errors are raised and `new_data_ptr` is deleted before the refcount is
+ // incremented, the refcount will get decremented and end up being one less
+ // than it should be.
+ reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
+ ->refcount++;
+
+ c10::Allocator* allocator = c10::GetAllocator(storage_impl->device_type());
+ c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ storage_impl->nbytes(),
+ allocator,
+ /*resizable=*/storage_impl->resizable());
+ new_storage.set_data_ptr(std::move(new_data_ptr));
+ return new_storage;
+}
+
+} // namespace c10
diff --git a/c10/core/RefcountedDeleter.h b/c10/core/RefcountedDeleter.h
new file mode 100644
index 0000000..afd3978
--- /dev/null
+++ b/c10/core/RefcountedDeleter.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include <c10/core/Storage.h>
+#include <c10/util/UniqueVoidPtr.h>
+
+#include <atomic>
+#include <memory>
+
+namespace c10 {
+
+// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
+// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
+// this custom context and the `refcounted_deleter` function below to make the
+// DataPtr act like a non-unique DataPtr. This context object holds onto an
+// inner context and deleter function which handle the actual deletion of the
+// data when the refcount reaches 0.
+//
+// This shared DataPtr feature is only used when storages are shared between
+// multiple Python interpreters in MultiPy. Before storages had PyObject
+// preservation, interpreters could just share the same StorageImpl instance.
+// But now a StorageImpl can only be associated with one interpreter in order
+// to properly manage a zombie PyObject. So we share storages across Python
+// interpreters by creating a different StorageImpl instance for each one, but
+// they all point to the same data.
+struct C10_API RefcountedDeleterContext {
+ RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
+ : other_ctx(other_ctx, other_deleter), refcount(1) {}
+
+ std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
+ std::atomic_int refcount;
+};
+
+// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
+// a shared DataPtr.
+//
+// Warning: This should only be called on a pointer to
+// a RefcountedDeleterContext that was allocated on the heap with `new`,
+// because when the refcount reaches 0, the context is deleted with `delete`
+C10_API void refcounted_deleter(void* ctx_);
+
+// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
+// a DataPtr that does, so it can be shared between multiple StorageImpls
+C10_API void maybeApplyRefcountedDeleter(c10::Storage storage);
+
+// Create a new StorageImpl that points to the same data. If the original
+// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
+// with one that does
+C10_API c10::Storage newStorageImplFromRefcountedDataPtr(c10::Storage storage);
+
+} // namespace c10
diff --git a/c10/core/SafePyObject.h b/c10/core/SafePyObject.h
index 932c32e..da5a7c5 100644
--- a/c10/core/SafePyObject.h
+++ b/c10/core/SafePyObject.h
@@ -29,7 +29,7 @@
SafePyObject& operator=(SafePyObject const&) = delete;
~SafePyObject() {
- (*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
+ (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
c10::impl::PyInterpreter& pyinterpreter() const {
diff --git a/c10/core/Storage.cpp b/c10/core/Storage.cpp
index 1361c81..d968339 100644
--- a/c10/core/Storage.cpp
+++ b/c10/core/Storage.cpp
@@ -1,3 +1,18 @@
+#include <c10/core/RefcountedDeleter.h>
#include <c10/core/Storage.h>
-namespace c10 {} // namespace c10
+namespace c10 {
+
+bool isSharedStorageAlias(const Storage& storage0, const Storage& storage1) {
+ c10::DeleterFnPtr deleter_expected = &c10::refcounted_deleter;
+ c10::DeleterFnPtr deleter0 = storage0.data_ptr().get_deleter();
+ c10::DeleterFnPtr deleter1 = storage1.data_ptr().get_deleter();
+
+ if ((deleter0 != deleter_expected) || (deleter1 != deleter_expected)) {
+ return false;
+ }
+
+ return storage0.data_ptr().get_context() == storage1.data_ptr().get_context();
+}
+
+} // namespace c10
diff --git a/c10/core/Storage.h b/c10/core/Storage.h
index 3b586ab..d812f8f 100644
--- a/c10/core/Storage.h
+++ b/c10/core/Storage.h
@@ -1,12 +1,22 @@
#pragma once
#include <c10/core/StorageImpl.h>
+#include <c10/util/ExclusivelyOwned.h>
namespace c10 {
+struct Storage;
+
+C10_API bool isSharedStorageAlias(
+ const Storage& storage0,
+ const Storage& storage1);
+
struct C10_API Storage {
public:
struct use_byte_size_t {};
+ struct unsafe_borrow_t {
+ explicit unsafe_borrow_t() = default;
+ };
Storage() = default;
Storage(c10::intrusive_ptr<StorageImpl> ptr)
@@ -40,6 +50,14 @@
allocator,
resizable)) {}
+ protected:
+ explicit Storage(unsafe_borrow_t, const Storage& rhs)
+ : storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
+ rhs.storage_impl_.get())) {}
+
+ friend MaybeOwnedTraits<Storage>;
+
+ public:
// Legacy constructor for partially initialized (dtype or memory) storages
// that can be temporarily created with Caffe2 APIs. See the note on top of
// TensorImpl.h for details.
@@ -144,7 +162,9 @@
}
bool is_alias_of(const Storage& other) const {
- return storage_impl_ == other.storage_impl_;
+ return (
+ storage_impl_ == other.storage_impl_ ||
+ isSharedStorageAlias(*this, other));
}
void UniqueStorageShareExternalPointer(
@@ -175,4 +195,67 @@
c10::intrusive_ptr<StorageImpl> storage_impl_;
};
+template <>
+struct MaybeOwnedTraits<c10::Storage> {
+ using owned_type = c10::Storage;
+ using borrow_type = c10::Storage;
+
+ static borrow_type createBorrow(const owned_type& from) {
+ return borrow_type(borrow_type::unsafe_borrow_t{}, from);
+ }
+
+ static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+ lhs.unsafeReleaseStorageImpl();
+ lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
+ }
+
+ static void destroyBorrow(borrow_type& toDestroy) {
+ toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
+ }
+
+ static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+ return borrow;
+ }
+
+ static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+ return &borrow;
+ }
+
+ static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
+ return true;
+ }
+};
+
+template <>
+struct ExclusivelyOwnedTraits<c10::Storage> {
+ using repr_type = c10::Storage;
+ using pointer_type = c10::Storage*;
+ using const_pointer_type = const c10::Storage*;
+
+ static repr_type nullRepr() {
+ return c10::Storage();
+ }
+
+ template <class... Args>
+ static repr_type createInPlace(Args&&... args) {
+ return c10::Storage(std::forward<Args>(args)...);
+ }
+
+ static repr_type moveToRepr(c10::Storage&& x) {
+ return std::move(x);
+ }
+
+ static c10::Storage take(c10::Storage& x) {
+ return std::move(x);
+ }
+
+ static pointer_type getImpl(repr_type& x) {
+ return &x;
+ }
+
+ static const_pointer_type getImpl(const repr_type& x) {
+ return &x;
+ }
+};
+
} // namespace c10
diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h
index 71b978a..04affff 100644
--- a/c10/core/StorageImpl.h
+++ b/c10/core/StorageImpl.h
@@ -205,6 +205,14 @@
return received_cuda_;
}
+ impl::PyObjectSlot* pyobj_slot() {
+ return &pyobj_slot_;
+ }
+
+ const impl::PyObjectSlot* pyobj_slot() const {
+ return &pyobj_slot_;
+ }
+
private:
DataPtr data_ptr_;
SymInt size_bytes_;
diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp
index 539cf5c..5e6ba98 100644
--- a/c10/core/TensorImpl.cpp
+++ b/c10/core/TensorImpl.cpp
@@ -73,9 +73,7 @@
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
}
-TensorImpl::~TensorImpl() {
- pyobj_slot_.destroy_pyobj_if_needed();
-}
+TensorImpl::~TensorImpl() = default;
TensorImpl::TensorImpl(
Storage&& storage,
@@ -582,7 +580,7 @@
if (storage_) {
storage_ = {};
}
- pyobj_slot_.destroy_pyobj_if_needed();
+ pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp
index 63c432e..81b2894 100644
--- a/c10/core/impl/PyInterpreter.cpp
+++ b/c10/core/impl/PyInterpreter.cpp
@@ -10,7 +10,8 @@
return "<unloaded interpreter>";
}
- void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
+ void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
+ } // do nothing
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h
index b36b26f..aaf41c1 100644
--- a/c10/core/impl/PyInterpreter.h
+++ b/c10/core/impl/PyInterpreter.h
@@ -127,8 +127,8 @@
virtual std::string name() const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
- // See NOTE [PyInterpreter::decref takes an `is_tensor` arg]
- virtual void decref(PyObject* pyobj, bool is_tensor) const = 0;
+ // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
+ virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this
diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp
index 3fc5670..b0012a6 100644
--- a/c10/core/impl/PyObjectSlot.cpp
+++ b/c10/core/impl/PyObjectSlot.cpp
@@ -5,12 +5,16 @@
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
-void PyObjectSlot::destroy_pyobj_if_needed() {
+PyObjectSlot::~PyObjectSlot() {
+ maybe_destroy_pyobj();
+}
+
+void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
- ->decref(_unchecked_untagged_pyobj(), /*is_tensor*/ true);
+ ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
@@ -47,6 +51,14 @@
(*pyobj_interpreter_.load())->name());
}
+bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
+ return interpreter == pyobj_interpreter();
+}
+
+bool PyObjectSlot::has_pyobj() {
+ return check_pyobj(pyobj_interpreter()).has_value();
+}
+
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h
index 3973cb7..2b69113 100644
--- a/c10/core/impl/PyObjectSlot.h
+++ b/c10/core/impl/PyObjectSlot.h
@@ -14,7 +14,9 @@
public:
PyObjectSlot();
- void destroy_pyobj_if_needed();
+ ~PyObjectSlot();
+
+ void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
@@ -118,6 +120,13 @@
PyInterpreter& load_pyobj_interpreter() const;
+ // Check if the PyObjectSlot's interpreter is the same as the specified
+ // interpreter
+ bool check_interpreter(PyInterpreter* interpreter);
+
+ // Check if the PyObjectSlot is holding a PyObject, owned or non-owned
+ bool has_pyobj();
+
bool owns_pyobj();
void set_owns_pyobj(bool b);
diff --git a/test/test_torch.py b/test/test_torch.py
index 89dd3da..a46aed2 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -8658,6 +8658,16 @@
T()
+ def test_storage_base_init(self):
+ # Direct construction not OK
+ self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
+
+ # But construction of subclass is OK
+ class T(torch._C.StorageBase):
+ pass
+
+ T()
+
def test_tensor_base_new(self):
# OK to call super().__new__, see
@@ -8670,6 +8680,18 @@
x = torch.ones(5)
test_tensor = TestTensor(x)
+ def test_storage_base_new(self):
+
+ # OK to call super().__new__, see
+ # https://github.com/pytorch/pytorch/issues/57421
+ class TestStorage(torch._C.StorageBase):
+ @staticmethod
+ def __new__(cls, x, *args, **kwargs):
+ return super().__new__(cls, x, *args, **kwargs)
+
+ x = torch.UntypedStorage(5)
+ test_storage = TestStorage(x)
+
def test_pyobj_preserved(self):
x = torch.empty(2)
x.foo = 2 # put something on __dict__
@@ -8694,6 +8716,160 @@
del z # it's dead again
self.assertEqual(type(y.grad), MyTensor)
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_dealloc(self):
+ m, t = Tracker.make()
+ s0 = torch.UntypedStorage(10)
+ s1 = s0
+ s0._tracker = t
+ del t
+
+ self.assertFalse(m[0])
+ del s0
+ self.assertFalse(m[0])
+ del s1
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_from_tensor_dealloc(self):
+ m, t = Tracker.make()
+ a = torch.randn(10)
+ s0 = a.untyped_storage()
+ s0._tracker = t
+ del t
+
+ s1 = a.untyped_storage()
+ self.assertTrue(s0 is s1)
+ self.assertTrue(hasattr(s1, '_tracker'))
+
+ del a
+
+ self.assertFalse(m[0])
+ del s0
+ self.assertFalse(m[0])
+ del s1
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_from_tensor_dealloc_zombie(self):
+ m, t = Tracker.make()
+ a = torch.randn(10)
+ s0 = a.untyped_storage()
+ s0._tracker = t
+ del t
+
+ s1 = a.untyped_storage()
+ self.assertTrue(s0 is s1)
+ self.assertTrue(hasattr(s1, '_tracker'))
+
+ self.assertFalse(m[0])
+ del s0
+ self.assertFalse(m[0])
+ del s1
+ self.assertFalse(m[0])
+ del a
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_from_tensor_dealloc_resurrected(self):
+ m, t = Tracker.make()
+ a = torch.randn(10)
+ s0 = a.untyped_storage()
+ s0._tracker = t
+ del t
+
+ s1 = a.untyped_storage()
+ self.assertTrue(s0 is s1)
+ self.assertTrue(hasattr(s1, '_tracker'))
+
+ self.assertFalse(m[0])
+ del s0
+ self.assertFalse(m[0])
+ del s1
+ self.assertFalse(m[0])
+
+ s0 = a.untyped_storage()
+ self.assertTrue(isinstance(s0, torch.UntypedStorage))
+
+ del a
+ self.assertFalse(m[0])
+ del s0
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_dealloc_resurrected(self):
+ m, t = Tracker.make()
+ s = torch.UntypedStorage(10)
+ s._tracker = t
+ del t
+
+ a = torch.tensor(s)
+ self.assertFalse(m[0])
+ del s
+
+ self.assertFalse(m[0])
+
+ s = a.untyped_storage()
+ self.assertTrue(isinstance(s, torch.UntypedStorage))
+
+ del a
+ self.assertFalse(m[0])
+ del s
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_dealloc_subclass_zombie(self):
+ class MyStorage(torch.UntypedStorage):
+ finalized_count = 0
+
+ def __del__(self):
+ MyStorage.finalized_count += 1
+
+ m, t = Tracker.make()
+ s = MyStorage(10)
+ s._tracker = t
+ del t
+
+ a = torch.tensor(s)
+ self.assertFalse(m[0])
+ del s
+
+ self.assertEqual(MyStorage.finalized_count, 0)
+ self.assertFalse(m[0])
+
+ del a
+ self.assertEqual(MyStorage.finalized_count, 1)
+ self.assertTrue(m[0])
+
+ @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
+ def test_storage_dealloc_subclass_resurrected(self):
+ class MyStorage(torch.UntypedStorage):
+ finalized_count = 0
+
+ def __del__(self):
+ MyStorage.finalized_count += 1
+
+ m, t = Tracker.make()
+ s = MyStorage(10)
+ s._tracker = t
+ del t
+
+ a = torch.tensor(s)
+ self.assertFalse(m[0])
+ del s
+
+ self.assertEqual(MyStorage.finalized_count, 0)
+ self.assertFalse(m[0])
+
+ s = a.untyped_storage()
+ del a
+ self.assertFalse(m[0])
+ self.assertEqual(MyStorage.finalized_count, 0)
+ self.assertTrue(isinstance(s, MyStorage))
+ del s
+ self.assertEqual(MyStorage.finalized_count, 1)
+ self.assertTrue(m[0])
+
def test_tensor_slot_dealloc(self):
class SlotTensor1(torch._C._TensorBase):
@@ -8715,6 +8891,27 @@
self.assertTrue(m1[0])
self.assertTrue(m2[0])
+ def test_storage_slot_dealloc(self):
+
+ class SlotStorage1(torch._C.StorageBase):
+ __slots__ = ['slot1']
+
+ class SlotStorage2(SlotStorage1):
+ __slots__ = ['slot2']
+
+ m1, t1 = Tracker.make()
+ m2, t2 = Tracker.make()
+ slot_storage = SlotStorage2(torch.UntypedStorage(2))
+ slot_storage.slot1 = t1
+ slot_storage.slot2 = t2
+ del t1
+ del t2
+ self.assertFalse(m1[0])
+ self.assertFalse(m2[0])
+ del slot_storage
+ self.assertTrue(m1[0])
+ self.assertTrue(m2[0])
+
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tensor_dict_dealloc(self):
m, t = Tracker.make()
@@ -8725,6 +8922,16 @@
del x
self.assertTrue(m[0])
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
+ def test_storage_dict_dealloc(self):
+ m, t = Tracker.make()
+ x = torch.UntypedStorage(2)
+ x.arf = t
+ del t
+ self.assertFalse(m[0])
+ del x
+ self.assertTrue(m[0])
+
def test_tensor_finalizer_dealloc(self):
m = [False]
@@ -8737,9 +8944,20 @@
del fin_tensor
self.assertTrue(m[0])
+ def test_storage_finalizer_dealloc(self):
+ m = [False]
+
+ class FinalizerStorage(torch._C.StorageBase):
+ def __del__(self):
+ m[0] = True
+
+ fin_storage = FinalizerStorage(torch.UntypedStorage(2))
+ self.assertFalse(m[0])
+ del fin_storage
+ self.assertTrue(m[0])
+
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_tensor_weakref_dealloc(self):
-
x = torch.empty(2)
m = [False]
@@ -8751,6 +8969,20 @@
self.assertTrue(m[0])
self.assertEqual(wref(), None)
+ @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
+ def test_storage_weakref_dealloc(self):
+
+ x = torch.UntypedStorage(2)
+ m = [False]
+
+ def cb(r):
+ m[0] = True
+
+ wref = weakref.ref(x, cb)
+ del x
+ self.assertTrue(m[0])
+ self.assertEqual(wref(), None)
+
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tensor_cycle_via_dict(self):
m1, t1 = Tracker.make()
@@ -8794,6 +9026,49 @@
self.assertTrue(m1[0])
self.assertTrue(m2[0])
+ @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
+ def test_storage_cycle_via_dict(self):
+ m1, t1 = Tracker.make()
+ x = torch.UntypedStorage(2)
+ x._tracker = t1
+ del t1
+
+ m2, t2 = Tracker.make()
+ y = torch.UntypedStorage(2)
+ y._tracker = t2
+ del t2
+
+ x._loop = y
+ y._loop = x
+
+ # C++ reference should keep the cycle live!
+ # This exercise THPVariable_subtype_traverse
+ # NB: Because z.grad is a reference done entirely in C++, cycles
+ # involving it directly are NOT broken by Python GC; you've
+ # set up a good old C++ reference cycle which we cannot safely
+ # break (because C++ references are allowed to be accessed
+ # multithreaded-ly) (TODO: except maybe if you can prove that
+ # only Python has access to the C++ object, in which case you can
+ # also prove that no multithreaded access occurs)
+ z = torch.UntypedStorage(2)
+ z.grad = x
+
+ del x
+ del y
+
+ gc.collect()
+ self.assertFalse(m1[0])
+ self.assertFalse(m2[0])
+
+ with disable_gc():
+ del z
+ self.assertFalse(m1[0])
+ self.assertFalse(m2[0])
+
+ gc.collect()
+ self.assertTrue(m1[0])
+ self.assertTrue(m2[0])
+
def test_tensor_cycle_via_slots(self):
m1 = [False]
m2 = [False]
@@ -8826,6 +9101,38 @@
self.assertTrue(m1[0])
self.assertTrue(m2[0])
+ def test_storage_cycle_via_slots(self):
+ m1 = [False]
+ m2 = [False]
+
+ class SlotStorage1(torch._C.StorageBase):
+ __slots__ = ['slot1']
+
+ def __del__(self):
+ m1[0] = True
+
+ class SlotStorage2(SlotStorage1):
+ __slots__ = ['slot2']
+
+ def __del__(self):
+ m2[0] = True
+
+ x = SlotStorage1(torch.UntypedStorage(2))
+ y = SlotStorage2(torch.UntypedStorage(2))
+
+ x.slot1 = y
+ y.slot2 = x
+
+ del x
+ with disable_gc():
+ del y
+ self.assertFalse(m1[0])
+ self.assertFalse(m2[0])
+
+ gc.collect()
+ self.assertTrue(m1[0])
+ self.assertTrue(m2[0])
+
# FIXME: move to test_autograd?
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_backward_hooks_traverse(self):
@@ -8854,7 +9161,7 @@
self.assertTrue(m2[0])
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
- def test_dead_weak_ref(self):
+ def test_tensor_dead_weak_ref(self):
x = torch.empty(2)
w_x = weakref.ref(x)
y = torch.empty(2)
@@ -8870,7 +9177,24 @@
self.assertRaises(RuntimeError, lambda: x.sigmoid())
- def test_resurrected_weak_ref(self):
+ @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
+ def test_storage_dead_weak_ref(self):
+ x = torch.UntypedStorage(2)
+ w_x = weakref.ref(x)
+ y = torch.tensor(x)
+ del x
+
+ x = w_x()
+ # Ideally, x would keep the storage live. But CPython doesn't
+ # provide enough hooks to do this. So it will go dead and x
+ # will transmute into storage with null StorageImpl. Not great, but the
+ # best we can do.
+ del y
+
+ self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
+ self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
+
+ def test_tensor_resurrected_weak_ref(self):
x = torch.empty(2)
w_x = weakref.ref(x)
y = torch.empty(2)
@@ -8883,8 +9207,20 @@
del y
x.sigmoid()
+ def test_storage_resurrected_weak_ref(self):
+ x = torch.UntypedStorage(2)
+ w_x = weakref.ref(x)
+ y = torch.tensor(x)
+ del x
+
+ x = w_x()
+ # Use this to manually fix weak reference after dereferencing them
+ x._fix_weakref()
+ del y
+ x.float()
+
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
- def test_fix_weakref_no_leak(self):
+ def test_tensor_fix_weakref_no_leak(self):
import weakref
called = False
@@ -8900,6 +9236,23 @@
self.assertTrue(called)
+ @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
+ def test_storage_fix_weakref_no_leak(self):
+ import weakref
+
+ called = False
+
+ a = torch.UntypedStorage(1)
+
+ def callback(w):
+ nonlocal called
+ called = True
+ wa = weakref.ref(a, callback)
+ a._fix_weakref()
+ del a
+
+ self.assertTrue(called)
+
# FIXME: move to test_linalg
@torch.inference_mode()
def test_bmm_multithreaded(self):
diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp
index e62d908..cb85098 100644
--- a/torch/csrc/DynamicTypes.cpp
+++ b/torch/csrc/DynamicTypes.cpp
@@ -30,29 +30,6 @@
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
layout_registry = {};
-at::DeprecatedTypeProperties* get_type_properties(
- at::DeviceType device_type,
- at::ScalarType scalarType) {
- at::Backend backend;
- if (device_type == at::kCPU) {
- backend = at::Backend::CPU;
- } else if (device_type == at::kCUDA) {
- backend = at::Backend::CUDA;
- } else if (device_type == at::kXPU) {
- backend = at::Backend::XPU;
- } else if (device_type == at::kHPU) {
- backend = at::Backend::HPU;
- } else if (device_type == at::kMPS) {
- backend = at::Backend::MPS;
- } else if (device_type == at::DeviceType::Meta) {
- backend = at::Backend::Undefined;
- } else if (device_type == at::DeviceType::PrivateUse1) {
- backend = at::Backend::PrivateUse1;
- } else {
- TORCH_CHECK(false, "Invalid device for storage: ", device_type);
- }
- return &at::getDeprecatedTypeProperties(backend, scalarType);
-}
} // namespace
void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
@@ -90,13 +67,10 @@
false,
"python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
}
- PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
- auto obj = THPObjectPtr(type->tp_alloc(type, 0));
+ PyObject* obj = THPStorage_Wrap(storage);
if (!obj)
throw python_error();
- ((THPStorage*)obj.get())->cdata =
- c10::MaybeOwned<at::Storage>::owned(at::Storage(/* copy */ storage));
- return obj.release();
+ return obj;
}
PyTypeObject* loadTypedStorageTypeObject() {
@@ -120,15 +94,14 @@
if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
return true;
}
- auto obj_type = Py_TYPE(obj);
-
- return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
+ return THPStorage_Check(obj);
}
-at::Storage createStorageGetType(
- PyObject* obj,
- at::ScalarType& scalar_type,
- bool& is_typed_storage) {
+std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
+ PyObject* obj) {
+ at::ScalarType scalar_type;
+ bool is_typed_storage;
+
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
PyObject* untyped_storage_obj;
@@ -138,10 +111,9 @@
// stay nonzero since the `TypedStorage` maintains a reference.
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
TORCH_INTERNAL_ASSERT(dtype_obj);
- Py_DECREF(dtype_obj);
-
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
+ Py_DECREF(dtype_obj);
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
@@ -152,22 +124,18 @@
untyped_storage_obj = obj;
}
- if (Py_TYPE(untyped_storage_obj) !=
- reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
- throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
- }
+ TORCH_CHECK(
+ THPStorage_Check(untyped_storage_obj),
+ "not a storage '",
+ Py_TYPE(obj)->tp_name,
+ "'");
- const auto& storage = THPStorage_Unpack(untyped_storage_obj);
- c10::DeviceType device_type = storage.device().type();
- auto type_properties = get_type_properties(device_type, at::kByte);
- return type_properties->unsafeStorageFromTH(
- storage.unsafeGetStorageImpl(), true);
+ auto storage = THPStorage_Unpack(untyped_storage_obj);
+ return std::make_tuple(storage, scalar_type, is_typed_storage);
}
at::Storage createStorage(PyObject* obj) {
- at::ScalarType scalar_type;
- bool is_typed_storage = false;
- return createStorageGetType(obj, scalar_type, is_typed_storage);
+ return std::get<0>(createStorageGetType(obj));
}
} // namespace torch
diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h
index 1ca813c..1fd0a9d 100644
--- a/torch/csrc/DynamicTypes.h
+++ b/torch/csrc/DynamicTypes.h
@@ -27,10 +27,8 @@
TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage);
at::Storage createStorage(PyObject* obj);
-at::Storage createStorageGetType(
- PyObject* obj,
- at::ScalarType& scalar_type,
- bool& is_typed_storage);
+std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
+ PyObject* obj);
bool isStorage(PyObject* obj);
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp
index 709175f..a846a17 100644
--- a/torch/csrc/PyInterpreter.cpp
+++ b/torch/csrc/PyInterpreter.cpp
@@ -35,7 +35,7 @@
: public c10::impl::PyInterpreterVTable {
std::string name() const override;
- void decref(PyObject* pyobj, bool is_tensor) const override;
+ void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl
@@ -188,15 +188,15 @@
TorchFunctionName::TorchDispatch));
}
-// NOTE [PyInterpreter::decref takes an `is_tensor` arg]
+// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the
-// pyobj is a Tensor or not.
-// - If it is a tensor, we need to be careful about PyObject resurrection
-// - If it is not a tensor, we can freely decref
+// pyobj has a PyObjectSlot or not.
+// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
+// - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
-void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
+void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const {
// Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned)
@@ -206,23 +206,33 @@
pybind11::gil_scoped_acquire gil;
// Two possibilities:
- // 1. We are decref-ing a tensor. Then we must be careful about
- // PyObject resurrection (this only applies to Tensors, see
+ // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
+ // Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
- if (is_tensor && Py_REFCNT(pyobj) > 1) {
- // It's still alive! This can happen if a weak ref resurrected
- // the PyObject without flipping ownership. At this point it is
- // too late to rescue the object, so just stub out the PyObject
- // so that it fails on subsequent uses. Don't raise an error here;
- // you're probably in a destructor.
- TORCH_WARN(
- "Deallocating Tensor that still has live PyObject references. "
- "This probably happened because you took out a weak reference to "
- "Tensor and didn't call _fix_weakref() after dereferencing it. "
- "Subsequent accesses to this tensor via the PyObject will now fail.");
- ((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>();
+ if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
+ if (THPVariable_Check(pyobj)) {
+ // It's still alive! This can happen if a weak ref resurrected
+ // the PyObject without flipping ownership. At this point it is
+ // too late to rescue the object, so just stub out the PyObject
+ // so that it fails on subsequent uses. Don't raise an error here;
+ // you're probably in a destructor.
+ TORCH_WARN(
+ "Deallocating Tensor that still has live PyObject references. "
+ "This probably happened because you took out a weak reference to "
+ "Tensor and didn't call _fix_weakref() after dereferencing it. "
+ "Subsequent accesses to this tensor via the PyObject will now fail.");
+ ((THPVariable*)pyobj)->cdata =
+ c10::MaybeOwned<torch::autograd::Variable>();
+ } else if (THPStorage_Check(pyobj)) {
+ TORCH_WARN(
+ "Deallocating UntypedStorage that still has live PyObject references. "
+ "This probably happened because you took out a weak reference to "
+ "UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
+ "Subsequent accesses to this storage via the PyObject will now fail.");
+ ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
+ }
}
Py_DECREF(pyobj);
};
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
index dfd35ea..1652d4d 100644
--- a/torch/csrc/Storage.cpp
+++ b/torch/csrc/Storage.cpp
@@ -6,6 +6,7 @@
#include <ATen/mps/MPSDevice.h>
#include <c10/core/CPUAllocator.h>
+#include <c10/core/RefcountedDeleter.h>
#include <libshm.h>
#include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/Device.h>
@@ -15,6 +16,7 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/copy_utils.h>
+#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <c10/util/intrusive_ptr.h>
@@ -27,28 +29,231 @@
}
}
-PyObject* THPStorageClass = nullptr;
+PyTypeObject* THPStorageClass = nullptr;
-PyObject* THPStorage_New(c10::Storage storage) {
- PyTypeObject* type = (PyTypeObject*)THPStorageClass;
- PyObject* obj = type->tp_alloc(type, 0);
- if (obj) {
- ((THPStorage*)obj)->cdata =
- c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
+PyObject* THPStorage_NewWithStorage(
+ PyTypeObject* type,
+ c10::Storage _storage,
+ c10::impl::PyInterpreterStatus status,
+ bool allow_preexisting_pyobj) {
+ TORCH_CHECK(
+ PyType_IsSubtype(type, &THPStorageType),
+ "Creating a Storage subclass from a class that does not inherit from ",
+ "Storage is not possible. Make sure your class inherits from Storage.");
+
+ auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
+ getPyInterpreter());
+ if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
+ TORCH_CHECK(
+ allow_preexisting_pyobj,
+ "Creating a new Storage subclass ",
+ type->tp_name,
+ " but the raw Storage object is already associated to a python object ",
+ "of type ",
+ maybe_pyobj.value()->ob_type->tp_name);
+ PyObject* obj = *maybe_pyobj;
+ PyTypeObject* obj_type = Py_TYPE(obj);
+ TORCH_CHECK(
+ obj_type == type || PyType_IsSubtype(obj_type, type),
+ "Creating a new Storage subclass ",
+ type->tp_name,
+ " but the raw Storage object is already associated to a python object ",
+ "of type ",
+ maybe_pyobj.value()->ob_type->tp_name,
+ " which is not a subclass of the "
+ "requested type");
+ return THPStorage_Wrap(std::move(_storage));
}
+
+ PyObject* obj = type->tp_alloc(type, 0);
+ TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
+
+ auto s = (THPStorage*)obj;
+
+ new (&s->cdata) c10::MaybeOwned<c10::Storage>();
+
+ s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
+
+ if (!c10::impl::HermeticPyObjectTLS::get_state()) {
+ const auto& storage = THPStorage_Unpack(s);
+ storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
+ getPyInterpreter(), obj, status);
+ }
+
return obj;
}
+// Wraps the c10::Storage with a storage PyObject
+PyObject* THPStorage_Wrap(c10::Storage storage) {
+ c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
+ if (c10::impl::HermeticPyObjectTLS::get_state()) {
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ std::move(storage),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
+ }
+ c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
+
+ // If the StorageImpl has a PyObject that is managed by a different
+ // interpreter than the current one, create a new StorageImpl that points to
+ // the same data and then create the Python storage from that.
+ // NOTE: This is only supposed to happen in MultiPy
+ if (pyobj_slot->has_pyobj() &&
+ !pyobj_slot->check_interpreter(getPyInterpreter())) {
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ c10::newStorageImplFromRefcountedDataPtr(storage),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
+ }
+ c10::optional<PyObject*> maybe_pyobj =
+ pyobj_slot->check_pyobj(getPyInterpreter());
+ c10::impl::PyInterpreterStatus status;
+ if (maybe_pyobj.has_value()) {
+ auto obj = *maybe_pyobj;
+ if (obj) {
+ if (pyobj_slot->owns_pyobj()) {
+ pyobj_slot->set_owns_pyobj(false);
+ reinterpret_cast<THPStorage*>(obj)->cdata =
+ c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
+ return obj;
+ } else {
+ Py_INCREF(obj);
+ return obj;
+ }
+ }
+ status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
+ } else {
+ if (storage.use_count() <= 1) {
+ status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
+ } else {
+ status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
+ }
+ }
+ return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
+}
+
+static bool THPStorage_isPreservable(THPStorage* self) {
+ if (self->cdata.unsafeIsBorrowed()) {
+ return false;
+ }
+ auto const& storage = THPStorage_Unpack(self);
+
+ if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
+ getPyInterpreter()) != c10::make_optional((PyObject*)self)) {
+ return false;
+ }
+ if (storage.use_count() <= 1) {
+ return false;
+ }
+ return true;
+}
+
+static bool THPStorage_tryPreserve(THPStorage* self) {
+ const auto& storage = THPStorage_Unpack(self);
+
+ if (!THPStorage_isPreservable(self)) {
+ return false;
+ }
+
+ c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
+
+ TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
+
+ storage_impl->pyobj_slot()->set_owns_pyobj(true);
+ Py_INCREF(self);
+
+ self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
+ return true;
+}
+
static void THPStorage_subclass_dealloc(PyObject* self) {
THPStorage* _self = (THPStorage*)self;
- // Some subclass of StorageBase are GC-tracked objects even
- // though the base class is not.
+
+ if (THPStorage_tryPreserve(_self)) {
+ return;
+ }
+
+ // Some subclass of StorageBase could be GC-tracked objects even
+ // though the base class is not
auto* type = Py_TYPE(self);
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
PyObject_GC_UnTrack(self);
}
+
+ bool has_finalizer = type->tp_finalize || type->tp_del;
+
+ if (type->tp_finalize) {
+ PyObject_GC_Track(self);
+ if (PyObject_CallFinalizerFromDealloc(self) < 0) {
+ // The finalizer has resurrected the PyObject and there is a new Python
+ // reference to it, so we can just stop deallocating. Read about
+ // resurrection from `__del__` here:
+ // https://docs.python.org/3/reference/datamodel.html#object.__del__
+ return;
+ }
+ PyObject_GC_UnTrack(self);
+ }
+
+ // base test is unnecessary as THPStorae does not set this
+ if (type->tp_weaklistoffset) {
+ PyObject_ClearWeakRefs(self);
+ }
+
+ if (type->tp_del) {
+ PyObject_GC_Track(self);
+ type->tp_del(self);
+ if (self->ob_refcnt > 0) {
+ // Resurrected (see above comment about resurrection from `__del__`)
+ return;
+ }
+ PyObject_GC_UnTrack(self);
+ }
+
+ if (has_finalizer) {
+ /* New weakrefs could be created during the finalizer call.
+ If this occurs, clear them out without calling their
+ finalizers since they might rely on part of the object
+ being finalized that has already been destroyed. */
+ if (type->tp_weaklistoffset) {
+ /* Modeled after GET_WEAKREFS_LISTPTR() */
+ PyWeakReference** list =
+ (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
+ while (*list)
+ _PyWeakref_ClearRef(*list);
+ }
+ }
+
+ // Clear slots
+ {
+ PyTypeObject* base = type;
+ while (base != &THPStorageType) {
+ if (Py_SIZE(base)) {
+ clear_slots(base, self);
+ }
+ base = base->tp_base;
+ TORCH_INTERNAL_ASSERT(base);
+ }
+ }
+
+ // Clear __dict__
+ if (C10_LIKELY(type->tp_dictoffset)) {
+ PyObject** dictptr = _PyObject_GetDictPtr(self);
+ if (dictptr != nullptr) {
+ PyObject* dict = *dictptr;
+ if (dict != nullptr) {
+ Py_DECREF(dict);
+ *dictptr = nullptr;
+ }
+ }
+ }
+
+ TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
+
_self->cdata.~MaybeOwned<c10::Storage>();
Py_TYPE(_self)->tp_free(self);
+
+ TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
+ Py_DECREF(type);
}
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
@@ -149,32 +354,35 @@
"(): only one or neither of 'allocator' or 'device' can ",
"be given, but not both");
- THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0));
- THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
+ PyObject* self = nullptr;
c10::Allocator* allocator = nullptr;
// torch.Storage(*, ...)
if (r.idx == 0) {
- self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
- c10::StorageImpl::use_byte_size_t(),
- 0,
- allocator,
- /*resizable=*/true,
- allocator_opt,
- device_opt));
- return (PyObject*)self.release();
+ self = THPStorage_NewWithStorage(
+ type,
+ make_storage_impl(
+ c10::StorageImpl::use_byte_size_t(),
+ 0,
+ allocator,
+ /*resizable=*/true,
+ allocator_opt,
+ device_opt),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
// torch.Storage(size, *, ...)
} else if (r.idx == 1) {
int64_t size = r.toInt64(0);
- self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
- c10::StorageImpl::use_byte_size_t(),
- size,
- allocator,
- /*resizable=*/true,
- allocator_opt,
- device_opt));
- return (PyObject*)self.release();
+ self = THPStorage_NewWithStorage(
+ type,
+ make_storage_impl(
+ c10::StorageImpl::use_byte_size_t(),
+ size,
+ allocator,
+ /*resizable=*/true,
+ allocator_opt,
+ device_opt),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
// torch.Storage(sequence, *, ...)
} else if (r.idx == 2) {
@@ -190,20 +398,23 @@
THPStorageStr,
"(): Could not obtain the length of sequence of type ",
THPUtils_typename(sequence));
- self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
- c10::StorageImpl::use_byte_size_t(),
- length,
- allocator,
- /*resizable=*/true,
- allocator_opt,
- device_opt));
+ self = THPStorage_NewWithStorage(
+ type,
+ make_storage_impl(
+ c10::StorageImpl::use_byte_size_t(),
+ length,
+ allocator,
+ /*resizable=*/true,
+ allocator_opt,
+ device_opt),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
THPObjectPtr item;
try {
+ const auto& storage = THPStorage_Unpack(self);
for (Py_ssize_t i = 0; i < length; i++) {
item = PySequence_GetItem(sequence, i);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint8_t value = THPByteUtils_unpackReal(item.get());
- const auto& storage = THPStorage_Unpack(self);
if (allocator == c10::GetDefaultCPUAllocator()) {
static_cast<uint8_t*>(storage.mutable_data())[i] = value;
} else {
@@ -220,20 +431,22 @@
THPUtils_typename(item.get()));
return nullptr;
}
- return (PyObject*)self.release();
}
+ return self;
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static Py_ssize_t THPStorage_length(THPStorage* self) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
return THPStorage_Unpack(self).nbytes();
END_HANDLE_TH_ERRORS_RET(-1)
}
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
/* Integer index */
if (THPUtils_checkLong(index)) {
@@ -289,7 +502,11 @@
old_storage_impl->allocator(),
/* resizable */ false);
- PyObject* _ret = THPStorage_New(std::move(new_storage_impl));
+ PyObject* _ret = THPStorage_NewWithStorage(
+ Py_TYPE(self),
+ std::move(new_storage_impl),
+ c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
+
return _ret;
}
PyErr_Format(
@@ -302,6 +519,7 @@
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
if (!THPByteUtils_checkReal(value)) {
THPUtils_setError(
"can only set storage content with a int types, but got "
@@ -450,6 +668,7 @@
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
return THPDevice_New(THPStorage_Unpack(self).device());
END_HANDLE_TH_ERRORS
}
@@ -489,7 +708,17 @@
}
void THPStorage_postInit(PyObject* module) {
- THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
+ THPStorageClass =
+ (PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
if (!THPStorageClass)
throw python_error();
}
+
+void THPStorage_assertNotNull(THPStorage* storage) {
+ TORCH_CHECK(
+ THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
+}
+
+void THPStorage_assertNotNull(PyObject* obj) {
+ THPStorage_assertNotNull((THPStorage*)obj);
+}
diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h
index 794e2dd..2a0911f 100644
--- a/torch/csrc/Storage.h
+++ b/torch/csrc/Storage.h
@@ -5,84 +5,43 @@
#define THPStorageStr "torch.UntypedStorage"
-namespace c10 {
-
-template <>
-struct MaybeOwnedTraits<c10::Storage> {
- using owned_type = c10::Storage;
- using borrow_type = c10::Storage;
-
- static borrow_type createBorrow(const owned_type& from) {
- return borrow_type(from);
- }
-
- static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
- lhs.unsafeReleaseStorageImpl();
- lhs = borrow_type(rhs);
- }
-
- static void destroyBorrow(borrow_type& toDestroy) {
- toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
- }
-
- static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
- return borrow;
- }
-
- static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
- return &borrow;
- }
-
- static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
- return true;
- }
-};
-
-template <>
-struct ExclusivelyOwnedTraits<c10::Storage> {
- using repr_type = c10::Storage;
- using pointer_type = c10::Storage*;
- using const_pointer_type = const c10::Storage*;
-
- static repr_type nullRepr() {
- return c10::Storage();
- }
-
- template <class... Args>
- static repr_type createInPlace(Args&&... args) {
- return c10::Storage(std::forward<Args>(args)...);
- }
-
- static repr_type moveToRepr(c10::Storage&& x) {
- return std::move(x);
- }
-
- static c10::Storage take(c10::Storage& x) {
- return std::move(x);
- }
-
- static pointer_type getImpl(repr_type& x) {
- return &x;
- }
-
- static const_pointer_type getImpl(const repr_type& x) {
- return &x;
- }
-};
-
-} // namespace c10
-
struct THPStorage {
PyObject_HEAD;
c10::MaybeOwned<c10::Storage> cdata;
};
-TORCH_PYTHON_API PyObject* THPStorage_New(c10::Storage storage);
-extern PyObject* THPStorageClass;
+TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
+TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
+ PyTypeObject* type,
+ c10::Storage _storage,
+ c10::impl::PyInterpreterStatus status,
+ bool allow_preexisting_pyobj = false);
+extern PyTypeObject* THPStorageClass;
+
+static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
+ return tp == THPStorageClass;
+}
+
+static inline bool THPStorage_CheckExact(PyObject* obj) {
+ return THPStorage_CheckTypeExact(Py_TYPE(obj));
+}
+
+inline bool THPStorage_Check(PyObject* obj) {
+ if (!THPStorageClass)
+ return false;
+
+ const auto result = PyObject_IsInstance(obj, (PyObject*)THPStorageClass);
+ if (result == -1)
+ throw python_error();
+ return result;
+}
bool THPStorage_init(PyObject* module);
void THPStorage_postInit(PyObject* module);
+void THPStorage_assertNotNull(THPStorage* storage);
+void THPStorage_assertNotNull(PyObject* obj);
+
extern PyTypeObject THPStorageType;
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp
index b1c66ae..3bdad7a 100644
--- a/torch/csrc/StorageMethods.cpp
+++ b/torch/csrc/StorageMethods.cpp
@@ -41,6 +41,7 @@
static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr();
END_HANDLE_TH_ERRORS
}
@@ -58,6 +59,7 @@
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
at::Storage self_ = torch::createStorage(self);
@@ -82,12 +84,14 @@
static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(_self);
return THPUtils_packInt64(sizeof(uint8_t));
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
c10::Allocator* allocator = THPStorage_Unpack(self).allocator();
auto new_storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
@@ -96,12 +100,13 @@
/*resizable=*/true);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- return THPStorage_New(std::move(new_storage));
+ return THPStorage_Wrap(std::move(new_storage));
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
THPUtils_assert(
THPUtils_checkLong(number_arg),
@@ -164,6 +169,7 @@
static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
THPUtils_assert(
THPByteUtils_checkReal(number_arg),
@@ -360,7 +366,7 @@
}
PyBuffer_Release(&buffer);
- return (PyObject*)THPStorage_New(storage);
+ return THPStorage_Wrap(storage);
END_HANDLE_TH_ERRORS
}
@@ -400,12 +406,16 @@
storage->set_nbytes(actual_nbytes);
}
- return (PyObject*)THPStorage_New(std::move(storage));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ std::move(storage),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS
}
PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
PyObject* file = PyTuple_GetItem(args, 0);
bool is_real_file = PyTuple_GetItem(args, 1) == Py_True;
@@ -451,12 +461,13 @@
auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
if (!storage.defined())
return nullptr;
- return THPStorage_New(std::move(storage));
+ return THPStorage_Wrap(std::move(storage));
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
PyObject* file = PyTuple_GET_ITEM(args, 0);
PyObject* offset = PyTuple_GET_ITEM(args, 1);
@@ -538,6 +549,12 @@
END_HANDLE_TH_ERRORS
}
+static PyObject* THPStorage_fix_weakref(PyObject* self, PyObject* noargs) {
+ const auto& storage = THPStorage_Unpack(self);
+ Py_DECREF(THPStorage_Wrap(storage));
+ Py_RETURN_NONE;
+}
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStorage_methods[] = {
{"copy_",
@@ -565,6 +582,7 @@
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
+ {"_fix_weakref", THPStorage_fix_weakref, METH_NOARGS, nullptr},
{nullptr}};
PyMethodDef* THPStorage_getMethods() {
diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp
index cae1d0e..c257bf9 100644
--- a/torch/csrc/StorageSharing.cpp
+++ b/torch/csrc/StorageSharing.cpp
@@ -33,6 +33,7 @@
static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
c10::DeviceType device_type = storage.device_type();
if (device_type == at::kCPU) {
@@ -49,6 +50,7 @@
static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
c10::DeviceType device_type = storage.device_type();
if (device_type == at::kCPU) {
@@ -74,17 +76,21 @@
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
std::string handle = at::NewProcessWideShmHandle();
- return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
- c10::StorageImpl::use_byte_size_t(),
- size,
- THManagedMapAllocator::makeDataPtr("", handle.c_str(), flags, size),
- /*allocator=*/nullptr,
- /*resizable=*/false));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ c10::make_intrusive<at::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ size,
+ THManagedMapAllocator::makeDataPtr("", handle.c_str(), flags, size),
+ /*allocator=*/nullptr,
+ /*resizable=*/false),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK(
storage.device_type() == at::kCPU,
@@ -165,13 +171,16 @@
const char* object_handle = PyBytes_AS_STRING(_object_handle);
int64_t size = THPUtils_unpackLong(_size);
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE;
- return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
- c10::StorageImpl::use_byte_size_t(),
- size,
- THManagedMapAllocator::makeDataPtr(
- manager_handle, object_handle, flags, size),
- /*allocator=*/nullptr,
- /*resizable=*/false));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ c10::make_intrusive<at::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ size,
+ THManagedMapAllocator::makeDataPtr(
+ manager_handle, object_handle, flags, size),
+ /*allocator=*/nullptr,
+ /*resizable=*/false),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS
}
@@ -182,12 +191,16 @@
if (!PyArg_ParseTuple(args, "L", &size)) {
return nullptr;
}
- return THPStorage_New(at::new_shm_fd_storage(size));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ at::new_shm_fd_storage(size),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK(
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
@@ -254,17 +267,22 @@
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE |
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD;
- return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
- c10::StorageImpl::use_byte_size_t(),
- size,
- at::MapAllocator::makeDataPtr(at::WITH_FD, "", fd, flags, size, nullptr),
- /*allocator=*/nullptr,
- /*resizable=*/false));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ c10::make_intrusive<at::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ size,
+ at::MapAllocator::makeDataPtr(
+ at::WITH_FD, "", fd, flags, size, nullptr),
+ /*allocator=*/nullptr,
+ /*resizable=*/false),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS
}
static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
#ifdef USE_CUDA
const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK(
@@ -541,7 +559,10 @@
base->set_resizable(false);
base->set_received_cuda(true);
- return THPStorage_New(std::move(base));
+ return THPStorage_NewWithStorage(
+ THPStorageClass,
+ std::move(base),
+ c10::impl::PyInterpreterStatus::TAGGED_BY_US);
#else
TORCH_CHECK(false, "CUDA is not available");
#endif
@@ -566,7 +587,7 @@
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
- return THPStorage_New(
+ return THPStorage_Wrap(
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
}
Py_RETURN_NONE;
@@ -598,6 +619,7 @@
PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
+ THPStorage_assertNotNull(self);
at::MapAllocator* ctx = nullptr;
const auto& storage = THPStorage_Unpack(self);
if (storage.device_type() == at::kCPU) {
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index f5ecf37..e30e93a 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -26,6 +26,7 @@
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
+#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/utils/python_strings.h>
@@ -1662,26 +1663,6 @@
END_HANDLE_TH_ERRORS
}
-static void clear_slots(PyTypeObject* type, PyObject* self) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Py_ssize_t i, n;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- PyMemberDef* mp;
-
- n = Py_SIZE(type);
- mp = type->tp_members;
- for (i = 0; i < n; i++, mp++) {
- if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
- char* addr = (char*)self + mp->offset;
- PyObject* obj = *(PyObject**)addr;
- if (obj != nullptr) {
- *(PyObject**)addr = nullptr;
- Py_DECREF(obj);
- }
- }
- }
-}
-
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
// on subclasses. It's never valid to construct a THPVariable so it's not
// necessary to implement the dealloc for that case
diff --git a/torch/csrc/utils/pyobject_preservation.cpp b/torch/csrc/utils/pyobject_preservation.cpp
new file mode 100644
index 0000000..4f2d0a2
--- /dev/null
+++ b/torch/csrc/utils/pyobject_preservation.cpp
@@ -0,0 +1,19 @@
+#include <torch/csrc/utils/pyobject_preservation.h>
+
+#include <structmember.h>
+
+void clear_slots(PyTypeObject* type, PyObject* self) {
+ Py_ssize_t n = Py_SIZE(type);
+ PyMemberDef* mp = type->tp_members;
+
+ for (Py_ssize_t i = 0; i < n; i++, mp++) {
+ if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
+ char* addr = (char*)self + mp->offset;
+ PyObject* obj = *(PyObject**)addr;
+ if (obj != nullptr) {
+ *(PyObject**)addr = nullptr;
+ Py_DECREF(obj);
+ }
+ }
+ }
+}
diff --git a/torch/csrc/utils/pyobject_preservation.h b/torch/csrc/utils/pyobject_preservation.h
new file mode 100644
index 0000000..456095d
--- /dev/null
+++ b/torch/csrc/utils/pyobject_preservation.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include <torch/csrc/python_headers.h>
+
+// This file contains utilities used for handling PyObject preservation
+
+void clear_slots(PyTypeObject* type, PyObject* self);
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index ee41d79..68dba42 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -1055,8 +1055,8 @@
is_typed_storage = false;
storage_scalar_type = at::ScalarType::Undefined;
} else {
- storage =
- createStorageGetType(args[i], storage_scalar_type, is_typed_storage);
+ std::tie(storage, storage_scalar_type, is_typed_storage) =
+ createStorageGetType(args[i]);
}
return storage;
}
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 57388fb..c154404 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -385,10 +385,12 @@
at::tracer::impl::NoTracerDispatchMode tracer_guard;
if (isStorage(data)) {
- ScalarType storage_scalar_type{ScalarType::Undefined};
bool is_typed_storage = false;
- Storage storage =
- createStorageGetType(data, storage_scalar_type, is_typed_storage);
+ ScalarType storage_scalar_type{ScalarType::Undefined};
+ Storage storage;
+ std::tie(storage, storage_scalar_type, is_typed_storage) =
+ createStorageGetType(data);
+
TORCH_CHECK(
!is_typed_storage || storage_scalar_type == scalar_type,
"Expected a Storage of type ",