Make upgrader test model generation more robust (#72030)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72030
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D33863263
Pulled By: tugsbayasgalan
fbshipit-source-id: 931578848ba530583008be6540003b2dcf4d55ce
(cherry picked from commit 67cd085104631264eb12c2c808eb4ed7b973a652)
diff --git a/test/jit/fixtures_srcs/generate_models.py b/test/jit/fixtures_srcs/generate_models.py
index 095154c..d4c554d 100644
--- a/test/jit/fixtures_srcs/generate_models.py
+++ b/test/jit/fixtures_srcs/generate_models.py
@@ -135,8 +135,12 @@
torch.jit.save(script_module, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
- version = int(zipped_model.read('archive/version').decode("utf-8"))
- return version
+ try:
+ version = int(zipped_model.read('archive/version').decode("utf-8"))
+ return version
+ except KeyError:
+ version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
+ return version
"""
Loop through all test modules. If the corresponding model doesn't exist in
@@ -158,9 +162,6 @@
all_models = get_all_models(model_directory_path)
for a_module, expect_operator in ALL_MODULES.items():
print(a_module, expect_operator)
- script_module = torch.jit.script(a_module)
- actual_model_version = get_output_model_version(script_module)
-
# For example: TestVersionedDivTensorExampleV7
torch_module_name = type(a_module).__name__
@@ -169,11 +170,16 @@
'_' + char.lower() if char.isupper() else char for char in torch_module_name
]).lstrip('_')
+ # Some models may not compile anymore, so skip the ones
+ # that already has pt file for them.
logger.info(f"Processing {torch_module_name}")
if model_exist(model_name, all_models):
logger.info(f"Model {model_name} already exists, skipping")
continue
+ script_module = torch.jit.script(a_module)
+ actual_model_version = get_output_model_version(script_module)
+
current_operator_version = torch._C._get_max_operator_version()
if actual_model_version >= current_operator_version + 1:
logger.error(