Improve disable name match (#71499)
Summary:
Allows disabling issues to disable all parametrized tests with dtypes.
Tested locally with:
1. .pytorch-disabled-tests.json as
```
{"test_bitwise_ops (__main__.TestBinaryUfuncs)": ["https://github.com/pytorch/pytorch/issues/99999", ["mac"]]}
```
and running `python test_binary_ufuncs.py --import-disabled-tests -k test_bitwise_ops` yields all tests skipped.
2. .pytorch-disabled-tests.json as
```
{"test_bitwise_ops_cpu_int16 (__main__.TestBinaryUfuncsCPU)": ["https://github.com/pytorch/pytorch/issues/99999", ["mac"]]}
```
and running `python test_binary_ufuncs.py --import-disabled-tests -k test_bitwise_ops` yields only `test_bitwise_ops_cpu_int16` skipped.
NOTE: this only works with dtype parametrization, not all prefixes, e.g., disabling `test_async_script` would NOT disable `test_async_script_capture`. This is the most intuitive behavior, I believe, but I can be convinced otherwise.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71499
Reviewed By: mruberry
Differential Revision: D33742723
Pulled By: janeyx99
fbshipit-source-id: 98a84f9e80402978fa8d22e0f018e6c6c4339a72
(cherry picked from commit 3f778919caebd3f5cae13963b4824088543e2311)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 1bf7413..41e15b4 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -61,6 +61,7 @@
from statistics import mean
import functools
from .composite_compliance import no_dispatch
+from torch.testing._internal.common_dtype import get_all_dtypes
torch.backends.disable_global_flags()
@@ -1370,31 +1371,57 @@
except ImportError:
print('Fail to import hypothesis in common_utils, tests are not derandomized')
+# Used in check_if_enable to see if a test method should be disabled by an issue,
+# sanitizes a test method name from appended suffixes by @dtypes parametrization.
+# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
+# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
+def remove_device_and_dtype_suffixes(test_name: str) -> str:
+ # import statement is localized to avoid circular dependency issues with common_device_type.py
+ from torch.testing._internal.common_device_type import get_device_type_test_bases
+ device_suffixes = [x.device_type for x in get_device_type_test_bases()]
+ dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
+
+ test_name_chunks = test_name.split("_")
+ if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
+ if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
+ return "_".join(test_name_chunks[0:-2])
+ return "_".join(test_name_chunks[0:-1])
+ return test_name
+
+
def check_if_enable(test: unittest.TestCase):
test_suite = str(test.__class__).split('\'')[1]
- test_name = f'{test._testMethodName} ({test_suite})'
- if slow_tests_dict is not None and test_name in slow_tests_dict:
+ raw_test_name = f'{test._testMethodName} ({test_suite})'
+ if slow_tests_dict is not None and raw_test_name in slow_tests_dict:
getattr(test, test._testMethodName).__dict__['slow_test'] = True
if not TEST_WITH_SLOW:
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+ sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
if not IS_SANDCASTLE and disabled_tests_dict is not None:
- if test_name in disabled_tests_dict:
- issue_url, platforms = disabled_tests_dict[test_name]
- platform_to_conditional: Dict = {
- "mac": IS_MACOS,
- "macos": IS_MACOS,
- "win": IS_WINDOWS,
- "windows": IS_WINDOWS,
- "linux": IS_LINUX,
- "rocm": TEST_WITH_ROCM,
- "asan": TEST_WITH_ASAN
- }
- if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]):
- raise unittest.SkipTest(
- f"Test is disabled because an issue exists disabling it: {issue_url}" +
- f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " +
- "If you're seeing this on your local machine and would like to enable this test, " +
- "please make sure IN_CI is not set and you are not using the flag --import-disabled-tests.")
+ for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
+ disable_test_parts = disabled_test.split()
+ if len(disable_test_parts) > 1:
+ disabled_test_name = disable_test_parts[0]
+ disabled_test_suite = disable_test_parts[1][1:-1]
+ # if test method name or its sanitized version exactly matches the disabled test method name
+ # AND allow non-parametrized suite names to disable parametrized ones (TestSuite disables TestSuiteCPU)
+ if (test._testMethodName == disabled_test_name or sanitized_test_method_name == disabled_test_name) \
+ and disabled_test_suite in test_suite:
+ platform_to_conditional: Dict = {
+ "mac": IS_MACOS,
+ "macos": IS_MACOS,
+ "win": IS_WINDOWS,
+ "windows": IS_WINDOWS,
+ "linux": IS_LINUX,
+ "rocm": TEST_WITH_ROCM,
+ "asan": TEST_WITH_ASAN
+ }
+ if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]):
+ skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
+ f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
+ "If you're seeing this on your local machine and would like to enable this test, " \
+ "please make sure IN_CI is not set and you are not using the flag --import-disabled-tests."
+ raise unittest.SkipTest(skip_msg)
if TEST_SKIP_FAST:
if not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")