Type-annotate serialization.py (#40862)
Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes
Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862
Differential Revision: D22344743
Pulled By: malfet
fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
diff --git a/torch/serialization.py b/torch/serialization.py
index c19a148..62642a3 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,8 @@
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file
+from torch.types import Storage
+from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
import copyreg
import pickle
import pathlib
@@ -26,7 +28,6 @@
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
-
class SourceChangeWarning(Warning):
pass
@@ -41,7 +42,7 @@
_package_registry = []
-def _is_zipfile(f):
+def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or
@@ -160,7 +161,7 @@
register_package(20, _cuda_tag, _cuda_deserialize)
-def location_tag(storage):
+def location_tag(storage: Storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@@ -237,29 +238,30 @@
class _open_zipfile_reader(_opener):
- def __init__(self, name_or_buffer):
+ def __init__(self, name_or_buffer) -> None:
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
- def __init__(self, name):
+ def __init__(self, name) -> None:
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
- def __exit__(self, *args):
+ def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
class _open_zipfile_writer_buffer(_opener):
- def __init__(self, buffer):
+ def __init__(self, buffer) -> None:
self.buffer = buffer
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
- def __exit__(self, *args):
+ def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
self.buffer.flush()
def _open_zipfile_writer(name_or_buffer):
+ container: Type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
@@ -267,7 +269,7 @@
return container(name_or_buffer)
-def _is_compressed_file(f):
+def _is_compressed_file(f) -> bool:
compress_modules = ['gzip']
try:
return f.__module__ in compress_modules
@@ -291,7 +293,7 @@
return False
-def _check_seekable(f):
+def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e):
for p in patterns:
@@ -307,8 +309,9 @@
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
+ return False
-def _check_dill_version(pickle_module):
+def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised.
@@ -327,7 +330,8 @@
pickle_module.__version__
))
-def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
+def save(obj, f: Union[str, os.PathLike, BinaryIO],
+ pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
@@ -370,12 +374,12 @@
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
-def _legacy_save(obj, f, pickle_module, pickle_protocol):
+def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
- def persistent_id(obj):
+ def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
@@ -396,6 +400,8 @@
return ('module', obj, source_file, source)
elif torch.is_storage(obj):
+ view_metadata: Optional[Tuple[str, int, int]]
+ obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj))
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
@@ -589,20 +595,20 @@
def _get_layout(name):
"""Get layout extension object from its string representation.
"""
- cache = _get_layout.cache
+ cache = _get_layout.cache # type: ignore[attr-defined]
if not cache:
for v in torch.__dict__.values():
if isinstance(v, torch.layout):
cache[str(v)] = v
return cache[name]
-
-_get_layout.cache = {}
+# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
+_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
- deserialized_objects = {}
+ deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location)
@@ -648,7 +654,7 @@
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
- deserialized_objects = {}
+ deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
@@ -777,7 +783,7 @@
return result
-def _maybe_decode_ascii(bytes_str):
+def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.