Remove dtype from torch.Storage and use only torch.ByteStorage (#62030)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030
Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible
Fixes https://github.com/pytorch/pytorch/issues/47442
* **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today.
* There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate.
* As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes.
* `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments.
* It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor.
* It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling.
* The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall.
To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization.
Original pull request: https://github.com/pytorch/pytorch/pull/59671
Reviewed By: soulitzer, ngimel
Differential Revision: D29466819
Pulled By: ezyang
fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
diff --git a/torch/serialization.py b/torch/serialization.py
index 4443561..e28ce0e 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -13,6 +13,7 @@
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
+from torch.storage import _get_dtype_from_pickle_storage_type
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
@@ -152,7 +153,7 @@
if getattr(obj, "_torch_load_uninitialized", False):
storage_type = getattr(torch.cuda, type(obj).__name__)
with torch.cuda.device(device):
- return storage_type(obj.size())
+ return storage_type(obj.nbytes())
else:
return obj.cuda(device)
@@ -161,7 +162,7 @@
register_package(20, _cuda_tag, _cuda_deserialize)
-def location_tag(storage: Storage):
+def location_tag(storage: Union[Storage, torch.storage.TypedStorage]):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@@ -406,28 +407,75 @@
"for correctness upon loading.")
return ('module', obj, source_file, source)
- elif torch.is_storage(obj):
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
+ if isinstance(obj, torch.storage.TypedStorage):
+ # TODO: Once we decide to break serialization FC, this case
+ # can be deleted
+ storage = obj._storage
+ storage_type_str = obj.pickle_storage_type()
+ storage_type = getattr(torch, storage_type_str)
+ dtype = obj.dtype
+ storage_numel = obj.size()
+
+ else:
+ storage = obj
+ storage_type = normalize_storage_type(type(obj))
+ dtype = torch.uint8
+ storage_numel = cast(Storage, storage).nbytes()
+
view_metadata: Optional[Tuple[str, int, int]]
- obj = cast(Storage, obj)
- storage_type = normalize_storage_type(type(obj))
+ storage = cast(Storage, storage)
+
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
offset = 0
- obj_key = str(obj._cdata)
- location = location_tag(obj)
- serialized_storages[obj_key] = obj
- is_view = obj._cdata != obj._cdata
+ storage_key = str(storage._cdata)
+ location = location_tag(storage)
+
+ # TODO: There's an issue here with FC. It might be impossible to
+ # solve, but it's worth noting. Imagine we save a list `[storage,
+ # tensor]`, where `tensor.storage()` is the same as `storage`, and
+ # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
+ # torch.float`. The storage will be serialized with element size
+ # of 1, since we're choosing to serialize the first occurance of
+ # a duplicate storage. Since this legacy serialization format saves
+ # the numel of the storage, rather than nbytes directly, we'll be
+ # effectively saving nbytes in this case. We'll be able to load it
+ # and the tensor back up with no problems in _this_ and future
+ # versions of pytorch, but in older versions, here's the problem:
+ # the storage will be loaded up as a UntypedStorage, and then the
+ # FloatTensor will loaded and the UntypedStorage will be assigned to
+ # it. Since the storage dtype does not match the tensor dtype, this
+ # will cause an error. If we reverse the list, like `[tensor,
+ # storage]`, then we will save the `tensor.storage()` as a faked
+ # `FloatStorage`, and the saved size will be the correct
+ # dtype-specific numel count that old versions expect. `tensor`
+ # will be able to load up properly in old versions, pointing to
+ # a FloatStorage. However, `storage` is still being translated to
+ # a UntypedStorage, and it will try to resolve to the same
+ # FloatStorage that `tensor` contains. This will also cause an
+ # error. It doesn't seem like there's any way around this.
+ # Probably, we just cannot maintain FC for the legacy format if the
+ # saved list contains both a tensor and a storage that point to the
+ # same data. We should still be able to maintain FC for lists of
+ # just tensors, as long as all views share the same dtype as the
+ # tensor they are viewing.
+
+ if storage_key not in serialized_storages:
+ serialized_storages[storage_key] = (storage, dtype)
+ is_view = storage._cdata != storage._cdata
if is_view:
- view_metadata = (str(obj._cdata), offset, obj.size())
+ view_metadata = (str(storage._cdata), offset, storage.nbytes())
else:
view_metadata = None
- return ('storage',
- storage_type,
- obj_key,
- location,
- obj.size(),
- view_metadata)
+ res = ('storage',
+ storage_type,
+ storage_key,
+ location,
+ storage_numel,
+ view_metadata)
+ return res
return None
sys_info = dict(
@@ -451,7 +499,8 @@
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
- serialized_storages[key]._write_file(f, _should_read_directly(f), True)
+ storage, dtype = serialized_storages[key]
+ storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
def _save(obj, zip_file, pickle_module, pickle_protocol):
@@ -464,17 +513,32 @@
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
- if torch.is_storage(obj):
- storage_type = normalize_storage_type(type(obj))
- obj_key = id_map.setdefault(obj._cdata, str(len(id_map)))
- location = location_tag(obj)
- serialized_storages[obj_key] = obj
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
+
+ if isinstance(obj, torch.storage.TypedStorage):
+ # TODO: Once we decide to break serialization FC, this case
+ # can be deleted
+ storage = obj._storage
+ storage_type_str = obj.pickle_storage_type()
+ storage_type = getattr(torch, storage_type_str)
+ storage_numel = obj.size()
+
+ else:
+ storage = obj
+ storage_type = normalize_storage_type(type(obj))
+ storage_numel = storage.nbytes()
+
+ storage = cast(Storage, storage)
+ storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
+ location = location_tag(storage)
+ serialized_storages[storage_key] = storage
return ('storage',
storage_type,
- obj_key,
+ storage_key,
location,
- obj.size())
+ storage_numel)
+
return None
# Write the pickle data for `obj`
@@ -495,7 +559,7 @@
if storage.device.type != 'cpu':
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
- num_bytes = storage.size() * storage.element_size()
+ num_bytes = storage.nbytes()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
@@ -630,6 +694,16 @@
restore_location = _get_restore_location(map_location)
+ class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
+
+ def find_class(self, mod_name, name):
+ if type(name) is str and 'Storage' in name:
+ try:
+ return StorageType(name)
+ except KeyError:
+ pass
+ return super().find_class(mod_name, name)
+
def _check_container_source(container_type, source_file, original_source):
try:
current_source = ''.join(get_source_lines_and_file(container_type)[0])
@@ -690,14 +764,25 @@
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
- obj = storage_type._new_with_file(f)
+ dtype = storage_type.dtype
+ obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
- deserialized_objects[key] = obj
+ # TODO: Once we decide to break serialization FC, we can
+ # stop wrapping with TypedStorage
+ deserialized_objects[key] = torch.storage.TypedStorage(
+ wrap_storage=obj,
+ dtype=dtype)
storage_views = pickle_module.load(f, **pickle_load_args)
- for target_cdata, root_cdata, offset, size in storage_views:
+ for target_cdata, root_cdata, offset, numel in storage_views:
root = deserialized_objects[root_cdata]
- deserialized_objects[target_cdata] = root[offset:offset + size]
+ element_size = torch._utils._element_size(root.dtype)
+ offset_bytes = offset * element_size
+ # TODO: Once we decide to break serialization FC, we can
+ # stop wrapping with TypedStorage
+ deserialized_objects[target_cdata] = torch.storage.TypedStorage(
+ wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
+ dtype=root.dtype)
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
@@ -706,18 +791,18 @@
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
- tensor_type = storage_to_tensor_type(storage)
ndim, = struct.unpack('<i', f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
f.read(4)
- size = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
+ numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
- tensor = tensor_type().set_(storage, storage_offset, size, stride)
+ tensor = torch.tensor([], dtype=storage.dtype).set_(
+ storage._storage, storage_offset, numel, stride)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
- unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
+ unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
@@ -735,20 +820,37 @@
_check_container_source(*data)
return data[0]
elif typename == 'storage':
- data_type, root_key, location, size, view_metadata = data
+ storage_type, root_key, location, numel, view_metadata = data
location = _maybe_decode_ascii(location)
+ dtype = storage_type.dtype
+
+ nbytes = numel * torch._utils._element_size(dtype)
+
if root_key not in deserialized_objects:
- obj = data_type(size)
+ obj = cast(Storage, torch.UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
- deserialized_objects[root_key] = restore_location(obj, location)
- storage = deserialized_objects[root_key]
+ # TODO: Once we decide to break serialization FC, we can
+ # stop wrapping with TypedStorage
+ deserialized_objects[root_key] = torch.storage.TypedStorage(
+ wrap_storage=restore_location(obj, location),
+ dtype=dtype)
+
+ typed_storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
+ offset_bytes = offset * torch._utils._element_size(dtype)
+ view_size_bytes = view_size * torch._utils._element_size(dtype)
if view_key not in deserialized_objects:
- deserialized_objects[view_key] = storage[offset:offset + view_size]
- return deserialized_objects[view_key]
+ # TODO: Once we decide to break serialization FC, we can
+ # stop wrapping with TypedStorage
+ deserialized_objects[view_key] = torch.storage.TypedStorage(
+ wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
+ dtype=dtype)
+ res = deserialized_objects[view_key]
+
else:
- return storage
+ res = typed_storage
+ return res
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
@@ -782,7 +884,7 @@
raise RuntimeError("Invalid protocol version: %s" % protocol_version)
_sys_info = pickle_module.load(f, **pickle_load_args)
- unpickler = pickle_module.Unpickler(f, **pickle_load_args)
+ unpickler = UnpicklerWrapper(f, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
@@ -791,7 +893,10 @@
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
- deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
+ typed_storage = deserialized_objects[key]
+ typed_storage._storage._set_from_file(
+ f, offset, f_should_read_directly,
+ torch._utils._element_size(typed_storage.dtype))
if offset is not None:
offset = f.tell()
@@ -833,17 +938,27 @@
return result
return restore_location
+class StorageType():
+ def __init__(self, name):
+ self.dtype = _get_dtype_from_pickle_storage_type(name)
+
+ def __str__(self):
+ return f'StorageType(dtype={self.dtype})'
+
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
restore_location = _get_restore_location(map_location)
loaded_storages = {}
- def load_tensor(data_type, size, key, location):
+ def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
- dtype = data_type(0).dtype
- storage = zip_file.get_storage_from_record(name, size, dtype).storage()
- loaded_storages[key] = restore_location(storage, location)
+ storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage()._untyped()
+ # TODO: Once we decide to break serialization FC, we can
+ # stop wrapping with TypedStorage
+ loaded_storages[key] = torch.storage.TypedStorage(
+ wrap_storage=restore_location(storage, location),
+ dtype=dtype)
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
@@ -852,11 +967,14 @@
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
- data_type, key, location, size = data
+ storage_type, key, location, numel = data
+ dtype = storage_type.dtype
+
if key not in loaded_storages:
- load_tensor(data_type, size, key, _maybe_decode_ascii(location))
- storage = loaded_storages[key]
- return storage
+ nbytes = numel * torch._utils._element_size(dtype)
+ load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
+
+ return loaded_storages[key]
load_module_mapping: Dict[str, str] = {
# See https://github.com/pytorch/pytorch/pull/51633
@@ -871,6 +989,11 @@
# Lets us override the imports that pickle uses when unpickling an object.
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
def find_class(self, mod_name, name):
+ if type(name) is str and 'Storage' in name:
+ try:
+ return StorageType(name)
+ except KeyError:
+ pass
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)