Type-annotate serialization.py (#40862)
Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes
Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862
Differential Revision: D22344743
Pulled By: malfet
fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
diff --git a/mypy.ini b/mypy.ini
index 4ef24b6..0583158 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -317,9 +317,6 @@
[mypy-torch.nn.functional]
ignore_errors = True
-[mypy-torch.serialization]
-ignore_errors = True
-
[mypy-torch.utils]
ignore_errors = True
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index e043621..1ddc9d7 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -2,10 +2,11 @@
import torch
from torch import Tensor
-from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
+from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
+ Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
from torch._six import inf
-from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device
+from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
import builtins
@@ -90,9 +91,6 @@
# Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ...
-# Defined in torch/csrc/generic/Storage.cpp
-class Storage: ...
-
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
...
@@ -169,6 +167,24 @@
# TODO
...
+# Defined in torch/csrc/jit/python/init.cpp
+class PyTorchFileReader(object):
+ @overload
+ def __init__(self, name: str) -> None: ...
+ @overload
+ def __init__(self, buffer: BinaryIO) -> None: ...
+ def get_record(self, name: str) -> bytes: ...
+ ...
+
+class PyTorchFileWriter(object):
+ @overload
+ def __init__(self, name: str) -> None: ...
+ @overload
+ def __init__(self, buffer: BinaryIO) -> None: ...
+ def write_record(self, name: str, data: bytes, size: _int) -> None: ...
+ def write_end_of_file(self) -> None: ...
+ ...
+
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: _device
diff --git a/torch/__init__.py b/torch/__init__.py
index 4937fba..fd38164 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -311,11 +311,6 @@
"""
return _C._get_deterministic()
-# If you edit these imports, please update torch/__init__.py.in as well
-from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
-from .serialization import save, load
-from ._tensor_str import set_printoptions
-
################################################################################
# Define Storage and Tensor classes
################################################################################
@@ -388,6 +383,10 @@
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes: Set[Type] = set()
+# If you edit these imports, please update torch/__init__.py.in as well
+from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
+from .serialization import save, load
+from ._tensor_str import set_printoptions
################################################################################
# Initialize extension
diff --git a/torch/serialization.py b/torch/serialization.py
index c19a148..62642a3 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,8 @@
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file
+from torch.types import Storage
+from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
import copyreg
import pickle
import pathlib
@@ -26,7 +28,6 @@
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
-
class SourceChangeWarning(Warning):
pass
@@ -41,7 +42,7 @@
_package_registry = []
-def _is_zipfile(f):
+def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or
@@ -160,7 +161,7 @@
register_package(20, _cuda_tag, _cuda_deserialize)
-def location_tag(storage):
+def location_tag(storage: Storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@@ -237,29 +238,30 @@
class _open_zipfile_reader(_opener):
- def __init__(self, name_or_buffer):
+ def __init__(self, name_or_buffer) -> None:
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
- def __init__(self, name):
+ def __init__(self, name) -> None:
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
- def __exit__(self, *args):
+ def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
class _open_zipfile_writer_buffer(_opener):
- def __init__(self, buffer):
+ def __init__(self, buffer) -> None:
self.buffer = buffer
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
- def __exit__(self, *args):
+ def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
self.buffer.flush()
def _open_zipfile_writer(name_or_buffer):
+ container: Type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
@@ -267,7 +269,7 @@
return container(name_or_buffer)
-def _is_compressed_file(f):
+def _is_compressed_file(f) -> bool:
compress_modules = ['gzip']
try:
return f.__module__ in compress_modules
@@ -291,7 +293,7 @@
return False
-def _check_seekable(f):
+def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e):
for p in patterns:
@@ -307,8 +309,9 @@
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
+ return False
-def _check_dill_version(pickle_module):
+def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised.
@@ -327,7 +330,8 @@
pickle_module.__version__
))
-def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
+def save(obj, f: Union[str, os.PathLike, BinaryIO],
+ pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
@@ -370,12 +374,12 @@
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
-def _legacy_save(obj, f, pickle_module, pickle_protocol):
+def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
- def persistent_id(obj):
+ def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
@@ -396,6 +400,8 @@
return ('module', obj, source_file, source)
elif torch.is_storage(obj):
+ view_metadata: Optional[Tuple[str, int, int]]
+ obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj))
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
@@ -589,20 +595,20 @@
def _get_layout(name):
"""Get layout extension object from its string representation.
"""
- cache = _get_layout.cache
+ cache = _get_layout.cache # type: ignore[attr-defined]
if not cache:
for v in torch.__dict__.values():
if isinstance(v, torch.layout):
cache[str(v)] = v
return cache[name]
-
-_get_layout.cache = {}
+# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
+_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
- deserialized_objects = {}
+ deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location)
@@ -648,7 +654,7 @@
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
- deserialized_objects = {}
+ deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
@@ -777,7 +783,7 @@
return result
-def _maybe_decode_ascii(bytes_str):
+def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.
diff --git a/torch/types.py b/torch/types.py
index be86dfd..ef3c68e 100644
--- a/torch/types.py
+++ b/torch/types.py
@@ -1,5 +1,5 @@
import torch
-from typing import Union, Sequence, List, Tuple
+from typing import Any, List, Sequence, Tuple, Union
import builtins
@@ -29,3 +29,15 @@
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, None]
+
+# Storage protocol implemented by ${Type}StorageBase classes
+class Storage(object):
+ _cdata: int
+
+ def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
+ ...
+
+ def size(self) -> int:
+ ...
+
+ ...