Reduce overhead in CUDAGraph Trees (#98529)
Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl).
This PR takes care of all of the lower hanging fruit.
- Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage
- Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98529
Approved by: https://github.com/jansel, https://github.com/ngimel
diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py
index a0f501c..9736389 100644
--- a/test/inductor/test_cudagraph_trees.py
+++ b/test/inductor/test_cudagraph_trees.py
@@ -43,6 +43,10 @@
)
+def cdata(t):
+ return t.untyped_storage()._cdata
+
+
class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
@@ -127,13 +131,15 @@
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]
- def cudagraphify_impl(self, *args, **kwargs):
+ def cudagraphify_impl(
+ self, *args, is_inference=True, is_backward=False, **kwargs
+ ):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
- is_backward=False,
- is_inference=True,
+ is_inference=is_inference,
+ is_backward=is_backward,
)
@staticmethod
@@ -418,7 +424,7 @@
self.assertEqual(all_live_block_count(), 0)
def test_aliased_storage_single_weakref(self):
- @torch.compile
+ @torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
@@ -447,6 +453,47 @@
self.assertFalse(self.get_manager().new_graph_id().id == 0)
+ def test_aliasing_static_ref(self):
+ class Mod(torch.nn.Linear):
+ def forward(self, x):
+ return self.weight.T @ x, self.weight.T, self.weight[0:4]
+
+ m = Mod(10, 10).cuda()
+
+ @torch.compile(mode="reduce-overhead")
+ def foo(mod, x):
+ return mod(x)
+
+ @torch.compile(mode="reduce-overhead")
+ def foo2(x):
+ return x[2:]
+
+ x = torch.rand([10, 10], device="cuda", requires_grad=True)
+ param_c = cdata(m.weight)
+ for _ in range(3):
+ # print("Runnng foo")
+ out1, alias_1, alias_2 = foo(m, x)
+ self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
+
+ # print("Runnng foo2")
+ out2 = foo2(out1)
+ out2.sum().backward()
+ self.assertEqual(cdata(out1), cdata(out2))
+
+ def test_aliased_static_parameter(self):
+ inp = torch.rand([20, 20], device="cuda")
+
+ def foo(args):
+ x = args[0]
+ args.clear()
+ return (x[0],)
+
+ foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
+
+ for _ in range(3):
+ out = foo_cg([inp])[0]
+ self.assertEqual(cdata(inp), cdata(out))
+
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
@@ -580,7 +627,6 @@
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
- print("Input ID", id(inp))
out = foo(inp)
out.sum().backward()
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index e1133a9..7358157 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -324,6 +324,8 @@
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
"""
+ __slots__ = ["ref", "_data_ptr"]
+
storage_ref: Optional[StorageWeakRef]
def __init__(self, inp: Union[Tensor, UntypedStorage]):
@@ -335,8 +337,12 @@
self.ref = StorageWeakRef(stor)
self._data_ptr = stor.data_ptr()
- # NB: only use cdata for debugging, use ref
- self._cdata = self.ref.cdata
+ @classmethod
+ def from_weakref_and_data_ptr(cls, cdata, data_ptr):
+ instance = cls.__new__(cls)
+ instance._data_ptr = data_ptr
+ instance.ref = StorageWeakRef.from_weakref(cdata)
+ return instance
def __call__(self) -> Optional[StorageWeakRefPointer]:
if self.ref is None:
@@ -515,6 +521,43 @@
LevelList = List # levels (distance from root of tree)
+class OutputAliasInfo:
+ pass
+
+
+class UnaliasedStorage(OutputAliasInfo):
+ "Singleton to mark that the graph output constructs a new alias or is None"
+ pass
+
+
+class PersistentStaticStorage(OutputAliasInfo):
+ "Singleton to mark that the graph output storage will be in output_persistent_storage array"
+ pass
+
+
+class AliasesPriorGraphOutput(OutputAliasInfo):
+ "Marks that the graph output aliases an output of a prior graph"
+ __slots__ = ["index"]
+
+ index: PathOutputIndex
+
+ def __init__(self, index):
+ assert isinstance(index, tuple)
+ self.index = index
+
+
+class AliasesNewOutput(OutputAliasInfo):
+ "Marks that the graph output aliases an index in the new, returned outputs"
+
+ __slots__ = ["index"]
+
+ index: int
+
+ def __init__(self, index):
+ assert isinstance(index, int)
+ self.index = index
+
+
class CUDAGraphNode:
"""
A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool
@@ -572,9 +615,6 @@
# as we execute any tree path. When we retrieve a storage from the cache we
# check that it is still alive, and we hash based on observed recording data ptr
# and storage cdata.
- self.storage_cache: Dict[
- Tuple[StorageDataPtr, NBytes], StorageWeakRefWrapper
- ] = (parent.storage_cache if parent is not None else {})
# we preserve a single reference to executed outputs that is then referenced
# in children to avoid children having to chase parent pointers in the hot path
@@ -611,7 +651,6 @@
# Their locations are static but lifetimes are not. We only include the persistent static
# data ptrs below because the non persistent data ptrs may be outputs of this record and
# fresh allocations.
- self.output_is_alias_of_persistent_static_inputs: OutputList[bool] = []
# precompute expanded dims to avoid computing in the hot path
self.expanded_dims: List[List[int]] = [
@@ -672,7 +711,21 @@
# constructor, otherwise this optimization would not be valid.
# initialized below in _record
+
self.checkpointed_caching_state: Optional[AllocatorState] = None
+
+ # Output Storage Alias information, can be:
+ # - A new, unaliased storage, or the output is None
+ # - An alias of a persistent static input, in which case a storage will be set in the corresponding index
+ # of output_persistent_storage
+ # - An alias of an output of a prior graph
+ # - An alias of an output already created in the reconstructed outputs
+ self.output_storage_alias: OutputList[OutputAliasInfo] = []
+
+ # if an output aliases a static, persistent input then the Storage of the
+ # persistent output will be set here
+ self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = []
+
self.recording_outputs: OutputList[Optional[torch.Tensor]] = self._record(
wrapped_function.model, recording_inputs
)
@@ -714,9 +767,6 @@
if config.triton.slow_cudagraph_asserts:
self.debug_check_invariants_before_invocation()
- if self.parent is None:
- self.storage_cache.clear()
-
assert len(self.static_input_data_ptrs) == len(new_inputs)
# NB: this ranges over non-static inputs too
for idx, data_ptr in enumerate(self.static_input_data_ptrs):
@@ -725,9 +775,6 @@
if data_ptr is not None:
# static input, e.g., parameter
assert data_ptr == new_inputs[idx].data_ptr()
- # TODO - shouldnt need to add this for persistent static inputs
- # since we dont manage the lifetimes of their aliased outputs
- self.add_to_storage_cache(new_inputs[idx].untyped_storage())
else:
# non-static input, need to copy it into CUDA graph
dst = self._reconstruct_from_tensor_metadata(self.inputs_metadata[idx])
@@ -735,18 +782,87 @@
self._copy_input(idx, dst, src)
new_inputs.clear()
- self.graph.replay()
+ self.run_graph()
- outputs = [
- (self._reconstruct_from_tensor_metadata(metadata) if metadata else None)
- for metadata in self.outputs_metadata
- ]
+ outputs = self.reconstruct_outputs()
self._add_replayed_outputs(outputs)
self.debug_check_invariants_after_invocation()
return outputs
+ def reconstruct_outputs(self):
+ "Reconstruct output tensors according to their saved metadata and alias information"
+
+ # The cpp function is constructing a new Tensor according to the saved output metadata
+ # For each element in the corresponding storage list:
+ # - if a Storage is contained, that will be used
+ # - if None is contained, a new Storage will be constructed
+ # - if an int is contained, the storage from the output list at that int will be used
+ storages_info: List[
+ Union[UntypedStorage, None, int]
+ ] = self.prepare_storages_for_construction()
+ outputs_new = []
+
+ # # We recreate the below logic in cpp to reduce overhead, since this is on the hot path
+ """
+ for storage_info, metadata in zip(storages_info, self.outputs_metadata):
+ if metadata is None:
+ outputs_new.append(None)
+ continue
+
+ if isinstance(storage_info, UntypedStorage):
+ s = storage_info
+ if storage_info is None:
+ s = self.create_storage(metadata)
+ else:
+ assert isinstance(storage_info, int)
+ s = outputs_new[storage_info].untyped_storage()
+ outputs_new.append(self._reconstruct_from_tensor_metadata(metadata, storage=s))
+ """
+
+ torch._C._construct_Tensors_From_Storage_and_Metadata(
+ storages_info, self.outputs_metadata, outputs_new
+ )
+
+ return outputs_new
+
+ def prepare_alias_info_for_tensor_construction(
+ self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any]
+ ) -> List[Union[UntypedStorage, None, int]]:
+ if metadata is None or out_alias_info is UnaliasedStorage:
+ return None
+
+ if out_alias_info is PersistentStaticStorage:
+ return self.output_persistent_storage[out_index]
+
+ if isinstance(out_alias_info, AliasesPriorGraphOutput):
+ depth, existing_output_index = out_alias_info.index
+ ref = self.path_weakrefs[depth][existing_output_index]
+ assert ref()
+ return torch.UntypedStorage._new_with_weak_ptr(ref())
+
+ assert isinstance(out_alias_info, AliasesNewOutput)
+ return out_alias_info.index
+
+ def prepare_storages_for_construction(
+ self,
+ ) -> List[Union[UntypedStorage, None, int]]:
+ output_storages = []
+ for i, (output_storage_alias, metadata) in enumerate(
+ zip(self.output_storage_alias, self.outputs_metadata)
+ ):
+ output_storages.append(
+ self.prepare_alias_info_for_tensor_construction(
+ i, output_storage_alias, metadata
+ )
+ )
+
+ return output_storages
+
+ def run_graph(self):
+ self.graph.replay()
+
def all_outputs_are_dead(self):
"All outputs of the path from this node to its root are dead"
for depth, output_index in self.live_indices_after_graph:
@@ -758,8 +874,8 @@
"Record the model"
# see: output_is_alias_of_persistent_static_inputs above
- static_input_persistent_storage_ptrs: Set[int] = {
- inputs[i].untyped_storage().data_ptr()
+ static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
+ inputs[i].untyped_storage().data_ptr(): StorageWeakRefWrapper(inputs[i])
for i in self.wrapped_function.static_input_idxs
if not self._is_cuda_graph_recorded_tensor(inputs[i])
}
@@ -794,7 +910,9 @@
return static_outputs
def _add_first_outputs(
- self, outputs, static_input_persistent_storage_ptrs: Set[int]
+ self,
+ outputs,
+ static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
):
"Add the outputs from the first invocation of the node and set up metadata"
@@ -807,12 +925,42 @@
self.expected_dead_indices_after_graph = delta
assert len(self.outputs_weakrefs) == 0
+ # index from data pointer to index in outputs
+ output_new_storages_index: Dict[StorageDataPtr, int] = {}
+
for i, o in enumerate(outputs):
- self.output_is_alias_of_persistent_static_inputs.append(
- o is not None
- and o.untyped_storage().data_ptr()
- in static_input_persistent_storage_ptrs
+ if o is None:
+ self.output_storage_alias.append(UnaliasedStorage)
+ self.output_persistent_storage.append(None)
+ continue
+
+ ref = static_input_persistent_storage_ptrs.get(
+ o.untyped_storage().data_ptr(), None
)
+ if ref and ref() is not None:
+ self.output_storage_alias.append(PersistentStaticStorage)
+ self.output_persistent_storage.append(
+ torch.UntypedStorage._new_with_weak_ptr(ref())
+ )
+ continue
+
+ self.output_persistent_storage.append(None)
+
+ path_ref = self._is_alias_of_live_recorded_tensor(o)
+ if path_ref is not None:
+ self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
+ continue
+
+ if o.untyped_storage().data_ptr() in output_new_storages_index:
+ self.output_storage_alias.append(
+ AliasesNewOutput(
+ output_new_storages_index[o.untyped_storage().data_ptr()]
+ )
+ )
+ continue
+
+ output_new_storages_index[o.untyped_storage().data_ptr()] = i
+ self.output_storage_alias.append(UnaliasedStorage)
if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))]
@@ -840,10 +988,27 @@
def _add_replayed_outputs(self, outputs):
self.outputs_weakrefs.clear()
- for out, is_alias in zip(
- outputs, self.output_is_alias_of_persistent_static_inputs
- ):
- self.outputs_weakrefs.append(map_to_ref(out) if not is_alias else None)
+ output_weak_ref_cdatas = []
+ output_data_ptrs = []
+
+ # For output, gets the storage weakref and data_ptr if it is not a static persistent storage alias
+ torch._C._map_Storage_Refs(
+ outputs,
+ self.output_persistent_storage,
+ output_weak_ref_cdatas,
+ output_data_ptrs,
+ )
+ assert len(output_weak_ref_cdatas) == len(output_data_ptrs)
+
+ for ref, data_ptr in zip(output_weak_ref_cdatas, output_data_ptrs):
+ if ref is None:
+ assert data_ptr is None
+ self.outputs_weakrefs.append(None)
+ continue
+
+ self.outputs_weakrefs.append(
+ StorageWeakRefWrapper.from_weakref_and_data_ptr(ref, data_ptr)
+ )
@property
def parent(self):
@@ -878,6 +1043,18 @@
return False
+ def _is_alias_of_live_recorded_tensor(
+ self, t: torch.Tensor
+ ) -> Optional[PathOutputIndex]:
+ for depth, output_refs in enumerate(self.path_weakrefs):
+ for output_index, storage_ref in enumerate(output_refs):
+ if not is_live(storage_ref):
+ continue
+ if storage_ref.data_ptr() == t.untyped_storage().data_ptr():
+ return (depth, output_index)
+
+ return None
+
@staticmethod
def _check_liveness(indices: List[PathOutputIndex], output_refs: List[List[bool]]):
"Check that all of the indices specified are dead references"
@@ -923,7 +1100,6 @@
for i, node in enumerate(self._path_from_root):
assert self.path_weakrefs[i] is node.outputs_weakrefs
- assert self.storage_cache is node.storage_cache
nodes = list(self._path_from_root)
@@ -949,9 +1125,9 @@
live_storage_data_ptrs.add(stor_data_ptr)
live_storage_weak_ptrs.add(stor_weak_ptr)
- is_persistent_alias = nodes[
- depth
- ].output_is_alias_of_persistent_static_inputs[output_idx]
+ is_persistent_alias = (
+ nodes[depth].output_persistent_storage[output_idx] is not None
+ )
if is_persistent_alias:
assert stor_data_ptr not in live_blocks
@@ -1007,7 +1183,6 @@
"Clear the output lists of all nodes in the path and the storage cache"
for li in self.path_weakrefs:
li.clear()
- self.storage_cache.clear()
@staticmethod
def _tensor_metadata(x, ignore_storage_offset=True):
@@ -1022,13 +1197,12 @@
"dtype": x.dtype,
"device": x.device,
"storage_offset": x.storage_offset() if not ignore_storage_offset else 0,
- # ref_cdata was weak pointer of storage observed during recording, it may be
- # different upon execution
- "ref_cdata": x.untyped_storage()._cdata,
}
- def _reconstruct_from_tensor_metadata(self, metadata: Dict[str, Any]) -> Tensor:
- s = self.get_or_create_storage(metadata)
+ def _reconstruct_from_tensor_metadata(
+ self, metadata: Dict[str, Any], storage=None
+ ) -> Tensor:
+ s = self.create_storage(metadata) if storage is None else storage
t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"])
t.set_(
source=s,
@@ -1038,23 +1212,10 @@
)
return t
- def add_to_storage_cache(self, untyped_storage: UntypedStorage):
- self.storage_cache[
- (untyped_storage.data_ptr(), untyped_storage.nbytes())
- ] = StorageWeakRefWrapper(untyped_storage)
-
- def get_or_create_storage(self, metadata):
- storage_wrapper = self.storage_cache.get(
- (metadata["data_ptr"], metadata["nbytes"]), None
+ def create_storage(self, metadata):
+ return torch._C._construct_storage_from_data_pointer(
+ metadata["data_ptr"], metadata["device"], metadata["nbytes"]
)
- if storage_wrapper is None or not storage_wrapper():
- s = torch._C._construct_storage_from_data_pointer(
- metadata["data_ptr"], metadata["device"], metadata["nbytes"]
- )
- self.add_to_storage_cache(s)
- else:
- s = torch.UntypedStorage._new_with_weak_ptr(storage_wrapper())
- return s
def _allocate_and_copy_recording_inputs(self, inputs):
"""
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index de52ce7..6787fc0 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -1,7 +1,13 @@
#include <ATen/ATen.h>
+#include <ATen/core/TensorBody.h>
#include <ATen/cuda/CUDAConfig.h>
+#include <c10/core/Device.h>
+#include <c10/core/TensorImpl.h>
#include <c10/util/UniqueVoidPtr.h>
+#include <pybind11/pytypes.h>
+#include <torch/csrc/utils/python_arg_parser.h>
#include <unordered_set>
+
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
@@ -1060,6 +1066,79 @@
});
m.def(
+ "_map_Storage_Refs",
+ [](const py::sequence& outputs,
+ const py::list& outputs_persistent_storage,
+ py::list output_refs,
+ py::list output_data_ptrs) {
+ TORCH_CHECK(outputs.size() == outputs_persistent_storage.size());
+
+ for (size_t i = 0, end = outputs.size(); i < end; ++i) {
+ if (!outputs_persistent_storage[i].is_none() ||
+ outputs[i].is_none()) {
+ output_refs.append(py::none());
+ output_data_ptrs.append(py::none());
+ continue;
+ }
+
+ auto t = outputs[i].cast<at::Tensor>();
+ c10::StorageImpl* storage = t.storage().unsafeGetStorageImpl();
+ auto weak = c10::raw::intrusive_ptr::make_weak(storage);
+ output_refs.append(reinterpret_cast<size_t>(weak));
+ output_data_ptrs.append(
+ reinterpret_cast<size_t>(storage->data_ptr().get()));
+ }
+ });
+
+ m.def(
+ "_construct_Tensors_From_Storage_and_Metadata",
+ [](const py::list& storages,
+ const py::list& metadatas,
+ py::list& outputs) {
+ TORCH_CHECK(storages.size() == metadatas.size());
+ for (size_t i = 0, end = storages.size(); i < end; ++i) {
+ const auto& maybe_metadata = metadatas[i];
+
+ if (maybe_metadata.is_none()) {
+ outputs.append(py::none());
+ continue;
+ }
+
+ const py::dict& metadata = maybe_metadata.cast<py::dict>();
+ c10::Storage s;
+ if (storages[i].is_none()) {
+ s = c10::Storage(
+ c10::Storage::use_byte_size_t(),
+ metadata["nbytes"].cast<int64_t>(),
+ at::DataPtr(
+ reinterpret_cast<void*>(
+ metadata["data_ptr"].cast<size_t>()),
+ metadata["device"].cast<c10::Device>()));
+ } else if (py::isinstance<py::int_>(storages[i])) {
+ s = outputs[storages[i].cast<int64_t>()]
+ .cast<at::Tensor>()
+ .storage();
+ } else {
+ s = storages[i].cast<c10::Storage>();
+ }
+
+ auto dtype_arg = metadata["dtype"].ptr();
+ auto meta = scalarTypeToTypeMeta(toScalarType(dtype_arg));
+
+ constexpr c10::DispatchKeySet cuda_dks(c10::DispatchKey::CUDA);
+ at::Tensor tensor = at::detail::make_tensor_base<c10::TensorImpl>(
+ std::move(s), cuda_dks, meta);
+
+ tensor.unsafeGetTensorImpl()->set_sizes_and_strides(
+ metadata["size"].cast<std::vector<int64_t>>(),
+ metadata["stride"].cast<std::vector<int64_t>>());
+ tensor.unsafeGetTensorImpl()->set_storage_offset(
+ metadata["storage_offset"].cast<int64_t>());
+ outputs.append(std::move(tensor));
+ }
+ });
+
+ m.def(
"_cuda_beginAllocateCurrentStreamToPool",
[](int device, at::cuda::MempoolId_t mempool_id) {
auto stream = at::cuda::getCurrentCUDAStream(device);
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index c8c70c0..5c01442 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -700,14 +700,7 @@
return scalartype(i);
}
-inline at::ScalarType PythonArgs::scalartype(int i) {
- if (!args[i]) {
- auto scalartype = signature.params[i].default_scalartype;
- return (scalartype == at::ScalarType::Undefined)
- ? torch::tensors::get_default_scalar_type()
- : scalartype;
- }
- PyObject* obj = args[i];
+inline at::ScalarType toScalarType(PyObject* obj) {
if (obj == (PyObject*)&PyFloat_Type) {
return at::ScalarType::Double;
}
@@ -720,6 +713,17 @@
return reinterpret_cast<THPDtype*>(obj)->scalar_type;
}
+inline at::ScalarType PythonArgs::scalartype(int i) {
+ if (!args[i]) {
+ auto scalartype = signature.params[i].default_scalartype;
+ return (scalartype == at::ScalarType::Undefined)
+ ? torch::tensors::get_default_scalar_type()
+ : scalartype;
+ }
+ PyObject* obj = args[i];
+ return toScalarType(obj);
+}
+
inline c10::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
if (!args[i])
return c10::nullopt;
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index 6389fc9..7fbde65 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -25,12 +25,21 @@
The cdata member is a Python number containing the integer representation of
the Storage pointer."""
+ __slots__ = ["cdata", "_free_weak_ref"]
+
def __init__(self, storage):
self.cdata = storage._weak_ref()
# Save a direct reference to _free_weak_ref because the `torch` module
# might be cleared during Python shutdown before this module is cleared.
self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
+ @classmethod
+ def from_weakref(cls, cdata):
+ instance = cls.__new__(cls)
+ instance.cdata = cdata
+ instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
+ return instance
+
def expired(self):
return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]