blob: 700cb6dda376782bd638731947b0e96324fc385d [file] [log] [blame]
import difflib
import inspect
import os
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
class SourceChangeWarning(Warning):
pass
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
_package_registry = []
def register_package(priority, tagger, deserializer):
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
def _cpu_tag(obj):
if type(obj).__module__ == 'torch':
return 'cpu'
def _cuda_tag(obj):
if type(obj).__module__ == 'torch.cuda':
return 'cuda:' + str(obj.get_device())
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device_id = max(int(location[5:]), 0)
return obj.cuda(device_id)
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of " +
torch.typename(storage))
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " +
location + ")")
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
Args:
obj: saved object
f: a file-like object (has to implement fileno that returns a file descriptor)
or a string containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
new_fd = False
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
new_fd = True
f = open(f, "wb")
try:
return _save(obj, f, pickle_module, pickle_protocol)
finally:
if new_fd:
f.close()
def _save(obj, f, pickle_module, pickle_protocol):
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, type) and issubclass(obj, nn.Module):
if obj in serialized_container_types:
return None
serialized_container_types[obj] = True
source_file = source = None
try:
source_file = inspect.getsourcefile(obj)
source = inspect.getsource(obj)
except: # saving the source is optional, so we can ignore any errors
warnings.warn("Couldn't retrieve source code for container of "
"type " + obj.__name__ + ". It won't be checked "
"for correctness upon loading.")
return ('module', obj, source_file, source)
elif torch.is_storage(obj):
storage_type = normalize_storage_type(type(obj))
root, offset = obj._root_storage()
root_key = str(root._cdata)
location = location_tag(obj)
serialized_storages[root_key] = root
is_view = obj._cdata != root._cdata
if is_view:
view_metadata = (str(obj._cdata), offset, obj.size())
else:
view_metadata = None
return ('storage',
storage_type,
root_key,
location,
root.size(),
view_metadata)
return None
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == 'little',
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
serialized_storage_keys = sorted(serialized_storages.keys())
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
serialized_storages[key]._write_file(f)
def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with :func:`torch.save` from a file.
torch.load can dynamically remap storages to be loaded on a different device
using the map_location argument. If it's a callable, it will be called with
two arguments: storage and location tag. It's expected to either return a
storage that's been moved to a different location, or None (and the location
will be resolved using the default method). If this argument is a dict it's
expected to be a mapping from location tags used in a file, to location
tags of the current system.
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
tagging and deserialization methods using register_package.
Args:
f: a file-like object (has to implement fileno that returns a file descriptor,
and must implement seek), or a string containing a file name
map_location: a function or a dict specifying how to remap storage locations
pickle_module: module used for unpickling metadata and objects (has to match
the pickle_module used to serialize file)
Example:
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
"""
new_fd = False
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
new_fd = True
f = open(f, 'rb')
try:
return _load(f, map_location, pickle_module)
finally:
if new_fd:
f.close()
def _load(f, map_location, pickle_module):
deserialized_objects = {}
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
def _check_container_source(container_type, source_file, original_source):
current_source = inspect.getsource(container_type)
if original_source != current_source:
if container_type.dump_patches:
file_name = container_type.__name__ + '.patch'
diff = difflib.unified_diff(current_source.split('\n'),
original_source.split('\n'),
source_file,
source_file, lineterm="")
lines = '\n'.join(diff)
try:
with open(file_name, 'a+') as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise IOError
msg = ("Saved a reverse patch to " + file_name + ". "
"Run `patch -p0 < " + file_name + "` to revert your "
"changes.")
except IOError:
msg = ("Tried to save a patch, but couldn't create a "
"writable file " + file_name + ". Make sure it "
"doesn't exist and your working directory is "
"writable.")
else:
msg = ("you can retrieve the original source code by "
"accessing the object's source attribute or set "
"`torch.nn.Module.dump_patches = True` and use the "
"patch tool to revert the changes.")
msg = ("source code of class '{}' has changed. {}"
.format(torch.typename(container_type), msg))
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
deserialized_objects = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
# Ignore containers that don't have any sources saved
if all(saved_id[1:]):
_check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f)
for i in range(num_storages):
args = pickle_module.load(f)
key, location, storage_type = args
obj = storage_type._new_with_file(f)
obj = restore_location(obj, location)
deserialized_objects[key] = obj
storage_views = pickle_module.load(f)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
deserialized_objects[target_cdata] = root[offset:offset + size]
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f)
for i in range(num_tensors):
args = pickle_module.load(f)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
tensor_type = storage_to_tensor_type(storage)
tensor = tensor_type._new_with_metadata_file(f, storage)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
deserialized_objects = {}
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = saved_id[0]
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data)
return data[0]
elif typename == 'storage':
data_type, root_key, location, size, view_metadata = data
if root_key not in deserialized_objects:
deserialized_objects[root_key] = restore_location(
data_type(size), location)
storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
if view_key not in deserialized_objects:
deserialized_objects[view_key] = storage[offset:offset + view_size]
return deserialized_objects[view_key]
else:
return storage
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
# try the legacy loader first, which only works if f is a tarfile
try:
return legacy_load(f)
except tarfile.TarError:
pass
f.seek(0)
magic_number = pickle_module.load(f)
if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f)
if protocol_version != PROTOCOL_VERSION:
raise RuntimeError("Invalid protocol version: %s" % protocol_version)
_sys_info = pickle_module.load(f)
unpickler = pickle_module.Unpickler(f)
unpickler.persistent_load = persistent_load
result = unpickler.load()
deserialized_storage_keys = pickle_module.load(f)
offset = f.tell()
for key in deserialized_storage_keys:
assert key in deserialized_objects
deserialized_objects[key]._set_from_file(f, offset)
offset = None
return result