Improve disable name match (#71499)

Summary:
Allows disabling issues to disable all parametrized tests with dtypes.

Tested locally with:
1. .pytorch-disabled-tests.json as
```
{"test_bitwise_ops (__main__.TestBinaryUfuncs)": ["https://github.com/pytorch/pytorch/issues/99999", ["mac"]]}
```
and running `python test_binary_ufuncs.py --import-disabled-tests -k test_bitwise_ops` yields all tests skipped.

2. .pytorch-disabled-tests.json as
```
{"test_bitwise_ops_cpu_int16 (__main__.TestBinaryUfuncsCPU)": ["https://github.com/pytorch/pytorch/issues/99999", ["mac"]]}
```
and running `python test_binary_ufuncs.py --import-disabled-tests -k test_bitwise_ops` yields only `test_bitwise_ops_cpu_int16` skipped.

NOTE: this only works with dtype parametrization, not all prefixes, e.g., disabling `test_async_script` would NOT disable `test_async_script_capture`. This is the most intuitive behavior, I believe, but I can be convinced otherwise.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71499

Reviewed By: mruberry

Differential Revision: D33742723

Pulled By: janeyx99

fbshipit-source-id: 98a84f9e80402978fa8d22e0f018e6c6c4339a72
(cherry picked from commit 3f778919caebd3f5cae13963b4824088543e2311)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 1bf7413..41e15b4 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -61,6 +61,7 @@
 from statistics import mean
 import functools
 from .composite_compliance import no_dispatch
+from torch.testing._internal.common_dtype import get_all_dtypes
 
 torch.backends.disable_global_flags()
 
@@ -1370,31 +1371,57 @@
 except ImportError:
     print('Fail to import hypothesis in common_utils, tests are not derandomized')
 
+# Used in check_if_enable to see if a test method should be disabled by an issue,
+# sanitizes a test method name from appended suffixes by @dtypes parametrization.
+# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
+# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
+def remove_device_and_dtype_suffixes(test_name: str) -> str:
+    # import statement is localized to avoid circular dependency issues with common_device_type.py
+    from torch.testing._internal.common_device_type import get_device_type_test_bases
+    device_suffixes = [x.device_type for x in get_device_type_test_bases()]
+    dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
+
+    test_name_chunks = test_name.split("_")
+    if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
+        if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
+            return "_".join(test_name_chunks[0:-2])
+        return "_".join(test_name_chunks[0:-1])
+    return test_name
+
+
 def check_if_enable(test: unittest.TestCase):
     test_suite = str(test.__class__).split('\'')[1]
-    test_name = f'{test._testMethodName} ({test_suite})'
-    if slow_tests_dict is not None and test_name in slow_tests_dict:
+    raw_test_name = f'{test._testMethodName} ({test_suite})'
+    if slow_tests_dict is not None and raw_test_name in slow_tests_dict:
         getattr(test, test._testMethodName).__dict__['slow_test'] = True
         if not TEST_WITH_SLOW:
             raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+    sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
     if not IS_SANDCASTLE and disabled_tests_dict is not None:
-        if test_name in disabled_tests_dict:
-            issue_url, platforms = disabled_tests_dict[test_name]
-            platform_to_conditional: Dict = {
-                "mac": IS_MACOS,
-                "macos": IS_MACOS,
-                "win": IS_WINDOWS,
-                "windows": IS_WINDOWS,
-                "linux": IS_LINUX,
-                "rocm": TEST_WITH_ROCM,
-                "asan": TEST_WITH_ASAN
-            }
-            if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]):
-                raise unittest.SkipTest(
-                    f"Test is disabled because an issue exists disabling it: {issue_url}" +
-                    f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " +
-                    "If you're seeing this on your local machine and would like to enable this test, " +
-                    "please make sure IN_CI is not set and you are not using the flag --import-disabled-tests.")
+        for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
+            disable_test_parts = disabled_test.split()
+            if len(disable_test_parts) > 1:
+                disabled_test_name = disable_test_parts[0]
+                disabled_test_suite = disable_test_parts[1][1:-1]
+                # if test method name or its sanitized version exactly matches the disabled test method name
+                # AND allow non-parametrized suite names to disable parametrized ones (TestSuite disables TestSuiteCPU)
+                if (test._testMethodName == disabled_test_name or sanitized_test_method_name == disabled_test_name) \
+                   and disabled_test_suite in test_suite:
+                    platform_to_conditional: Dict = {
+                        "mac": IS_MACOS,
+                        "macos": IS_MACOS,
+                        "win": IS_WINDOWS,
+                        "windows": IS_WINDOWS,
+                        "linux": IS_LINUX,
+                        "rocm": TEST_WITH_ROCM,
+                        "asan": TEST_WITH_ASAN
+                    }
+                    if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]):
+                        skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
+                            f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
+                            "If you're seeing this on your local machine and would like to enable this test, " \
+                            "please make sure IN_CI is not set and you are not using the flag --import-disabled-tests."
+                        raise unittest.SkipTest(skip_msg)
     if TEST_SKIP_FAST:
         if not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
             raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")