Fix torch.hub for new zipfile format. (#42333)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/42239

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42333

Reviewed By: VitalyFedyunin

Differential Revision: D23215210

Pulled By: ailzhang

fbshipit-source-id: 161ead8b457c11655dd2cab5eecfd0edf7ae5c2b
diff --git a/test/test_utils.py b/test/test_utils.py
index 19b16c1..fea4758 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -571,6 +571,18 @@
         self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
                          SUM_OF_HUB_EXAMPLE)
 
+    # Test the default zipfile serialization format produced by >=1.6 release.
+    @retry(URLError, tries=3, skip_after_retries=True)
+    def test_load_zip_1_6_checkpoint(self):
+        hub_model = hub.load(
+            'ailzhang/torchhub_example',
+            'mnist_zip_1_6',
+            pretrained=True,
+            verbose=False)
+        self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
+                         SUM_OF_HUB_EXAMPLE)
+
+
     def test_hub_dir(self):
         with tempfile.TemporaryDirectory('hub_dir') as dirname:
             torch.hub.set_dir(dirname)
diff --git a/torch/hub.py b/torch/hub.py
index b7a9ef8..23a7456 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -422,6 +422,31 @@
             _download_url_to_file will be removed in after 1.3 release')
     download_url_to_file(url, dst, hash_prefix, progress)
 
+# Hub used to support automatically extracts from zipfile manually compressed by users.
+# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
+# We should remove this support since zipfile is now default zipfile format for torch.save().
+def _is_legacy_zip_format(filename):
+    if zipfile.is_zipfile(filename):
+        infolist = zipfile.ZipFile(filename).infolist()
+        return len(infolist) == 1 and not infolist[0].is_dir()
+    return False
+
+def _legacy_zip_load(filename, model_dir, map_location):
+    warnings.warn('Falling back to the old format < 1.6. This support will be '
+                  'deprecated in favor of default zipfile format introduced in 1.6. '
+                  'Please redo torch.save() to save it in the new zipfile format.')
+    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
+    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
+    #       E.g. resnet18-5c106cde.pth which is widely used.
+    with zipfile.ZipFile(filename) as f:
+        members = f.infolist()
+        if len(members) != 1:
+            raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
+        f.extractall(model_dir)
+        extraced_name = members[0].filename
+        extracted_file = os.path.join(model_dir, extraced_name)
+    return torch.load(extracted_file, map_location=map_location)
+
 def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
     r"""Loads the Torch serialized object at the given URL.
 
@@ -481,16 +506,6 @@
             hash_prefix = r.group(1) if r else None
         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
 
-    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
-    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
-    #       E.g. resnet18-5c106cde.pth which is widely used.
-    if zipfile.is_zipfile(cached_file):
-        with zipfile.ZipFile(cached_file) as cached_zipfile:
-            members = cached_zipfile.infolist()
-            if len(members) != 1:
-                raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
-            cached_zipfile.extractall(model_dir)
-            extraced_name = members[0].filename
-            cached_file = os.path.join(model_dir, extraced_name)
-
+    if _is_legacy_zip_format(cached_file):
+        return _legacy_zip_load(cached_file, model_dir, map_location)
     return torch.load(cached_file, map_location=map_location)