Type-annotate serialization.py (#40862)

Summary:
Move Storage class from __init__.pyi.in to types.py and make it a protocol, since this is not a real class
Expose `PyTorchFileReader` and `PyTorchFileWriter` native classes

Ignore function attributes, as there are yet no good way to type annotate those, see https://github.com/python/mypy/issues/2087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40862

Differential Revision: D22344743

Pulled By: malfet

fbshipit-source-id: 95cdb6f980ee79383960f306223e170c63df3232
diff --git a/torch/serialization.py b/torch/serialization.py
index c19a148..62642a3 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,8 @@
 from ._utils import _import_dotted_name
 from ._six import string_classes as _string_classes
 from torch._utils_internal import get_source_lines_and_file
+from torch.types import Storage
+from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union
 import copyreg
 import pickle
 import pathlib
@@ -26,7 +28,6 @@
 PROTOCOL_VERSION = 1001
 STORAGE_KEY_SEPARATOR = ','
 
-
 class SourceChangeWarning(Warning):
     pass
 
@@ -41,7 +42,7 @@
 _package_registry = []
 
 
-def _is_zipfile(f):
+def _is_zipfile(f) -> bool:
     # This is a stricter implementation than zipfile.is_zipfile().
     # zipfile.is_zipfile() is True if the magic number appears anywhere in the
     # binary. Since we expect the files here to be generated by torch.save or
@@ -160,7 +161,7 @@
 register_package(20, _cuda_tag, _cuda_deserialize)
 
 
-def location_tag(storage):
+def location_tag(storage: Storage):
     for _, tagger, _ in _package_registry:
         location = tagger(storage)
         if location:
@@ -237,29 +238,30 @@
 
 
 class _open_zipfile_reader(_opener):
-    def __init__(self, name_or_buffer):
+    def __init__(self, name_or_buffer) -> None:
         super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
 
 
 class _open_zipfile_writer_file(_opener):
-    def __init__(self, name):
+    def __init__(self, name) -> None:
         super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
 
-    def __exit__(self, *args):
+    def __exit__(self, *args) -> None:
         self.file_like.write_end_of_file()
 
 
 class _open_zipfile_writer_buffer(_opener):
-    def __init__(self, buffer):
+    def __init__(self, buffer) -> None:
         self.buffer = buffer
         super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
 
-    def __exit__(self, *args):
+    def __exit__(self, *args) -> None:
         self.file_like.write_end_of_file()
         self.buffer.flush()
 
 
 def _open_zipfile_writer(name_or_buffer):
+    container: Type[_opener]
     if _is_path(name_or_buffer):
         container = _open_zipfile_writer_file
     else:
@@ -267,7 +269,7 @@
     return container(name_or_buffer)
 
 
-def _is_compressed_file(f):
+def _is_compressed_file(f) -> bool:
     compress_modules = ['gzip']
     try:
         return f.__module__ in compress_modules
@@ -291,7 +293,7 @@
         return False
 
 
-def _check_seekable(f):
+def _check_seekable(f) -> bool:
 
     def raise_err_msg(patterns, e):
         for p in patterns:
@@ -307,8 +309,9 @@
         return True
     except (io.UnsupportedOperation, AttributeError) as e:
         raise_err_msg(["seek", "tell"], e)
+    return False
 
-def _check_dill_version(pickle_module):
+def _check_dill_version(pickle_module) -> None:
     '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
     If dill version is lower than 0.3.1, a ValueError is raised.
 
@@ -327,7 +330,8 @@
                 pickle_module.__version__
             ))
 
-def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True):
+def save(obj, f: Union[str, os.PathLike, BinaryIO],
+         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
     """Saves an object to a disk file.
 
     See also: :ref:`recommend-saving-models`
@@ -370,12 +374,12 @@
         _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
 
 
-def _legacy_save(obj, f, pickle_module, pickle_protocol):
+def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
     import torch.nn as nn
     serialized_container_types = {}
     serialized_storages = {}
 
-    def persistent_id(obj):
+    def persistent_id(obj: Any) -> Optional[Tuple]:
         # FIXME: the docs say that persistent_id should only return a string
         # but torch store returns tuples. This works only in the binary protocol
         # see
@@ -396,6 +400,8 @@
             return ('module', obj, source_file, source)
 
         elif torch.is_storage(obj):
+            view_metadata: Optional[Tuple[str, int, int]]
+            obj = cast(Storage, obj)
             storage_type = normalize_storage_type(type(obj))
             # Offset is always 0, but we keep it for backwards compatibility
             # with the old serialization format (which supported storage views)
@@ -589,20 +595,20 @@
 def _get_layout(name):
     """Get layout extension object from its string representation.
     """
-    cache = _get_layout.cache
+    cache = _get_layout.cache   # type: ignore[attr-defined]
     if not cache:
         for v in torch.__dict__.values():
             if isinstance(v, torch.layout):
                 cache[str(v)] = v
     return cache[name]
 
-
-_get_layout.cache = {}
+# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
+_get_layout.cache = {}   # type: ignore[attr-defined]
 copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
 
 
 def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
-    deserialized_objects = {}
+    deserialized_objects: Dict[int, Any] = {}
 
     restore_location = _get_restore_location(map_location)
 
@@ -648,7 +654,7 @@
             warnings.warn(msg, SourceChangeWarning)
 
     def legacy_load(f):
-        deserialized_objects = {}
+        deserialized_objects: Dict[int, Any] = {}
 
         def persistent_load(saved_id):
             if isinstance(saved_id, tuple):
@@ -777,7 +783,7 @@
     return result
 
 
-def _maybe_decode_ascii(bytes_str):
+def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
     # When using encoding='bytes' in Py3, some **internal** keys stored as
     # strings in Py2 are loaded as bytes. This function decodes them with
     # ascii encoding, one that Py3 uses by default.