Fix torch.hub for new zipfile format. (#42333)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42239
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42333
Reviewed By: VitalyFedyunin
Differential Revision: D23215210
Pulled By: ailzhang
fbshipit-source-id: 161ead8b457c11655dd2cab5eecfd0edf7ae5c2b
diff --git a/test/test_utils.py b/test/test_utils.py
index 19b16c1..fea4758 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -571,6 +571,18 @@
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
+ # Test the default zipfile serialization format produced by >=1.6 release.
+ @retry(URLError, tries=3, skip_after_retries=True)
+ def test_load_zip_1_6_checkpoint(self):
+ hub_model = hub.load(
+ 'ailzhang/torchhub_example',
+ 'mnist_zip_1_6',
+ pretrained=True,
+ verbose=False)
+ self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
+ SUM_OF_HUB_EXAMPLE)
+
+
def test_hub_dir(self):
with tempfile.TemporaryDirectory('hub_dir') as dirname:
torch.hub.set_dir(dirname)
diff --git a/torch/hub.py b/torch/hub.py
index b7a9ef8..23a7456 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -422,6 +422,31 @@
_download_url_to_file will be removed in after 1.3 release')
download_url_to_file(url, dst, hash_prefix, progress)
+# Hub used to support automatically extracts from zipfile manually compressed by users.
+# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
+# We should remove this support since zipfile is now default zipfile format for torch.save().
+def _is_legacy_zip_format(filename):
+ if zipfile.is_zipfile(filename):
+ infolist = zipfile.ZipFile(filename).infolist()
+ return len(infolist) == 1 and not infolist[0].is_dir()
+ return False
+
+def _legacy_zip_load(filename, model_dir, map_location):
+ warnings.warn('Falling back to the old format < 1.6. This support will be '
+ 'deprecated in favor of default zipfile format introduced in 1.6. '
+ 'Please redo torch.save() to save it in the new zipfile format.')
+ # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
+ # We deliberately don't handle tarfile here since our legacy serialization format was in tar.
+ # E.g. resnet18-5c106cde.pth which is widely used.
+ with zipfile.ZipFile(filename) as f:
+ members = f.infolist()
+ if len(members) != 1:
+ raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
+ f.extractall(model_dir)
+ extraced_name = members[0].filename
+ extracted_file = os.path.join(model_dir, extraced_name)
+ return torch.load(extracted_file, map_location=map_location)
+
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
r"""Loads the Torch serialized object at the given URL.
@@ -481,16 +506,6 @@
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
- # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
- # We deliberately don't handle tarfile here since our legacy serialization format was in tar.
- # E.g. resnet18-5c106cde.pth which is widely used.
- if zipfile.is_zipfile(cached_file):
- with zipfile.ZipFile(cached_file) as cached_zipfile:
- members = cached_zipfile.infolist()
- if len(members) != 1:
- raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
- cached_zipfile.extractall(model_dir)
- extraced_name = members[0].filename
- cached_file = os.path.join(model_dir, extraced_name)
-
+ if _is_legacy_zip_format(cached_file):
+ return _legacy_zip_load(cached_file, model_dir, map_location)
return torch.load(cached_file, map_location=map_location)