| import os |
| import sys |
| import tempfile |
| import tarfile |
| import pickle |
| import shutil |
| import struct |
| from contextlib import closing, contextmanager |
| if sys.version_info[0] == 2: |
| import cPickle as pickle |
| else: |
| import pickle |
| |
| import torch |
| from ._utils import _import_dotted_name |
| |
| DEFAULT_PROTOCOL = 2 |
| |
| LONG_SIZE = struct.Struct('=l').size |
| INT_SIZE = struct.Struct('=i').size |
| SHORT_SIZE = struct.Struct('=h').size |
| |
| def _add_to_tar(fn, tar_file, name): |
| tmp_file = tempfile.NamedTemporaryFile(delete=False) |
| fn(tmp_file) |
| tmp_file.close() |
| |
| tar_file.add(tmp_file.name, arcname=name) |
| if os.path.isfile(tmp_file.name): |
| os.remove(tmp_file.name) |
| |
| |
| @contextmanager |
| def mkdtemp(): |
| path = tempfile.mkdtemp() |
| yield path |
| shutil.rmtree(path) |
| |
| |
| _package_registry = [] |
| |
| |
| def register_package(priority, tagger, deserializer): |
| queue_elem = (priority, tagger, deserializer) |
| _package_registry.append(queue_elem) |
| _package_registry.sort() |
| |
| |
| def _cpu_tag(obj): |
| if type(obj).__module__ == 'torch': |
| return 'cpu' |
| |
| |
| def _cuda_tag(obj): |
| if type(obj).__module__ == 'torch.cuda': |
| return 'cuda:' + str(obj.get_device()) |
| |
| |
| def _cpu_deserialize(obj, location): |
| if location == 'cpu': |
| return obj |
| |
| |
| def _cuda_deserialize(obj, location): |
| if location.startswith('cuda'): |
| device_id = max(int(location[5:]), 0) |
| return obj.cuda(device_id) |
| |
| |
| register_package(10, _cpu_tag, _cpu_deserialize) |
| register_package(20, _cuda_tag, _cuda_deserialize) |
| |
| |
| def location_tag(storage): |
| for _, tagger, _ in _package_registry: |
| location = tagger(storage) |
| if location: |
| return location |
| raise RuntimeError("don't know how to determine data location of " + |
| torch.typename(storage)) |
| |
| |
| def default_restore_location(storage, location): |
| for _, _, fn in _package_registry: |
| result = fn(storage, location) |
| if result is not None: |
| return result |
| raise RuntimeError("don't know how to restore data location of " + |
| torch.typename(storage) + " (tagged with " + location + ")") |
| |
| |
| def normalize_storage_type(storage_type): |
| return getattr(torch, storage_type.__name__) |
| |
| |
| def storage_to_tensor_type(storage): |
| storage_type = type(storage) |
| module = _import_dotted_name(storage_type.__module__) |
| return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) |
| |
| |
| # TODO: choose pickle protocol |
| def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): |
| """Saves an object to a disk file. |
| |
| Args: |
| obj: saved object |
| f: a file-like object (has to implement fileno that returns a file descriptor) |
| pickle_module: module used for pickling metadata and objects |
| pickle_protocol: can be specified to override the default protocol |
| """ |
| serialized_tensors = {} |
| serialized_storages = {} |
| |
| def persistent_id(obj): |
| if torch.is_tensor(obj): |
| serialized_tensors[obj._cdata] = obj |
| return str(obj._cdata) |
| elif torch.is_storage(obj): |
| serialized_storages[obj._cdata] = obj |
| return str(obj._cdata) |
| return None |
| |
| def save_tensors(f): |
| pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol) |
| for key, tensor in serialized_tensors.items(): |
| storage = tensor.storage() |
| if storage is not None: |
| storage_id = storage._cdata |
| serialized_storages[storage_id] = storage |
| else: |
| storage_id = None |
| |
| pickle_module.dump((key, storage_id, type(tensor)), f, |
| protocol=pickle_protocol) |
| f.flush() |
| tensor._write_metadata(f) |
| |
| def save_storages(f): |
| storage_views = [] |
| storage_views_roots = {} |
| |
| for key, storage in serialized_storages.items(): |
| root, offset = storage._root_storage() |
| if root is not storage: |
| storage_views_roots[root._cdata] = root |
| storage_views.append((storage._cdata, root._cdata, offset, |
| storage.size())) |
| for view_info in storage_views: |
| del serialized_storages[view_info[0]] |
| serialized_storages.update(storage_views_roots) |
| |
| pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol) |
| for key, storage in serialized_storages.items(): |
| location = location_tag(storage) |
| storage_type = normalize_storage_type(type(storage)) |
| pickle_module.dump((key, location, storage_type), f, |
| protocol=pickle_protocol) |
| f.flush() |
| storage._write_file(f) |
| |
| pickle_module.dump(storage_views, f, protocol=pickle_protocol) |
| |
| def pickle_objects(f): |
| pickler = pickle_module.Pickler(f, protocol=pickle_protocol) |
| pickler.persistent_id = persistent_id |
| pickler.dump(obj) |
| |
| def save_sys_info(f): |
| sys_info = dict( |
| protocol_version=1000, |
| little_endian=sys.byteorder == 'little', |
| type_sizes = dict( |
| short=SHORT_SIZE, |
| int=INT_SIZE, |
| long=LONG_SIZE, |
| ), |
| ) |
| pickle_module.dump(sys_info, f, protocol=pickle_protocol) |
| |
| with closing(tarfile.open(fileobj=f, mode='w:', format=tarfile.PAX_FORMAT)) as tar: |
| _add_to_tar(save_sys_info, tar, 'sys_info') |
| _add_to_tar(pickle_objects, tar, 'pickle') |
| _add_to_tar(save_tensors, tar, 'tensors') |
| _add_to_tar(save_storages, tar, 'storages') |
| |
| |
| def load(f, map_location=None, pickle_module=pickle): |
| """Loads an object saved with torch.save from a disk file. |
| |
| torch.load can dynamically remap storages to be loaded on a different device |
| using the map_location argument. If it's a callable, it will be called with |
| two arguments: storage and location tag. It's expected to either return a |
| storage that's been moved to a different location, or None (and the location |
| will be resolved using the default method). If this argument is a dict it's |
| expected to be a mapping from location tags used in a file, to location |
| tags of the current system. |
| |
| By default the location tags are 'cpu' for host tensors and 'cuda:device_id' |
| (e.g. 'cuda:2') for cuda tensors. User extensions can register their own |
| tagging and deserialization methods using register_package. |
| |
| Args: |
| f: a file-like object (has to implement fileno that returns a file descriptor) |
| map_location: a function or a dict specifying how to remap storage locations |
| pickle_module: module used for unpickling metadata and objects (has to match |
| the pickle_module used to serialize file) |
| """ |
| deserialized_objects = {} |
| |
| if map_location is None: |
| restore_location = default_restore_location |
| elif isinstance(map_location, dict): |
| def restore_location(storage, location): |
| location = map_location.get(location, location) |
| return default_restore_location(storage, location) |
| else: |
| def restore_location(storage, location): |
| result = map_location(storage, location) |
| if not result: |
| result = default_restore_location(storage, location) |
| return result |
| |
| def persistent_load(saved_id): |
| return deserialized_objects[int(saved_id)] |
| |
| with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ |
| mkdtemp() as tmpdir: |
| |
| tar.extract('storages', path=tmpdir) |
| with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: |
| num_storages = pickle_module.load(f) |
| for i in range(num_storages): |
| args = pickle_module.load(f) |
| key, location, storage_type = args |
| obj = storage_type._new_with_file(f) |
| obj = restore_location(obj, location) |
| deserialized_objects[key] = obj |
| |
| storage_views = pickle_module.load(f) |
| for target_cdata, root_cdata, offset, size in storage_views: |
| root = deserialized_objects[root_cdata] |
| deserialized_objects[target_cdata] = root[offset:offset+size] |
| |
| tar.extract('tensors', path=tmpdir) |
| with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: |
| num_tensors = pickle_module.load(f) |
| for i in range(num_tensors): |
| args = pickle_module.load(f) |
| key, storage_id, original_tensor_type = args |
| storage = deserialized_objects.get(storage_id, None) |
| if storage: |
| tensor_type = storage_to_tensor_type(storage) |
| tensor = tensor_type._new_with_metadata_file(f, storage) |
| else: |
| tensor = original_tensor_type._new_with_metadata_file(f, storage) |
| deserialized_objects[key] = tensor |
| |
| pickle_file = tar.extractfile('pickle') |
| unpickler = pickle_module.Unpickler(pickle_file) |
| unpickler.persistent_load = persistent_load |
| result = unpickler.load() |
| return result |
| |