Type-annotate serialization.py (#40862)

Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes

Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862

Differential Revision: D22344743

Pulled By: malfet

fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
diff --git a/mypy.ini b/mypy.ini
index 4ef24b6..0583158 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -317,9 +317,6 @@
 [mypy-torch.nn.functional]
 ignore_errors = True
 
-[mypy-torch.serialization]
-ignore_errors = True
-
 [mypy-torch.utils]
 ignore_errors = True
 
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index e043621..1ddc9d7 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -2,10 +2,11 @@
 
 import torch
 from torch import Tensor
-from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
+from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
+        Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
 from torch._six import inf
 
-from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number, Device
+from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
 
 import builtins
 
@@ -90,9 +91,6 @@
 # Defined in torch/csrc/utils/tensor_qschemes.cpp
 per_tensor_affine: qscheme = ...
 
-# Defined in torch/csrc/generic/Storage.cpp
-class Storage: ...
-
 # Defined in torch/csrc/autograd/python_function.cpp
 class _FunctionBase(object):
     ...
@@ -169,6 +167,24 @@
     # TODO
     ...
 
+# Defined in torch/csrc/jit/python/init.cpp
+class PyTorchFileReader(object):
+    @overload
+    def __init__(self, name: str) -> None: ...
+    @overload
+    def __init__(self, buffer: BinaryIO) -> None: ...
+    def get_record(self, name: str) -> bytes: ...
+    ...
+
+class PyTorchFileWriter(object):
+    @overload
+    def __init__(self, name: str) -> None: ...
+    @overload
+    def __init__(self, buffer: BinaryIO) -> None: ...
+    def write_record(self, name: str, data: bytes, size: _int) -> None: ...
+    def write_end_of_file(self) -> None: ...
+    ...
+
 # Defined in torch/csrc/Generator.cpp
 class Generator(object):
     device: _device
diff --git a/torch/__init__.py b/torch/__init__.py
index 4937fba..fd38164 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -311,11 +311,6 @@
     """
     return _C._get_deterministic()
 
-# If you edit these imports, please update torch/__init__.py.in as well
-from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
-from .serialization import save, load
-from ._tensor_str import set_printoptions
-
 ################################################################################
 # Define Storage and Tensor classes
 ################################################################################
@@ -388,6 +383,10 @@
 # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
 _tensor_classes: Set[Type] = set()
 
+# If you edit these imports, please update torch/__init__.py.in as well
+from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
+from .serialization import save, load
+from ._tensor_str import set_printoptions
 
 ################################################################################
 # Initialize extension
diff --git a/torch/serialization.py b/torch/serialization.py
index c19a148..62642a3 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,8 @@
 from ._utils import _import_dotted_name
 from ._six import string_classes as _string_classes
 from torch._utils_internal import get_source_lines_and_file
+from torch.types import Storage
+from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
 import copyreg
 import pickle
 import pathlib
@@ -26,7 +28,6 @@
 PROTOCOL_VERSION = 1001
 STORAGE_KEY_SEPARATOR = ','
 
-
 class SourceChangeWarning(Warning):
     pass
 
@@ -41,7 +42,7 @@
 _package_registry = []
 
 
-def _is_zipfile(f):
+def _is_zipfile(f) -> bool:
     # This is a stricter implementation than zipfile.is_zipfile().
     # zipfile.is_zipfile() is True if the magic number appears anywhere in the
     # binary. Since we expect the files here to be generated by torch.save or
@@ -160,7 +161,7 @@
 register_package(20, _cuda_tag, _cuda_deserialize)
 
 
-def location_tag(storage):
+def location_tag(storage: Storage):
     for _, tagger, _ in _package_registry:
         location = tagger(storage)
         if location:
@@ -237,29 +238,30 @@
 
 
 class _open_zipfile_reader(_opener):
-    def __init__(self, name_or_buffer):
+    def __init__(self, name_or_buffer) -> None:
         super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
 
 
 class _open_zipfile_writer_file(_opener):
-    def __init__(self, name):
+    def __init__(self, name) -> None:
         super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
 
-    def __exit__(self, *args):
+    def __exit__(self, *args) -> None:
         self.file_like.write_end_of_file()
 
 
 class _open_zipfile_writer_buffer(_opener):
-    def __init__(self, buffer):
+    def __init__(self, buffer) -> None:
         self.buffer = buffer
         super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
 
-    def __exit__(self, *args):
+    def __exit__(self, *args) -> None:
         self.file_like.write_end_of_file()
         self.buffer.flush()
 
 
 def _open_zipfile_writer(name_or_buffer):
+    container: Type[_opener]
     if _is_path(name_or_buffer):
         container = _open_zipfile_writer_file
     else:
@@ -267,7 +269,7 @@
     return container(name_or_buffer)
 
 
-def _is_compressed_file(f):
+def _is_compressed_file(f) -> bool:
     compress_modules = ['gzip']
     try:
         return f.__module__ in compress_modules
@@ -291,7 +293,7 @@
         return False
 
 
-def _check_seekable(f):
+def _check_seekable(f) -> bool:
 
     def raise_err_msg(patterns, e):
         for p in patterns:
@@ -307,8 +309,9 @@
         return True
     except (io.UnsupportedOperation, AttributeError) as e:
         raise_err_msg(["seek", "tell"], e)
+    return False
 
-def _check_dill_version(pickle_module):
+def _check_dill_version(pickle_module) -> None:
     '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
     If dill version is lower than 0.3.1, a ValueError is raised.
 
@@ -327,7 +330,8 @@
                 pickle_module.__version__
             ))
 
-def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
+def save(obj, f: Union[str, os.PathLike, BinaryIO],
+         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
     """Saves an object to a disk file.
 
     See also: :ref:`recommend-saving-models`
@@ -370,12 +374,12 @@
         _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
 
 
-def _legacy_save(obj, f, pickle_module, pickle_protocol):
+def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
     import torch.nn as nn
     serialized_container_types = {}
     serialized_storages = {}
 
-    def persistent_id(obj):
+    def persistent_id(obj: Any) -> Optional[Tuple]:
         # FIXME: the docs say that persistent_id should only return a string
         # but torch store returns tuples. This works only in the binary protocol
         # see
@@ -396,6 +400,8 @@
             return ('module', obj, source_file, source)
 
         elif torch.is_storage(obj):
+            view_metadata: Optional[Tuple[str, int, int]]
+            obj = cast(Storage, obj)
             storage_type = normalize_storage_type(type(obj))
             # Offset is always 0, but we keep it for backwards compatibility
             # with the old serialization format (which supported storage views)
@@ -589,20 +595,20 @@
 def _get_layout(name):
     """Get layout extension object from its string representation.
     """
-    cache = _get_layout.cache
+    cache = _get_layout.cache   # type: ignore[attr-defined]
     if not cache:
         for v in torch.__dict__.values():
             if isinstance(v, torch.layout):
                 cache[str(v)] = v
     return cache[name]
 
-
-_get_layout.cache = {}
+# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
+_get_layout.cache = {}   # type: ignore[attr-defined]
 copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
 
 
 def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
-    deserialized_objects = {}
+    deserialized_objects: Dict[int, Any] = {}
 
     restore_location = _get_restore_location(map_location)
 
@@ -648,7 +654,7 @@
             warnings.warn(msg, SourceChangeWarning)
 
     def legacy_load(f):
-        deserialized_objects = {}
+        deserialized_objects: Dict[int, Any] = {}
 
         def persistent_load(saved_id):
             if isinstance(saved_id, tuple):
@@ -777,7 +783,7 @@
     return result
 
 
-def _maybe_decode_ascii(bytes_str):
+def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
     # When using encoding='bytes' in Py3, some **internal** keys stored as
     # strings in Py2 are loaded as bytes. This function decodes them with
     # ascii encoding, one that Py3 uses by default.
diff --git a/torch/types.py b/torch/types.py
index be86dfd..ef3c68e 100644
--- a/torch/types.py
+++ b/torch/types.py
@@ -1,5 +1,5 @@
 import torch
-from typing import Union, Sequence, List, Tuple
+from typing import Any, List, Sequence, Tuple, Union
 
 import builtins
 
@@ -29,3 +29,15 @@
 # literal device object).  This nomenclature is consistent with PythonArgParser.
 # None means use the default device (typically CPU)
 Device = Union[_device, str, None]
+
+# Storage protocol implemented by ${Type}StorageBase classes
+class Storage(object):
+    _cdata: int
+
+    def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
+        ...
+
+    def size(self) -> int:
+        ...
+
+    ...