blob: 6358ab7e4676f74c6d1520665de555192204c8c9 [file] [log] [blame]
import os
import sys
import tempfile
import tarfile
import pickle
import shutil
import struct
from contextlib import closing, contextmanager
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import torch
from ._utils import _import_dotted_name
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
def _add_to_tar(fn, tar_file, name):
tmp_file = tempfile.NamedTemporaryFile(delete=False)
fn(tmp_file)
tmp_file.close()
tar_file.add(tmp_file.name, arcname=name)
if os.path.isfile(tmp_file.name):
os.remove(tmp_file.name)
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
_package_registry = []
def register_package(priority, tagger, deserializer):
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
def _cpu_tag(obj):
if type(obj).__module__ == 'torch':
return 'cpu'
def _cuda_tag(obj):
if type(obj).__module__ == 'torch.cuda':
return 'cuda:' + str(obj.get_device())
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device_id = max(int(location[5:]), 0)
return obj.cuda(device_id)
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of " +
torch.typename(storage))
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " + location + ")")
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
# TODO: choose pickle protocol
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
Args:
obj: saved object
f: a file-like object (has to implement fileno that returns a file descriptor)
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
serialized_tensors = {}
serialized_storages = {}
def persistent_id(obj):
if torch.is_tensor(obj):
serialized_tensors[obj._cdata] = obj
return str(obj._cdata)
elif torch.is_storage(obj):
serialized_storages[obj._cdata] = obj
return str(obj._cdata)
return None
def save_tensors(f):
pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol)
for key, tensor in serialized_tensors.items():
storage = tensor.storage()
if storage is not None:
storage_id = storage._cdata
serialized_storages[storage_id] = storage
else:
storage_id = None
pickle_module.dump((key, storage_id, type(tensor)), f,
protocol=pickle_protocol)
f.flush()
tensor._write_metadata(f)
def save_storages(f):
storage_views = []
storage_views_roots = {}
for key, storage in serialized_storages.items():
root, offset = storage._root_storage()
if root is not storage:
storage_views_roots[root._cdata] = root
storage_views.append((storage._cdata, root._cdata, offset,
storage.size()))
for view_info in storage_views:
del serialized_storages[view_info[0]]
serialized_storages.update(storage_views_roots)
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
for key, storage in serialized_storages.items():
location = location_tag(storage)
storage_type = normalize_storage_type(type(storage))
pickle_module.dump((key, location, storage_type), f,
protocol=pickle_protocol)
f.flush()
storage._write_file(f)
pickle_module.dump(storage_views, f, protocol=pickle_protocol)
def pickle_objects(f):
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
def save_sys_info(f):
sys_info = dict(
protocol_version=1000,
little_endian=sys.byteorder == 'little',
type_sizes = dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
with closing(tarfile.open(fileobj=f, mode='w:', format=tarfile.PAX_FORMAT)) as tar:
_add_to_tar(save_sys_info, tar, 'sys_info')
_add_to_tar(pickle_objects, tar, 'pickle')
_add_to_tar(save_tensors, tar, 'tensors')
_add_to_tar(save_storages, tar, 'storages')
def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with torch.save from a disk file.
torch.load can dynamically remap storages to be loaded on a different device
using the map_location argument. If it's a callable, it will be called with
two arguments: storage and location tag. It's expected to either return a
storage that's been moved to a different location, or None (and the location
will be resolved using the default method). If this argument is a dict it's
expected to be a mapping from location tags used in a file, to location
tags of the current system.
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
tagging and deserialization methods using register_package.
Args:
f: a file-like object (has to implement fileno that returns a file descriptor)
map_location: a function or a dict specifying how to remap storage locations
pickle_module: module used for unpickling metadata and objects (has to match
the pickle_module used to serialize file)
"""
deserialized_objects = {}
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
if not result:
result = default_restore_location(storage, location)
return result
def persistent_load(saved_id):
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f)
for i in range(num_storages):
args = pickle_module.load(f)
key, location, storage_type = args
obj = storage_type._new_with_file(f)
obj = restore_location(obj, location)
deserialized_objects[key] = obj
storage_views = pickle_module.load(f)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
deserialized_objects[target_cdata] = root[offset:offset+size]
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f)
for i in range(num_tensors):
args = pickle_module.load(f)
key, storage_id, original_tensor_type = args
storage = deserialized_objects.get(storage_id, None)
if storage:
tensor_type = storage_to_tensor_type(storage)
tensor = tensor_type._new_with_metadata_file(f, storage)
else:
tensor = original_tensor_type._new_with_metadata_file(f, storage)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result