blob: 5c7fba9f1fab2d452a5a93b12d36a37240917498 [file] [log] [blame]
import torch
import os
import weakref
import multiprocessing
from multiprocessing.reduction import ForkingPickler
import sys
try:
# Early load resource_sharer to prevent a partially initialized instance
# from being inherited in a forked child process. The reduce_storage method
# requires this module indirectly through DupFd(). The built-in mp.Queue
# class pickles arguments in a background thread which may overlap with the
# fork.
import multiprocessing.resource_sharer
except ImportError:
pass
class StorageRef(object):
# An object with a cdata field which may be set to None. We subclass object
# instead of using a dict() to support weak references.
def __init__(self, ptr):
self.cdata = ptr
# mapping from handles to StorageRef objects
shared_cache = weakref.WeakValueDictionary()
def rebuild_event(handle):
return torch.cuda.Event(_handle=handle)
def reduce_event(event):
return (rebuild_event, (event.ipc_handle(),))
def rebuild_tensor(cls, storage, metadata):
storage_offset, size, stride = metadata
new_tensor = cls()
new_tensor.set_(storage, storage_offset, size, stride)
return new_tensor
def reduce_tensor(tensor):
metadata = (tensor.storage_offset(), tensor.size(), tensor.stride())
storage = tensor.storage()
return (rebuild_tensor, (type(tensor), storage, metadata))
def fd_id(fd):
# Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
# this doesn't work with shared memory handles, which is why we don't
# support the "file_descriptor" sharing method on that platform.
stat = os.fstat(fd)
return (stat.st_ino, stat.st_dev)
def storage_from_cache(cls, key):
storage_ref = shared_cache.get(key)
if storage_ref is None:
return None
return cls._new_with_weak_ptr(storage_ref)
def rebuild_storage_fd(cls, df, size):
if sys.version_info[0] == 2:
fd = multiprocessing.reduction.rebuild_handle(df)
else:
fd = df.detach()
try:
storage = storage_from_cache(cls, fd_id(fd))
if storage is not None:
return storage
storage = cls._new_shared_fd(fd, size)
shared_cache[fd_id(fd)] = storage._weak_ref(StorageRef)
return storage
finally:
os.close(fd)
def rebuild_storage_filename(cls, manager, handle, size):
storage = storage_from_cache(cls, handle)
if storage is not None:
return storage._shared_decref()
storage = cls._new_shared_filename(manager, handle, size)
shared_cache[handle] = storage._weak_ref(StorageRef)
return storage._shared_decref()
def rebuild_storage_cuda(cls, device, handle, size, offset, view_size):
storage = storage_from_cache(cls, handle)
if storage is not None:
return storage._new_view(offset, view_size)
torch.cuda._lazy_init()
storage = cls._new_shared_cuda(device, handle, size, offset, view_size)
shared_cache[handle] = storage._weak_ref(StorageRef)
return storage
def reduce_storage(storage):
from . import get_sharing_strategy
if storage.is_cuda:
metadata = storage._share_cuda_()
cache_key = metadata[1]
rebuild = rebuild_storage_cuda
elif get_sharing_strategy() == 'file_system':
metadata = storage._share_filename_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
storage._shared_incref()
else:
fd, size = storage._share_fd_()
if sys.version_info[0] == 2:
df = multiprocessing.reduction.reduce_handle(fd)
else:
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)
rebuild = rebuild_storage_fd
shared_cache[cache_key] = storage._weak_ref(StorageRef)
return (rebuild, (type(storage),) + metadata)
def init_reductions():
ForkingPickler.register(torch.cuda.Event, reduce_event)
for t in torch._storage_classes:
ForkingPickler.register(t, reduce_storage)
for t in torch._tensor_classes:
ForkingPickler.register(t, reduce_tensor)