Generic test parametrization functionality (#60753)

Summary:
This PR plays around with implementation & usage of a `parametrize` decorator for test parametrization similar to `pytest.mark.parametrize`, based on previous work introducing a `_TestParametrizer` class. It works with the internal `DeviceTest` hierarchy & composes with `dtype`, `skip*`, and other decorators. Basic usage is demonstrated in `test/test_blah.py`:

```python
import unittest
from itertools import product
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests, deviceCountAtLeast, ops)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
    TestCase, run_tests, parametrize, instantiate_parametrized_tests, subtest)

class TestBlah(TestCase):
    parametrize("x", range(5))
    def test_default_names(self, x):
        print('Passed in:', x)

    # Use default names but add an expected failure.
    parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
                       *range(1, 5)])
    def test_default_names_expected_failure(self, x):
        if x == 0:
            raise RuntimeError('Boom')
        print('Passed in:', x)

    parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
    def test_custom_names(self, bias):
        print('Passed in:', bias)

    parametrize("bias", [subtest(True, name='bias'),
                          subtest(False, name='no_bias')])
    def test_custom_names_alternate(self, bias):
        print('Passed in:', bias)

    parametrize("x,y", [(1, 2), (1, 3), (1, 4)])
    def test_two_things_default_names(self, x, y):
        print('Passed in:', x, y)

    parametrize("x", [1, 2, 3])
    parametrize("y", [4, 5, 6])
    def test_two_things_composition(self, x, y):
        print('Passed in:', x, y)

    parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
                       *range(1, 3)])
    parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
    def test_two_things_composition_expected_failure(self, x, y):
        if x == 0 or y == 6:
            raise RuntimeError('Boom')
        print('Passed in:', x, y)

    parametrize("x", [1, 2])
    parametrize("y", [3, 4])
    parametrize("z", [5, 6])
    def test_three_things_composition(self, x, y, z):
        print('Passed in:', x, y, z)

    parametrize("x", [1, 2], name_fn=str)
    parametrize("y", [3, 4], name_fn=str)
    parametrize("z", [5, 6], name_fn=str)
    def test_three_things_composition_custom_names(self, x, y, z):
        print('Passed in:', x, y, z)

    parametrize("x,y", product(range(2), range(3)))
    def test_two_things_product(self, x, y):
        print('Passed in:', x, y)

    parametrize("x,y", [subtest((1, 2), name='double'),
                         subtest((1, 3), name='triple'),
                         subtest((1, 4), name='quadruple')])
    def test_two_things_custom_names(self, x, y):
        print('Passed in:', x, y)

    parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}_{}'.format(x, y))
    def test_two_things_custom_names_alternate(self, x, y):
        print('Passed in:', x, y)

class TestDeviceBlah(TestCase):
    parametrize("x", range(10))
    def test_default_names(self, device, x):
        print('Passed in:', device, x)

    parametrize("x,y", [(1, 2), (3, 4), (5, 6)])
    def test_two_things(self, device, x, y):
        print('Passed in:', device, x, y)

    deviceCountAtLeast(1)
    def test_multiple_devices(self, devices):
        print('Passed in:', devices)

    ops(op_db)
    parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
    def test_op_parametrized(self, device, dtype, op, flag):
        print('Passed in:', device, dtype, op, flag)

instantiate_parametrized_tests(TestBlah)
instantiate_device_type_tests(TestDeviceBlah, globals())

if __name__ == '__main__':
    run_tests()
```

Generated tests:
```
TestBlah.test_custom_names_alternate_bias
TestBlah.test_custom_names_alternate_no_bias
TestBlah.test_custom_names_bias
TestBlah.test_custom_names_no_bias
TestBlah.test_default_names_expected_failure_x_0
TestBlah.test_default_names_expected_failure_x_1
TestBlah.test_default_names_expected_failure_x_2
TestBlah.test_default_names_expected_failure_x_3
TestBlah.test_default_names_expected_failure_x_4
TestBlah.test_default_names_x_0
TestBlah.test_default_names_x_1
TestBlah.test_default_names_x_2
TestBlah.test_default_names_x_3
TestBlah.test_default_names_x_4
TestBlah.test_three_things_composition_custom_names_1_3_5
TestBlah.test_three_things_composition_custom_names_1_3_6
TestBlah.test_three_things_composition_custom_names_1_4_5
TestBlah.test_three_things_composition_custom_names_1_4_6
TestBlah.test_three_things_composition_custom_names_2_3_5
TestBlah.test_three_things_composition_custom_names_2_3_6
TestBlah.test_three_things_composition_custom_names_2_4_5
TestBlah.test_three_things_composition_custom_names_2_4_6
TestBlah.test_three_things_composition_x_1_y_3_z_5
TestBlah.test_three_things_composition_x_1_y_3_z_6
TestBlah.test_three_things_composition_x_1_y_4_z_5
TestBlah.test_three_things_composition_x_1_y_4_z_6
TestBlah.test_three_things_composition_x_2_y_3_z_5
TestBlah.test_three_things_composition_x_2_y_3_z_6
TestBlah.test_three_things_composition_x_2_y_4_z_5
TestBlah.test_three_things_composition_x_2_y_4_z_6
TestBlah.test_two_things_composition_expected_failure_x_0_y_4
TestBlah.test_two_things_composition_expected_failure_x_0_y_5
TestBlah.test_two_things_composition_expected_failure_x_0_y_6
TestBlah.test_two_things_composition_expected_failure_x_1_y_4
TestBlah.test_two_things_composition_expected_failure_x_1_y_5
TestBlah.test_two_things_composition_expected_failure_x_1_y_6
TestBlah.test_two_things_composition_expected_failure_x_2_y_4
TestBlah.test_two_things_composition_expected_failure_x_2_y_5
TestBlah.test_two_things_composition_expected_failure_x_2_y_6
TestBlah.test_two_things_composition_x_1_y_4
TestBlah.test_two_things_composition_x_1_y_5
TestBlah.test_two_things_composition_x_1_y_6
TestBlah.test_two_things_composition_x_2_y_4
TestBlah.test_two_things_composition_x_2_y_5
TestBlah.test_two_things_composition_x_2_y_6
TestBlah.test_two_things_composition_x_3_y_4
TestBlah.test_two_things_composition_x_3_y_5
TestBlah.test_two_things_composition_x_3_y_6
TestBlah.test_two_things_custom_names_alternate_1_2
TestBlah.test_two_things_custom_names_alternate_1_3
TestBlah.test_two_things_custom_names_alternate_1_4
TestBlah.test_two_things_custom_names_double
TestBlah.test_two_things_custom_names_quadruple
TestBlah.test_two_things_custom_names_triple
TestBlah.test_two_things_default_names_x_1_y_2
TestBlah.test_two_things_default_names_x_1_y_3
TestBlah.test_two_things_default_names_x_1_y_4
TestBlah.test_two_things_product_x_0_y_0
TestBlah.test_two_things_product_x_0_y_1
TestBlah.test_two_things_product_x_0_y_2
TestBlah.test_two_things_product_x_1_y_0
TestBlah.test_two_things_product_x_1_y_1
TestBlah.test_two_things_product_x_1_y_2
TestDeviceBlahCPU.test_default_names_x_0_cpu
TestDeviceBlahCPU.test_default_names_x_1_cpu
TestDeviceBlahCPU.test_default_names_x_2_cpu
TestDeviceBlahCPU.test_default_names_x_3_cpu
TestDeviceBlahCPU.test_default_names_x_4_cpu
TestDeviceBlahCPU.test_default_names_x_5_cpu
TestDeviceBlahCPU.test_default_names_x_6_cpu
TestDeviceBlahCPU.test_default_names_x_7_cpu
TestDeviceBlahCPU.test_default_names_x_8_cpu
TestDeviceBlahCPU.test_default_names_x_9_cpu
TestDeviceBlahCPU.test_multiple_devices_cpu
TestDeviceBlahCPU.test_op_parametrized_<opname>_<variant>_cpu_uint8_flag_enabled_cpu
TestDeviceBlahCPU.test_two_things_x_1_y_2_cpu
TestDeviceBlahCPU.test_two_things_x_3_y_4_cpu
TestDeviceBlahCPU.test_two_things_x_5_y_6_cpu
TestDeviceBlahMETA.test_default_names_x_0_meta
TestDeviceBlahMETA.test_default_names_x_1_meta
TestDeviceBlahMETA.test_default_names_x_2_meta
TestDeviceBlahMETA.test_default_names_x_3_meta
TestDeviceBlahMETA.test_default_names_x_4_meta
TestDeviceBlahMETA.test_default_names_x_5_meta
TestDeviceBlahMETA.test_default_names_x_6_meta
TestDeviceBlahMETA.test_default_names_x_7_meta
TestDeviceBlahMETA.test_default_names_x_8_meta
TestDeviceBlahMETA.test_default_names_x_9_meta
TestDeviceBlahMETA.test_multiple_devices_meta
TestDeviceBlahMETA.test_op_parametrized_<opname>_<variant>_meta_uint8_flag_enabled_meta
TestDeviceBlahMETA.test_two_things_x_1_y_2_meta
TestDeviceBlahMETA.test_two_things_x_3_y_4_meta
TestDeviceBlahMETA.test_two_things_x_5_y_6_meta
```

Caveats:
* `parametrize` decorators cannot be "stacked" yet; each one overwrites the previous. This will change to either:
  * Allow stacking of multiple decorators
  * Error out with a nice error message if multiple decorators are specified

The PR introduces `instantiate_parametrized_tests()` in addition to `instantiate_device_type_tests()`. The former should be used for non-device-specific tests, and the latter should be used for device-specific tests, as usual. Both of these support the `parametrize` decorator. Only the latter supports the `ops` decorator (no change here- this was already the case).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60753

Reviewed By: saketh-are

Differential Revision: D30606615

Pulled By: jbschlosser

fbshipit-source-id: a34f36d643f68a6e221f419d9bb3e1ae1d84dd65
diff --git a/test/test_testing.py b/test/test_testing.py
index e45977f..d777587 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -12,11 +12,12 @@
 
 from torch.testing import make_tensor
 from torch.testing._internal.common_utils import \
-    (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest)
+    (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest,
+     parametrize, subtest, instantiate_parametrized_tests, dtype_name)
 from torch.testing._internal.common_device_type import \
     (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
      get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA,
-     deviceCountAtLeast)
+     deviceCountAtLeast, ops)
 from torch.testing._internal.common_methods_invocations import op_db
 import torch.testing._internal.opinfo_helper as opinfo_helper
 from torch.testing._internal.common_dtype import get_all_dtypes
@@ -1425,5 +1426,312 @@
             fn()
 
 
+def _get_test_names_for_test_class(test_cls):
+    """ Convenience function to get all test names for a given test class. """
+    test_names = ['{}.{}'.format(test_cls.__name__, key) for key in test_cls.__dict__
+                  if key.startswith('test_')]
+    return sorted(test_names)
+
+
+class TestTestParametrization(TestCase):
+    def test_default_names(self):
+
+        class TestParametrized(TestCase):
+            @parametrize("x", range(5))
+            def test_default_names(self, x):
+                pass
+
+            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
+            def test_two_things_default_names(self, x, y):
+                pass
+
+        instantiate_parametrized_tests(TestParametrized)
+
+        expected_test_names = [
+            'TestParametrized.test_default_names_x_0',
+            'TestParametrized.test_default_names_x_1',
+            'TestParametrized.test_default_names_x_2',
+            'TestParametrized.test_default_names_x_3',
+            'TestParametrized.test_default_names_x_4',
+            'TestParametrized.test_two_things_default_names_x_1_y_2',
+            'TestParametrized.test_two_things_default_names_x_2_y_3',
+            'TestParametrized.test_two_things_default_names_x_3_y_4',
+        ]
+        test_names = _get_test_names_for_test_class(TestParametrized)
+        self.assertEqual(expected_test_names, test_names)
+
+    def test_name_fn(self):
+
+        class TestParametrized(TestCase):
+            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
+            def test_custom_names(self, bias):
+                pass
+
+            @parametrize("x", [1, 2], name_fn=str)
+            @parametrize("y", [3, 4], name_fn=str)
+            @parametrize("z", [5, 6], name_fn=str)
+            def test_three_things_composition_custom_names(self, x, y, z):
+                pass
+
+            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y))
+            def test_two_things_custom_names_alternate(self, x, y):
+                pass
+
+        instantiate_parametrized_tests(TestParametrized)
+
+        expected_test_names = [
+            'TestParametrized.test_custom_names_bias',
+            'TestParametrized.test_custom_names_no_bias',
+            'TestParametrized.test_three_things_composition_custom_names_1_3_5',
+            'TestParametrized.test_three_things_composition_custom_names_1_3_6',
+            'TestParametrized.test_three_things_composition_custom_names_1_4_5',
+            'TestParametrized.test_three_things_composition_custom_names_1_4_6',
+            'TestParametrized.test_three_things_composition_custom_names_2_3_5',
+            'TestParametrized.test_three_things_composition_custom_names_2_3_6',
+            'TestParametrized.test_three_things_composition_custom_names_2_4_5',
+            'TestParametrized.test_three_things_composition_custom_names_2_4_6',
+            'TestParametrized.test_two_things_custom_names_alternate_1__2',
+            'TestParametrized.test_two_things_custom_names_alternate_1__3',
+            'TestParametrized.test_two_things_custom_names_alternate_1__4',
+        ]
+        test_names = _get_test_names_for_test_class(TestParametrized)
+        self.assertEqual(expected_test_names, test_names)
+
+    def test_subtest_names(self):
+
+        class TestParametrized(TestCase):
+            @parametrize("bias", [subtest(True, name='bias'),
+                                  subtest(False, name='no_bias')])
+            def test_custom_names(self, bias):
+                pass
+
+            @parametrize("x,y", [subtest((1, 2), name='double'),
+                                 subtest((1, 3), name='triple'),
+                                 subtest((1, 4), name='quadruple')])
+            def test_two_things_custom_names(self, x, y):
+                pass
+
+        instantiate_parametrized_tests(TestParametrized)
+
+        expected_test_names = [
+            'TestParametrized.test_custom_names_bias',
+            'TestParametrized.test_custom_names_no_bias',
+            'TestParametrized.test_two_things_custom_names_double',
+            'TestParametrized.test_two_things_custom_names_quadruple',
+            'TestParametrized.test_two_things_custom_names_triple',
+        ]
+        test_names = _get_test_names_for_test_class(TestParametrized)
+        self.assertEqual(expected_test_names, test_names)
+
+    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
+    def test_subtest_expected_failure(self, x):
+        if x == 2:
+            raise RuntimeError('Boom')
+
+    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
+    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
+    def test_two_things_subtest_expected_failure(self, x, y):
+        if x == 1 or y == 6:
+            raise RuntimeError('Boom')
+
+
+class TestTestParametrizationDeviceType(TestCase):
+    def test_unparametrized_names(self, device):
+        # This test exists to protect against regressions in device / dtype test naming
+        # due to parametrization logic.
+
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            def test_device_specific(self, device):
+                pass
+
+            @dtypes(torch.float32, torch.float64)
+            def test_device_dtype_specific(self, device, dtype):
+                pass
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = [name.format(device_cls.__name__, device) for name in (
+            '{}.test_device_dtype_specific_{}_float32',
+            '{}.test_device_dtype_specific_{}_float64',
+            '{}.test_device_specific_{}')
+        ]
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(expected_test_names, test_names)
+
+    def test_default_names(self, device):
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            @parametrize("x", range(5))
+            def test_default_names(self, device, x):
+                pass
+
+            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
+            def test_two_things_default_names(self, device, x, y):
+                pass
+
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = [name.format(device_cls.__name__, device) for name in (
+            '{}.test_default_names_x_0_{}',
+            '{}.test_default_names_x_1_{}',
+            '{}.test_default_names_x_2_{}',
+            '{}.test_default_names_x_3_{}',
+            '{}.test_default_names_x_4_{}',
+            '{}.test_two_things_default_names_x_1_y_2_{}',
+            '{}.test_two_things_default_names_x_2_y_3_{}',
+            '{}.test_two_things_default_names_x_3_y_4_{}')
+        ]
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(expected_test_names, test_names)
+
+    # Note: Currently, the device string is inserted into the name multiple times.
+    # To fix this, the responsibility for adding the device string can be pushed outside
+    # into instantiate_device_type_tests(). This will result in the device string always being
+    # at the end of the test name, which is different from now for @ops tests. This possibly
+    # breaking change will be made in a future PR.
+    @unittest.expectedFailure
+    def test_name_fn(self, device):
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
+            def test_custom_names(self, device, bias):
+                pass
+
+            @parametrize("x", [1, 2], name_fn=str)
+            @parametrize("y", [3, 4], name_fn=str)
+            @parametrize("z", [5, 6], name_fn=str)
+            def test_three_things_composition_custom_names(self, device, x, y, z):
+                pass
+
+            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y))
+            def test_two_things_custom_names_alternate(self, device, x, y):
+                pass
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = [name.format(device_cls.__name__, device) for name in (
+            '{}.test_custom_names_bias_{}',
+            '{}.test_custom_names_no_bias_{}',
+            '{}.test_three_things_composition_custom_names_1_3_5_{}',
+            '{}.test_three_things_composition_custom_names_1_3_6_{}',
+            '{}.test_three_things_composition_custom_names_1_4_5_{}',
+            '{}.test_three_things_composition_custom_names_1_4_6_{}',
+            '{}.test_three_things_composition_custom_names_2_3_5_{}',
+            '{}.test_three_things_composition_custom_names_2_3_6_{}',
+            '{}.test_three_things_composition_custom_names_2_4_5_{}',
+            '{}.test_three_things_composition_custom_names_2_4_6_{}',
+            '{}.test_two_things_custom_names_alternate_1__2_{}',
+            '{}.test_two_things_custom_names_alternate_1__3_{}',
+            '{}.test_two_things_custom_names_alternate_1__4_{}')
+        ]
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(expected_test_names, test_names)
+
+    def test_subtest_names(self, device):
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            @parametrize("bias", [subtest(True, name='bias'),
+                                  subtest(False, name='no_bias')])
+            def test_custom_names(self, device, bias):
+                pass
+
+            @parametrize("x,y", [subtest((1, 2), name='double'),
+                                 subtest((1, 3), name='triple'),
+                                 subtest((1, 4), name='quadruple')])
+            def test_two_things_custom_names(self, device, x, y):
+                pass
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = [name.format(device_cls.__name__, device) for name in (
+            '{}.test_custom_names_bias_{}',
+            '{}.test_custom_names_no_bias_{}',
+            '{}.test_two_things_custom_names_double_{}',
+            '{}.test_two_things_custom_names_quadruple_{}',
+            '{}.test_two_things_custom_names_triple_{}')
+        ]
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(expected_test_names, test_names)
+
+    # Note: Currently, the device string is inserted into the name multiple times.
+    # To fix this, the responsibility for adding the device string can be pushed outside
+    # into instantiate_device_type_tests(). This will result in the device string always being
+    # at the end of the test name, which is different from now for @ops tests. This possibly
+    # breaking change will be made in a future PR.
+    @unittest.expectedFailure
+    def test_ops_composition_names(self, device):
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            @ops(op_db)
+            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
+            def test_op_parametrized(self, device, dtype, op, flag):
+                pass
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = []
+        for op in op_db:
+            for dtype in op.default_test_dtypes(device):
+                for flag_part in ('_flag_disabled_', '_flag_enabled_'):
+                    op_name = '{}{}'.format(op.name, '_' + op.variant_test_name if op.variant_test_name else '')
+                    part1 = '{}.test_op_parametrized_{}'.format(device_cls.__name__, op_name)
+                    expected_test_names.append(part1 + '_' + dtype_name(dtype) + flag_part + device)
+
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(sorted(expected_test_names), sorted(test_names))
+
+    def test_dtypes_composition_names(self, device):
+        # Test checks that @parametrize and @dtypes compose as expected.
+
+        device = self.device_type
+
+        class TestParametrized(TestCase):
+            @dtypes(torch.float32, torch.float64)
+            @parametrize("x", range(3))
+            def test_parametrized(self, x, dtype):
+                pass
+
+        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+        device_cls = locals()['TestParametrized{}'.format(device.upper())]
+        expected_test_names = [name.format(device_cls.__name__, device) for name in (
+            '{}.test_parametrized_x_0_{}_float32',
+            '{}.test_parametrized_x_0_{}_float64',
+            '{}.test_parametrized_x_1_{}_float32',
+            '{}.test_parametrized_x_1_{}_float64',
+            '{}.test_parametrized_x_2_{}_float32',
+            '{}.test_parametrized_x_2_{}_float64')
+        ]
+        test_names = _get_test_names_for_test_class(device_cls)
+        self.assertEqual(sorted(expected_test_names), sorted(test_names))
+
+    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
+    def test_subtest_expected_failure(self, device, x):
+        if x == 2:
+            raise RuntimeError('Boom')
+
+    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
+    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
+    def test_two_things_subtest_expected_failure(self, device, x, y):
+        if x == 1 or y == 6:
+            raise RuntimeError('Boom')
+
+
+instantiate_parametrized_tests(TestTestParametrization)
+instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
+
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 971b3a6..ee3c7ff 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -13,7 +13,7 @@
 from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
     IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH, \
-    TEST_WITH_MIOPEN_SUGGEST_NHWC
+    _TestParametrizer, dtype_name, TEST_WITH_MIOPEN_SUGGEST_NHWC
 from torch.testing._internal.common_cuda import _get_torch_cuda_version
 from torch.testing._internal.common_dtype import get_all_dtypes
 
@@ -252,19 +252,14 @@
 # then inherit from it for your generic test.
 
 
-def _dtype_name(dtype):
-    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
-    return str(dtype).split('.')[1]
-
-
 def _dtype_test_suffix(dtypes):
     """ Returns the test suffix for a dtype, sequence of dtypes, or None. """
     if isinstance(dtypes, list) or isinstance(dtypes, tuple):
         if len(dtypes) == 0:
             return ''
-        return '_' + '_'.join((_dtype_name(d) for d in dtypes))
+        return '_' + '_'.join((dtype_name(d) for d in dtypes))
     elif dtypes:
-        return '_{}'.format(_dtype_name(dtypes))
+        return '_{}'.format(dtype_name(dtypes))
     else:
         return ''
 
@@ -382,22 +377,32 @@
 
                 return result
 
-            assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
-            setattr(cls, test_name, instantiated_test)
+            assert not hasattr(cls, name), "Redefinition of test {0}".format(name)
+            setattr(cls, name, instantiated_test)
 
         # Handles tests that need parametrization (e.g. those that run across a set of
         # ops / modules using the @ops or @modules decorators).
-        if hasattr(test, 'parametrize_fn'):
-            for (test, test_name, param_kwargs) in test.parametrize_fn(test, generic_cls, cls):
-                instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs)
-        else:
-            dtypes = cls._get_dtypes(test)
-            dtypes = tuple(dtypes) if dtypes is not None else (None,)
-            for dtype in dtypes:
-                param_kwargs = {}
-                _update_param_kwargs(param_kwargs, 'dtype', dtype)
-                test_name = '{}_{}{}'.format(name, cls.device_type, _dtype_test_suffix(dtype))
-                instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs)
+
+        def default_parametrize_fn(test, generic_cls, cls):
+            # By default, parametrize only over device.
+            test_suffix = cls.device_type
+            yield (test, test_suffix, {})
+
+        parametrize_fn = test.parametrize_fn if hasattr(test, 'parametrize_fn') else default_parametrize_fn
+        for (test, test_suffix, param_kwargs) in parametrize_fn(test, generic_cls, cls):
+            if hasattr(test, 'handles_dtypes') and test.handles_dtypes:
+                full_name = '{}_{}'.format(name, test_suffix)
+                instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=param_kwargs)
+            else:
+                # The parametrize_fn doesn't handle dtypes internally; handle them here instead by generating
+                # a test per dtype.
+                dtypes = cls._get_dtypes(test)
+                dtypes = tuple(dtypes) if dtypes is not None else (None,)
+                for dtype in dtypes:
+                    all_param_kwargs = dict(param_kwargs)
+                    _update_param_kwargs(all_param_kwargs, 'dtype', dtype)
+                    full_name = '{}_{}{}'.format(name, test_suffix, _dtype_test_suffix(dtype))
+                    instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=all_param_kwargs)
 
     def run(self, result=None):
         super().run(result=result)
@@ -634,43 +639,6 @@
     none = 5  # Instantiate no dtype variants (no dtype kwarg needed)
 
 
-class _TestParametrizer(object):
-    """
-    Decorator class for parametrizing a test function, yielding a set of new tests spawned
-    from the original generic test, each specialized for a specific set of test inputs. For
-    example, parametrizing a test across the set of ops will result in a test function per op.
-
-    The decision of how to parametrize / what to parametrize over is intended to be implemented
-    by each derived class.
-
-    In the details, the decorator adds a 'parametrize_fn' property to the test function that is called
-    during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this,
-    there is no need to parametrize over device type, as that is already handled separately.
-    """
-    def _parametrize_test(self, test, generic_cls, device_cls):
-        """
-        Parametrizes the given test function across whatever dimension is specified by the derived class.
-        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
-        ops, all modules, or all ops + their associated dtypes.
-
-        Args:
-            test (fn): Test function to parametrize over; must support least a device arg
-            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
-            device_cls (class): Device-specialized test class object (e.g. TestFooCPU)
-
-        Returns:
-            Generator object returning 3-tuples of:
-                test (fn): Parametrized test function; must support a device arg and args for any params
-                test_name (str): Parametrized name of the test (e.g. test_bar_opname_int64)
-                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
-        """
-        raise NotImplementedError
-
-    def __call__(self, fn):
-        fn.parametrize_fn = self._parametrize_test
-        return fn
-
-
 # Decorator that defines the OpInfos a test template should be instantiated for.
 #
 # Example usage:
@@ -712,6 +680,7 @@
 class ops(_TestParametrizer):
     def __init__(self, op_list, *, dtypes: OpDTypes = OpDTypes.basic,
                  allowed_dtypes: Optional[Sequence[torch.dtype]] = None):
+        super().__init__(handles_dtypes=True)
         self.op_list = op_list
         self.opinfo_dtypes = dtypes
         self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
@@ -745,11 +714,10 @@
 
             for dtype in dtypes:
                 # Construct the test name.
-                test_name = '{}_{}{}_{}{}'.format(test.__name__,
-                                                  op.name.replace('.', '_'),
-                                                  '_' + op.variant_test_name if op.variant_test_name else '',
-                                                  device_cls.device_type,
-                                                  _dtype_test_suffix(dtype))
+                test_name = '{}{}_{}{}'.format(op.name.replace('.', '_'),
+                                               '_' + op.variant_test_name if op.variant_test_name else '',
+                                               device_cls.device_type,
+                                               _dtype_test_suffix(dtype))
 
                 # Construct parameter kwargs to pass to the test.
                 param_kwargs = {'op': op}
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5dd1cb2..5aa8b67 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -758,6 +758,12 @@
         return (supported if self._default_test_dtypes is None
                 else supported.intersection(self._default_test_dtypes))
 
+    @property
+    def formatted_name(self):
+        """Returns a formatted full name for this OpInfo that can be used in test names."""
+        variant = '_' + self.variant_test_name if self.variant_test_name else ''
+        return '{}{}'.format(self.name.replace('.', '_'), variant)
+
 
 def _generate_reduction_inputs(device, dtype, requires_grad):
     """Generates input tensors for testing reduction operators"""
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index b1cbbb3..a7133b7 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -49,6 +49,7 @@
     """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
 
     def __init__(self, module_info_list):
+        super().__init__(handles_dtypes=True)
         self.module_info_list = module_info_list
 
     def _parametrize_test(self, test, generic_cls, device_cls):
@@ -56,10 +57,9 @@
             # TODO: Factor some of this out since it's similar to OpInfo.
             for dtype in floating_types():
                 # Construct the test name.
-                test_name = '{}_{}_{}{}'.format(test.__name__,
-                                                module_info.name.replace('.', '_'),
-                                                device_cls.device_type,
-                                                _dtype_test_suffix(dtype))
+                test_name = '{}_{}{}'.format(module_info.name.replace('.', '_'),
+                                             device_cls.device_type,
+                                             _dtype_test_suffix(dtype))
 
                 # Construct parameter kwargs to pass to the test.
                 param_kwargs = {'module_info': module_info}
@@ -153,6 +153,10 @@
     def name(self):
         return formatted_module_name(self.module_cls)
 
+    @property
+    def formatted_name(self):
+        return self.name.replace('.', '_')
+
 
 def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs):
     make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 922d5c8..b32c3a1 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -76,6 +76,267 @@
 slow_tests_dict: Optional[Dict[str, Any]] = None
 disabled_tests_dict: Optional[Dict[str, Any]] = None
 
+
+class _TestParametrizer(object):
+    """
+    Decorator class for parametrizing a test function, yielding a set of new tests spawned
+    from the original generic test, each specialized for a specific set of test inputs. For
+    example, parametrizing a test across the set of ops will result in a test function per op.
+
+    The decision of how to parametrize / what to parametrize over is intended to be implemented
+    by each derived class.
+
+    In the details, the decorator adds a 'parametrize_fn' property to the test function that is called
+    during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this,
+    there is no need to parametrize over device type, as that is already handled separately.
+
+    If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
+    composite 'parametrize_fn' will be created that generates tests with the product of the parameters
+    generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
+
+    Args:
+        handles_dtypes (bool): If True, indicates that it is the responsibility of the decorator to handle
+            dtypes internally. This allows for more flexibility when needed (e.g. for op-specific dtype handling).
+            Default: True
+    """
+    def __init__(self, handles_dtypes=True):
+        self.handles_dtypes = handles_dtypes
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        """
+        Parametrizes the given test function across whatever dimension is specified by the derived class.
+        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
+        ops, all modules, or all ops + their associated dtypes.
+
+        Args:
+            test (fn): Test function to parametrize over
+            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+            device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
+                if the tests are not part of a device-specific set
+
+        Returns:
+            Generator object returning 3-tuples of:
+                test (fn): Parametrized test function; must support a device arg and args for any params
+                test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
+                    the base name of the test
+                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
+        """
+        raise NotImplementedError
+
+    def __call__(self, fn):
+        if hasattr(fn, 'parametrize_fn'):
+            # Do composition with the product of args.
+            old_parametrize_fn = fn.parametrize_fn
+            new_parametrize_fn = self._parametrize_test
+
+            def composite_fn(test, generic_cls, device_cls,
+                             old_parametrize_fn=old_parametrize_fn,
+                             new_parametrize_fn=new_parametrize_fn):
+                old_tests = [(test, test_name, param_kwargs) for (test, test_name, param_kwargs) in
+                             old_parametrize_fn(test, generic_cls, device_cls)]
+                for (old_test, old_test_name, old_param_kwargs) in old_tests:
+                    for (new_test, new_test_name, new_param_kwargs) in \
+                            new_parametrize_fn(old_test, generic_cls, device_cls):
+                        full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
+                        yield (new_test, '{}_{}'.format(new_test_name, old_test_name), full_param_kwargs)
+
+            fn.parametrize_fn = composite_fn
+            old_handles_dtypes = fn.handles_dtypes if hasattr(fn, 'handles_dtypes') else False
+            if self.handles_dtypes and old_handles_dtypes:
+                raise RuntimeError('Cannot compose multiple parametrization decorators that handle dtypes; '
+                                   'their dtype handling conflicts')
+            fn.handles_dtypes = self.handles_dtypes or old_handles_dtypes
+        else:
+            fn.parametrize_fn = self._parametrize_test
+            fn.handles_dtypes = self.handles_dtypes
+        return fn
+
+
+def instantiate_parametrized_tests(generic_cls):
+    """
+    Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
+    decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
+    parametrized tests with specialized names.
+
+    Args:
+        generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+    """
+    for attr_name in tuple(dir(generic_cls)):
+        class_attr = getattr(generic_cls, attr_name)
+        if not hasattr(class_attr, 'parametrize_fn'):
+            continue
+
+        if hasattr(class_attr, 'handles_dtypes') and class_attr.handles_dtypes:
+            raise RuntimeError('instantiate_parametrized_tests() should not be used with decorators '
+                               'that handle dtypes internally (e.g. @ops, @modules, etc.). Use '
+                               'instantiate_device_type_tests() with these instead.')
+
+        # Remove the generic test from the test class.
+        delattr(generic_cls, attr_name)
+
+        # Add parametrized tests to the test class.
+        def instantiate_test_helper(cls, name, test, param_kwargs):
+            @wraps(test)
+            def instantiated_test(self, param_kwargs=param_kwargs):
+                test(self, **param_kwargs)
+
+            assert not hasattr(generic_cls, name), "Redefinition of test {0}".format(name)
+            setattr(generic_cls, name, instantiated_test)
+
+        for (test, test_suffix, param_kwargs) in class_attr.parametrize_fn(
+                class_attr, generic_cls=generic_cls, device_cls=None):
+            full_name = '{}_{}'.format(test.__name__, test_suffix)
+            instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
+
+
+class subtest(object):
+    """
+    Explicit subtest case for use with test parametrization.
+    Allows for explicit naming of individual subtest cases as well as applying
+    decorators to the parametrized test.
+
+    Args:
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name (str): Optional name to use for the test.
+        decorators (iterable): Iterable of decorators to apply to the generated test.
+    """
+    __slots__ = ['arg_values', 'name', 'decorators']
+
+    def __init__(self, arg_values, name=None, decorators=None):
+        self.arg_values = arg_values
+        self.name = name
+        self.decorators = decorators if decorators else []
+
+
+class parametrize(_TestParametrizer):
+    """
+    Decorator for applying generic test parametrizations.
+
+    The interface for this decorator is modeled after `@pytest.mark.parametrize`.
+    Basic usage between this decorator and pytest's is identical. The first argument
+    should be a string containing comma-separated names of parameters for the test, and
+    the second argument should be an iterable returning values or tuples of values for
+    the case of multiple parameters.
+
+    Beyond this basic usage, the decorator provides some additional functionality that
+    pytest does not.
+
+    1. Parametrized tests end up as generated test functions on unittest test classes.
+    Since this differs from how pytest works, this decorator takes on the additional
+    responsibility of naming these test functions. The default test names consists of
+    the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
+    but custom names can be defined using `name_fn` or the `subtest` structure (see below).
+
+    2. The decorator specially handles parameter values of type `subtest`, which allows for
+    more fine-grained control over both test naming and test execution. In particular, it can
+    be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
+    below).
+
+    Examples::
+
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        def test_bar(self, x, y):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
+                     name_fn=lambda x, y: '{}_{}'.format(x, y))
+        def test_bar_custom_names(self, x, y):
+            ...
+
+        @parametrize("x, y", [subtest((1, 2), name='double'),
+                              subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
+                              subtest((1, 4), name='quadruple')])
+        def test_baz(self, x, y):
+            ...
+
+    Args:
+        arg_str (str): String of arg names separate by commas (e.g. "x,y").
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name_fn (callable): Optional function that takes in parameters and returns subtest name.
+    """
+    def __init__(self, arg_str, arg_values, name_fn=None):
+        super().__init__(handles_dtypes=False)
+        self.arg_names = arg_str.split(',')
+        self.arg_values = arg_values
+        self.name_fn = name_fn
+
+    def _formatted_str_repr(self, name, value):
+        """ Returns a string representation for the given arg that is suitable for use in test function names. """
+        if isinstance(value, torch.dtype):
+            return dtype_name(value)
+        elif isinstance(value, torch.device):
+            return str(value)
+        # Can't use isinstance as it would cause a circular import
+        elif value.__class__.__name__ == 'OpInfo' or value.__class__.__name__ == 'ModuleInfo':
+            return value.formatted_name
+        else:
+            # Include name and value separated by underscore.
+            return '{}_{}'.format(name, str(value).replace('.', '_'))
+
+    def _default_subtest_name(self, values):
+        return '_'.join([self._formatted_str_repr(a, v) for a, v in zip(self.arg_names, values)])
+
+    def _get_subtest_name(self, values, explicit_name=None):
+        if explicit_name:
+            subtest_name = explicit_name
+        elif self.name_fn:
+            subtest_name = self.name_fn(*values)
+        else:
+            subtest_name = self._default_subtest_name(values)
+        return subtest_name
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if len(self.arg_names) == 0:
+            # No additional parameters needed for the test.
+            test_name = device_cls.device_type if device_cls else ''
+            yield (test, test_name, {})
+        else:
+            # Each "values" item is expected to be either:
+            # * A tuple of values with one for each arg. For a single arg, a single item is expected.
+            # * A subtest instance with arg_values matching the previous.
+            for values in self.arg_values:
+                maybe_name = None
+                if isinstance(values, subtest):
+                    sub = values
+                    values = sub.arg_values
+                    maybe_name = sub.name
+
+                    # Apply decorators.
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    for decorator in sub.decorators:
+                        test_wrapper = decorator(test_wrapper)
+
+                    gen_test = test_wrapper
+                else:
+                    gen_test = test
+
+                values = list(values) if len(self.arg_names) > 1 else [values]
+                if len(values) != len(self.arg_names):
+                    raise RuntimeError('Expected # values == # arg names, but got: {} '
+                                       'values and {} names for test "{}"'.format(
+                                           len(values), len(self.arg_names), test.__name__))
+
+                param_kwargs = {
+                    name: value for name, value in zip(self.arg_names, values)
+                }
+
+                subtest_name = self._get_subtest_name(values, explicit_name=maybe_name)
+                test_name = '{}{}'.format(subtest_name, '_' + device_cls.device_type if device_cls else '')
+                if '.' in test_name:
+                    raise RuntimeError('Test name cannot contain periods, but got: {}'.format(test_name))
+
+                yield (gen_test, test_name, param_kwargs)
+
+
 class ProfilingMode(Enum):
     LEGACY = 1
     SIMPLE = 2
@@ -271,6 +532,12 @@
 def get_test_names(test_cases):
     return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
 
+def _print_test_names():
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    test_cases = discover_test_cases_recursively(suite)
+    for name in get_test_names(test_cases):
+        print(name)
+
 def chunk_list(lst, nchunks):
     return [lst[i::nchunks] for i in range(nchunks)]
 
@@ -300,10 +567,7 @@
             print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
     # Determine the test launch mechanism
     if TEST_DISCOVER:
-        suite = unittest.TestLoader().loadTestsFromModule(__main__)
-        test_cases = discover_test_cases_recursively(suite)
-        for name in get_test_names(test_cases):
-            print(name)
+        _print_test_names()
     elif TEST_IN_SUBPROCESS:
         suite = unittest.TestLoader().loadTestsFromModule(__main__)
         test_cases = discover_test_cases_recursively(suite)
@@ -2585,3 +2849,7 @@
         return wrapper
 
     return decorator
+
+def dtype_name(dtype):
+    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
+    return str(dtype).split('.')[1]