Generic test parametrization functionality (#60753)
Summary:
This PR plays around with implementation & usage of a `parametrize` decorator for test parametrization similar to `pytest.mark.parametrize`, based on previous work introducing a `_TestParametrizer` class. It works with the internal `DeviceTest` hierarchy & composes with `dtype`, `skip*`, and other decorators. Basic usage is demonstrated in `test/test_blah.py`:
```python
import unittest
from itertools import product
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, deviceCountAtLeast, ops)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
TestCase, run_tests, parametrize, instantiate_parametrized_tests, subtest)
class TestBlah(TestCase):
parametrize("x", range(5))
def test_default_names(self, x):
print('Passed in:', x)
# Use default names but add an expected failure.
parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
*range(1, 5)])
def test_default_names_expected_failure(self, x):
if x == 0:
raise RuntimeError('Boom')
print('Passed in:', x)
parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
def test_custom_names(self, bias):
print('Passed in:', bias)
parametrize("bias", [subtest(True, name='bias'),
subtest(False, name='no_bias')])
def test_custom_names_alternate(self, bias):
print('Passed in:', bias)
parametrize("x,y", [(1, 2), (1, 3), (1, 4)])
def test_two_things_default_names(self, x, y):
print('Passed in:', x, y)
parametrize("x", [1, 2, 3])
parametrize("y", [4, 5, 6])
def test_two_things_composition(self, x, y):
print('Passed in:', x, y)
parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
*range(1, 3)])
parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
def test_two_things_composition_expected_failure(self, x, y):
if x == 0 or y == 6:
raise RuntimeError('Boom')
print('Passed in:', x, y)
parametrize("x", [1, 2])
parametrize("y", [3, 4])
parametrize("z", [5, 6])
def test_three_things_composition(self, x, y, z):
print('Passed in:', x, y, z)
parametrize("x", [1, 2], name_fn=str)
parametrize("y", [3, 4], name_fn=str)
parametrize("z", [5, 6], name_fn=str)
def test_three_things_composition_custom_names(self, x, y, z):
print('Passed in:', x, y, z)
parametrize("x,y", product(range(2), range(3)))
def test_two_things_product(self, x, y):
print('Passed in:', x, y)
parametrize("x,y", [subtest((1, 2), name='double'),
subtest((1, 3), name='triple'),
subtest((1, 4), name='quadruple')])
def test_two_things_custom_names(self, x, y):
print('Passed in:', x, y)
parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}_{}'.format(x, y))
def test_two_things_custom_names_alternate(self, x, y):
print('Passed in:', x, y)
class TestDeviceBlah(TestCase):
parametrize("x", range(10))
def test_default_names(self, device, x):
print('Passed in:', device, x)
parametrize("x,y", [(1, 2), (3, 4), (5, 6)])
def test_two_things(self, device, x, y):
print('Passed in:', device, x, y)
deviceCountAtLeast(1)
def test_multiple_devices(self, devices):
print('Passed in:', devices)
ops(op_db)
parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
def test_op_parametrized(self, device, dtype, op, flag):
print('Passed in:', device, dtype, op, flag)
instantiate_parametrized_tests(TestBlah)
instantiate_device_type_tests(TestDeviceBlah, globals())
if __name__ == '__main__':
run_tests()
```
Generated tests:
```
TestBlah.test_custom_names_alternate_bias
TestBlah.test_custom_names_alternate_no_bias
TestBlah.test_custom_names_bias
TestBlah.test_custom_names_no_bias
TestBlah.test_default_names_expected_failure_x_0
TestBlah.test_default_names_expected_failure_x_1
TestBlah.test_default_names_expected_failure_x_2
TestBlah.test_default_names_expected_failure_x_3
TestBlah.test_default_names_expected_failure_x_4
TestBlah.test_default_names_x_0
TestBlah.test_default_names_x_1
TestBlah.test_default_names_x_2
TestBlah.test_default_names_x_3
TestBlah.test_default_names_x_4
TestBlah.test_three_things_composition_custom_names_1_3_5
TestBlah.test_three_things_composition_custom_names_1_3_6
TestBlah.test_three_things_composition_custom_names_1_4_5
TestBlah.test_three_things_composition_custom_names_1_4_6
TestBlah.test_three_things_composition_custom_names_2_3_5
TestBlah.test_three_things_composition_custom_names_2_3_6
TestBlah.test_three_things_composition_custom_names_2_4_5
TestBlah.test_three_things_composition_custom_names_2_4_6
TestBlah.test_three_things_composition_x_1_y_3_z_5
TestBlah.test_three_things_composition_x_1_y_3_z_6
TestBlah.test_three_things_composition_x_1_y_4_z_5
TestBlah.test_three_things_composition_x_1_y_4_z_6
TestBlah.test_three_things_composition_x_2_y_3_z_5
TestBlah.test_three_things_composition_x_2_y_3_z_6
TestBlah.test_three_things_composition_x_2_y_4_z_5
TestBlah.test_three_things_composition_x_2_y_4_z_6
TestBlah.test_two_things_composition_expected_failure_x_0_y_4
TestBlah.test_two_things_composition_expected_failure_x_0_y_5
TestBlah.test_two_things_composition_expected_failure_x_0_y_6
TestBlah.test_two_things_composition_expected_failure_x_1_y_4
TestBlah.test_two_things_composition_expected_failure_x_1_y_5
TestBlah.test_two_things_composition_expected_failure_x_1_y_6
TestBlah.test_two_things_composition_expected_failure_x_2_y_4
TestBlah.test_two_things_composition_expected_failure_x_2_y_5
TestBlah.test_two_things_composition_expected_failure_x_2_y_6
TestBlah.test_two_things_composition_x_1_y_4
TestBlah.test_two_things_composition_x_1_y_5
TestBlah.test_two_things_composition_x_1_y_6
TestBlah.test_two_things_composition_x_2_y_4
TestBlah.test_two_things_composition_x_2_y_5
TestBlah.test_two_things_composition_x_2_y_6
TestBlah.test_two_things_composition_x_3_y_4
TestBlah.test_two_things_composition_x_3_y_5
TestBlah.test_two_things_composition_x_3_y_6
TestBlah.test_two_things_custom_names_alternate_1_2
TestBlah.test_two_things_custom_names_alternate_1_3
TestBlah.test_two_things_custom_names_alternate_1_4
TestBlah.test_two_things_custom_names_double
TestBlah.test_two_things_custom_names_quadruple
TestBlah.test_two_things_custom_names_triple
TestBlah.test_two_things_default_names_x_1_y_2
TestBlah.test_two_things_default_names_x_1_y_3
TestBlah.test_two_things_default_names_x_1_y_4
TestBlah.test_two_things_product_x_0_y_0
TestBlah.test_two_things_product_x_0_y_1
TestBlah.test_two_things_product_x_0_y_2
TestBlah.test_two_things_product_x_1_y_0
TestBlah.test_two_things_product_x_1_y_1
TestBlah.test_two_things_product_x_1_y_2
TestDeviceBlahCPU.test_default_names_x_0_cpu
TestDeviceBlahCPU.test_default_names_x_1_cpu
TestDeviceBlahCPU.test_default_names_x_2_cpu
TestDeviceBlahCPU.test_default_names_x_3_cpu
TestDeviceBlahCPU.test_default_names_x_4_cpu
TestDeviceBlahCPU.test_default_names_x_5_cpu
TestDeviceBlahCPU.test_default_names_x_6_cpu
TestDeviceBlahCPU.test_default_names_x_7_cpu
TestDeviceBlahCPU.test_default_names_x_8_cpu
TestDeviceBlahCPU.test_default_names_x_9_cpu
TestDeviceBlahCPU.test_multiple_devices_cpu
TestDeviceBlahCPU.test_op_parametrized_<opname>_<variant>_cpu_uint8_flag_enabled_cpu
TestDeviceBlahCPU.test_two_things_x_1_y_2_cpu
TestDeviceBlahCPU.test_two_things_x_3_y_4_cpu
TestDeviceBlahCPU.test_two_things_x_5_y_6_cpu
TestDeviceBlahMETA.test_default_names_x_0_meta
TestDeviceBlahMETA.test_default_names_x_1_meta
TestDeviceBlahMETA.test_default_names_x_2_meta
TestDeviceBlahMETA.test_default_names_x_3_meta
TestDeviceBlahMETA.test_default_names_x_4_meta
TestDeviceBlahMETA.test_default_names_x_5_meta
TestDeviceBlahMETA.test_default_names_x_6_meta
TestDeviceBlahMETA.test_default_names_x_7_meta
TestDeviceBlahMETA.test_default_names_x_8_meta
TestDeviceBlahMETA.test_default_names_x_9_meta
TestDeviceBlahMETA.test_multiple_devices_meta
TestDeviceBlahMETA.test_op_parametrized_<opname>_<variant>_meta_uint8_flag_enabled_meta
TestDeviceBlahMETA.test_two_things_x_1_y_2_meta
TestDeviceBlahMETA.test_two_things_x_3_y_4_meta
TestDeviceBlahMETA.test_two_things_x_5_y_6_meta
```
Caveats:
* `parametrize` decorators cannot be "stacked" yet; each one overwrites the previous. This will change to either:
* Allow stacking of multiple decorators
* Error out with a nice error message if multiple decorators are specified
The PR introduces `instantiate_parametrized_tests()` in addition to `instantiate_device_type_tests()`. The former should be used for non-device-specific tests, and the latter should be used for device-specific tests, as usual. Both of these support the `parametrize` decorator. Only the latter supports the `ops` decorator (no change here- this was already the case).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60753
Reviewed By: saketh-are
Differential Revision: D30606615
Pulled By: jbschlosser
fbshipit-source-id: a34f36d643f68a6e221f419d9bb3e1ae1d84dd65
diff --git a/test/test_testing.py b/test/test_testing.py
index e45977f..d777587 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -12,11 +12,12 @@
from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
- (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest)
+ (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest,
+ parametrize, subtest, instantiate_parametrized_tests, dtype_name)
from torch.testing._internal.common_device_type import \
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA,
- deviceCountAtLeast)
+ deviceCountAtLeast, ops)
from torch.testing._internal.common_methods_invocations import op_db
import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal.common_dtype import get_all_dtypes
@@ -1425,5 +1426,312 @@
fn()
+def _get_test_names_for_test_class(test_cls):
+ """ Convenience function to get all test names for a given test class. """
+ test_names = ['{}.{}'.format(test_cls.__name__, key) for key in test_cls.__dict__
+ if key.startswith('test_')]
+ return sorted(test_names)
+
+
+class TestTestParametrization(TestCase):
+ def test_default_names(self):
+
+ class TestParametrized(TestCase):
+ @parametrize("x", range(5))
+ def test_default_names(self, x):
+ pass
+
+ @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
+ def test_two_things_default_names(self, x, y):
+ pass
+
+ instantiate_parametrized_tests(TestParametrized)
+
+ expected_test_names = [
+ 'TestParametrized.test_default_names_x_0',
+ 'TestParametrized.test_default_names_x_1',
+ 'TestParametrized.test_default_names_x_2',
+ 'TestParametrized.test_default_names_x_3',
+ 'TestParametrized.test_default_names_x_4',
+ 'TestParametrized.test_two_things_default_names_x_1_y_2',
+ 'TestParametrized.test_two_things_default_names_x_2_y_3',
+ 'TestParametrized.test_two_things_default_names_x_3_y_4',
+ ]
+ test_names = _get_test_names_for_test_class(TestParametrized)
+ self.assertEqual(expected_test_names, test_names)
+
+ def test_name_fn(self):
+
+ class TestParametrized(TestCase):
+ @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
+ def test_custom_names(self, bias):
+ pass
+
+ @parametrize("x", [1, 2], name_fn=str)
+ @parametrize("y", [3, 4], name_fn=str)
+ @parametrize("z", [5, 6], name_fn=str)
+ def test_three_things_composition_custom_names(self, x, y, z):
+ pass
+
+ @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y))
+ def test_two_things_custom_names_alternate(self, x, y):
+ pass
+
+ instantiate_parametrized_tests(TestParametrized)
+
+ expected_test_names = [
+ 'TestParametrized.test_custom_names_bias',
+ 'TestParametrized.test_custom_names_no_bias',
+ 'TestParametrized.test_three_things_composition_custom_names_1_3_5',
+ 'TestParametrized.test_three_things_composition_custom_names_1_3_6',
+ 'TestParametrized.test_three_things_composition_custom_names_1_4_5',
+ 'TestParametrized.test_three_things_composition_custom_names_1_4_6',
+ 'TestParametrized.test_three_things_composition_custom_names_2_3_5',
+ 'TestParametrized.test_three_things_composition_custom_names_2_3_6',
+ 'TestParametrized.test_three_things_composition_custom_names_2_4_5',
+ 'TestParametrized.test_three_things_composition_custom_names_2_4_6',
+ 'TestParametrized.test_two_things_custom_names_alternate_1__2',
+ 'TestParametrized.test_two_things_custom_names_alternate_1__3',
+ 'TestParametrized.test_two_things_custom_names_alternate_1__4',
+ ]
+ test_names = _get_test_names_for_test_class(TestParametrized)
+ self.assertEqual(expected_test_names, test_names)
+
+ def test_subtest_names(self):
+
+ class TestParametrized(TestCase):
+ @parametrize("bias", [subtest(True, name='bias'),
+ subtest(False, name='no_bias')])
+ def test_custom_names(self, bias):
+ pass
+
+ @parametrize("x,y", [subtest((1, 2), name='double'),
+ subtest((1, 3), name='triple'),
+ subtest((1, 4), name='quadruple')])
+ def test_two_things_custom_names(self, x, y):
+ pass
+
+ instantiate_parametrized_tests(TestParametrized)
+
+ expected_test_names = [
+ 'TestParametrized.test_custom_names_bias',
+ 'TestParametrized.test_custom_names_no_bias',
+ 'TestParametrized.test_two_things_custom_names_double',
+ 'TestParametrized.test_two_things_custom_names_quadruple',
+ 'TestParametrized.test_two_things_custom_names_triple',
+ ]
+ test_names = _get_test_names_for_test_class(TestParametrized)
+ self.assertEqual(expected_test_names, test_names)
+
+ @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
+ def test_subtest_expected_failure(self, x):
+ if x == 2:
+ raise RuntimeError('Boom')
+
+ @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
+ @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
+ def test_two_things_subtest_expected_failure(self, x, y):
+ if x == 1 or y == 6:
+ raise RuntimeError('Boom')
+
+
+class TestTestParametrizationDeviceType(TestCase):
+ def test_unparametrized_names(self, device):
+ # This test exists to protect against regressions in device / dtype test naming
+ # due to parametrization logic.
+
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ def test_device_specific(self, device):
+ pass
+
+ @dtypes(torch.float32, torch.float64)
+ def test_device_dtype_specific(self, device, dtype):
+ pass
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = [name.format(device_cls.__name__, device) for name in (
+ '{}.test_device_dtype_specific_{}_float32',
+ '{}.test_device_dtype_specific_{}_float64',
+ '{}.test_device_specific_{}')
+ ]
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(expected_test_names, test_names)
+
+ def test_default_names(self, device):
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ @parametrize("x", range(5))
+ def test_default_names(self, device, x):
+ pass
+
+ @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
+ def test_two_things_default_names(self, device, x, y):
+ pass
+
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = [name.format(device_cls.__name__, device) for name in (
+ '{}.test_default_names_x_0_{}',
+ '{}.test_default_names_x_1_{}',
+ '{}.test_default_names_x_2_{}',
+ '{}.test_default_names_x_3_{}',
+ '{}.test_default_names_x_4_{}',
+ '{}.test_two_things_default_names_x_1_y_2_{}',
+ '{}.test_two_things_default_names_x_2_y_3_{}',
+ '{}.test_two_things_default_names_x_3_y_4_{}')
+ ]
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(expected_test_names, test_names)
+
+ # Note: Currently, the device string is inserted into the name multiple times.
+ # To fix this, the responsibility for adding the device string can be pushed outside
+ # into instantiate_device_type_tests(). This will result in the device string always being
+ # at the end of the test name, which is different from now for @ops tests. This possibly
+ # breaking change will be made in a future PR.
+ @unittest.expectedFailure
+ def test_name_fn(self, device):
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
+ def test_custom_names(self, device, bias):
+ pass
+
+ @parametrize("x", [1, 2], name_fn=str)
+ @parametrize("y", [3, 4], name_fn=str)
+ @parametrize("z", [5, 6], name_fn=str)
+ def test_three_things_composition_custom_names(self, device, x, y, z):
+ pass
+
+ @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y))
+ def test_two_things_custom_names_alternate(self, device, x, y):
+ pass
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = [name.format(device_cls.__name__, device) for name in (
+ '{}.test_custom_names_bias_{}',
+ '{}.test_custom_names_no_bias_{}',
+ '{}.test_three_things_composition_custom_names_1_3_5_{}',
+ '{}.test_three_things_composition_custom_names_1_3_6_{}',
+ '{}.test_three_things_composition_custom_names_1_4_5_{}',
+ '{}.test_three_things_composition_custom_names_1_4_6_{}',
+ '{}.test_three_things_composition_custom_names_2_3_5_{}',
+ '{}.test_three_things_composition_custom_names_2_3_6_{}',
+ '{}.test_three_things_composition_custom_names_2_4_5_{}',
+ '{}.test_three_things_composition_custom_names_2_4_6_{}',
+ '{}.test_two_things_custom_names_alternate_1__2_{}',
+ '{}.test_two_things_custom_names_alternate_1__3_{}',
+ '{}.test_two_things_custom_names_alternate_1__4_{}')
+ ]
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(expected_test_names, test_names)
+
+ def test_subtest_names(self, device):
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ @parametrize("bias", [subtest(True, name='bias'),
+ subtest(False, name='no_bias')])
+ def test_custom_names(self, device, bias):
+ pass
+
+ @parametrize("x,y", [subtest((1, 2), name='double'),
+ subtest((1, 3), name='triple'),
+ subtest((1, 4), name='quadruple')])
+ def test_two_things_custom_names(self, device, x, y):
+ pass
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = [name.format(device_cls.__name__, device) for name in (
+ '{}.test_custom_names_bias_{}',
+ '{}.test_custom_names_no_bias_{}',
+ '{}.test_two_things_custom_names_double_{}',
+ '{}.test_two_things_custom_names_quadruple_{}',
+ '{}.test_two_things_custom_names_triple_{}')
+ ]
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(expected_test_names, test_names)
+
+ # Note: Currently, the device string is inserted into the name multiple times.
+ # To fix this, the responsibility for adding the device string can be pushed outside
+ # into instantiate_device_type_tests(). This will result in the device string always being
+ # at the end of the test name, which is different from now for @ops tests. This possibly
+ # breaking change will be made in a future PR.
+ @unittest.expectedFailure
+ def test_ops_composition_names(self, device):
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ @ops(op_db)
+ @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
+ def test_op_parametrized(self, device, dtype, op, flag):
+ pass
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = []
+ for op in op_db:
+ for dtype in op.default_test_dtypes(device):
+ for flag_part in ('_flag_disabled_', '_flag_enabled_'):
+ op_name = '{}{}'.format(op.name, '_' + op.variant_test_name if op.variant_test_name else '')
+ part1 = '{}.test_op_parametrized_{}'.format(device_cls.__name__, op_name)
+ expected_test_names.append(part1 + '_' + dtype_name(dtype) + flag_part + device)
+
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(sorted(expected_test_names), sorted(test_names))
+
+ def test_dtypes_composition_names(self, device):
+ # Test checks that @parametrize and @dtypes compose as expected.
+
+ device = self.device_type
+
+ class TestParametrized(TestCase):
+ @dtypes(torch.float32, torch.float64)
+ @parametrize("x", range(3))
+ def test_parametrized(self, x, dtype):
+ pass
+
+ instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
+
+ device_cls = locals()['TestParametrized{}'.format(device.upper())]
+ expected_test_names = [name.format(device_cls.__name__, device) for name in (
+ '{}.test_parametrized_x_0_{}_float32',
+ '{}.test_parametrized_x_0_{}_float64',
+ '{}.test_parametrized_x_1_{}_float32',
+ '{}.test_parametrized_x_1_{}_float64',
+ '{}.test_parametrized_x_2_{}_float32',
+ '{}.test_parametrized_x_2_{}_float64')
+ ]
+ test_names = _get_test_names_for_test_class(device_cls)
+ self.assertEqual(sorted(expected_test_names), sorted(test_names))
+
+ @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
+ def test_subtest_expected_failure(self, device, x):
+ if x == 2:
+ raise RuntimeError('Boom')
+
+ @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
+ @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
+ def test_two_things_subtest_expected_failure(self, device, x, y):
+ if x == 1 or y == 6:
+ raise RuntimeError('Boom')
+
+
+instantiate_parametrized_tests(TestTestParametrization)
+instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 971b3a6..ee3c7ff 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -13,7 +13,7 @@
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH, \
- TEST_WITH_MIOPEN_SUGGEST_NHWC
+ _TestParametrizer, dtype_name, TEST_WITH_MIOPEN_SUGGEST_NHWC
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_dtype import get_all_dtypes
@@ -252,19 +252,14 @@
# then inherit from it for your generic test.
-def _dtype_name(dtype):
- """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
- return str(dtype).split('.')[1]
-
-
def _dtype_test_suffix(dtypes):
""" Returns the test suffix for a dtype, sequence of dtypes, or None. """
if isinstance(dtypes, list) or isinstance(dtypes, tuple):
if len(dtypes) == 0:
return ''
- return '_' + '_'.join((_dtype_name(d) for d in dtypes))
+ return '_' + '_'.join((dtype_name(d) for d in dtypes))
elif dtypes:
- return '_{}'.format(_dtype_name(dtypes))
+ return '_{}'.format(dtype_name(dtypes))
else:
return ''
@@ -382,22 +377,32 @@
return result
- assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
- setattr(cls, test_name, instantiated_test)
+ assert not hasattr(cls, name), "Redefinition of test {0}".format(name)
+ setattr(cls, name, instantiated_test)
# Handles tests that need parametrization (e.g. those that run across a set of
# ops / modules using the @ops or @modules decorators).
- if hasattr(test, 'parametrize_fn'):
- for (test, test_name, param_kwargs) in test.parametrize_fn(test, generic_cls, cls):
- instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs)
- else:
- dtypes = cls._get_dtypes(test)
- dtypes = tuple(dtypes) if dtypes is not None else (None,)
- for dtype in dtypes:
- param_kwargs = {}
- _update_param_kwargs(param_kwargs, 'dtype', dtype)
- test_name = '{}_{}{}'.format(name, cls.device_type, _dtype_test_suffix(dtype))
- instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs)
+
+ def default_parametrize_fn(test, generic_cls, cls):
+ # By default, parametrize only over device.
+ test_suffix = cls.device_type
+ yield (test, test_suffix, {})
+
+ parametrize_fn = test.parametrize_fn if hasattr(test, 'parametrize_fn') else default_parametrize_fn
+ for (test, test_suffix, param_kwargs) in parametrize_fn(test, generic_cls, cls):
+ if hasattr(test, 'handles_dtypes') and test.handles_dtypes:
+ full_name = '{}_{}'.format(name, test_suffix)
+ instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=param_kwargs)
+ else:
+ # The parametrize_fn doesn't handle dtypes internally; handle them here instead by generating
+ # a test per dtype.
+ dtypes = cls._get_dtypes(test)
+ dtypes = tuple(dtypes) if dtypes is not None else (None,)
+ for dtype in dtypes:
+ all_param_kwargs = dict(param_kwargs)
+ _update_param_kwargs(all_param_kwargs, 'dtype', dtype)
+ full_name = '{}_{}{}'.format(name, test_suffix, _dtype_test_suffix(dtype))
+ instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=all_param_kwargs)
def run(self, result=None):
super().run(result=result)
@@ -634,43 +639,6 @@
none = 5 # Instantiate no dtype variants (no dtype kwarg needed)
-class _TestParametrizer(object):
- """
- Decorator class for parametrizing a test function, yielding a set of new tests spawned
- from the original generic test, each specialized for a specific set of test inputs. For
- example, parametrizing a test across the set of ops will result in a test function per op.
-
- The decision of how to parametrize / what to parametrize over is intended to be implemented
- by each derived class.
-
- In the details, the decorator adds a 'parametrize_fn' property to the test function that is called
- during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this,
- there is no need to parametrize over device type, as that is already handled separately.
- """
- def _parametrize_test(self, test, generic_cls, device_cls):
- """
- Parametrizes the given test function across whatever dimension is specified by the derived class.
- Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
- ops, all modules, or all ops + their associated dtypes.
-
- Args:
- test (fn): Test function to parametrize over; must support least a device arg
- generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
- device_cls (class): Device-specialized test class object (e.g. TestFooCPU)
-
- Returns:
- Generator object returning 3-tuples of:
- test (fn): Parametrized test function; must support a device arg and args for any params
- test_name (str): Parametrized name of the test (e.g. test_bar_opname_int64)
- param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
- """
- raise NotImplementedError
-
- def __call__(self, fn):
- fn.parametrize_fn = self._parametrize_test
- return fn
-
-
# Decorator that defines the OpInfos a test template should be instantiated for.
#
# Example usage:
@@ -712,6 +680,7 @@
class ops(_TestParametrizer):
def __init__(self, op_list, *, dtypes: OpDTypes = OpDTypes.basic,
allowed_dtypes: Optional[Sequence[torch.dtype]] = None):
+ super().__init__(handles_dtypes=True)
self.op_list = op_list
self.opinfo_dtypes = dtypes
self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
@@ -745,11 +714,10 @@
for dtype in dtypes:
# Construct the test name.
- test_name = '{}_{}{}_{}{}'.format(test.__name__,
- op.name.replace('.', '_'),
- '_' + op.variant_test_name if op.variant_test_name else '',
- device_cls.device_type,
- _dtype_test_suffix(dtype))
+ test_name = '{}{}_{}{}'.format(op.name.replace('.', '_'),
+ '_' + op.variant_test_name if op.variant_test_name else '',
+ device_cls.device_type,
+ _dtype_test_suffix(dtype))
# Construct parameter kwargs to pass to the test.
param_kwargs = {'op': op}
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5dd1cb2..5aa8b67 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -758,6 +758,12 @@
return (supported if self._default_test_dtypes is None
else supported.intersection(self._default_test_dtypes))
+ @property
+ def formatted_name(self):
+ """Returns a formatted full name for this OpInfo that can be used in test names."""
+ variant = '_' + self.variant_test_name if self.variant_test_name else ''
+ return '{}{}'.format(self.name.replace('.', '_'), variant)
+
def _generate_reduction_inputs(device, dtype, requires_grad):
"""Generates input tensors for testing reduction operators"""
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index b1cbbb3..a7133b7 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -49,6 +49,7 @@
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
def __init__(self, module_info_list):
+ super().__init__(handles_dtypes=True)
self.module_info_list = module_info_list
def _parametrize_test(self, test, generic_cls, device_cls):
@@ -56,10 +57,9 @@
# TODO: Factor some of this out since it's similar to OpInfo.
for dtype in floating_types():
# Construct the test name.
- test_name = '{}_{}_{}{}'.format(test.__name__,
- module_info.name.replace('.', '_'),
- device_cls.device_type,
- _dtype_test_suffix(dtype))
+ test_name = '{}_{}{}'.format(module_info.name.replace('.', '_'),
+ device_cls.device_type,
+ _dtype_test_suffix(dtype))
# Construct parameter kwargs to pass to the test.
param_kwargs = {'module_info': module_info}
@@ -153,6 +153,10 @@
def name(self):
return formatted_module_name(self.module_cls)
+ @property
+ def formatted_name(self):
+ return self.name.replace('.', '_')
+
def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 922d5c8..b32c3a1 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -76,6 +76,267 @@
slow_tests_dict: Optional[Dict[str, Any]] = None
disabled_tests_dict: Optional[Dict[str, Any]] = None
+
+class _TestParametrizer(object):
+ """
+ Decorator class for parametrizing a test function, yielding a set of new tests spawned
+ from the original generic test, each specialized for a specific set of test inputs. For
+ example, parametrizing a test across the set of ops will result in a test function per op.
+
+ The decision of how to parametrize / what to parametrize over is intended to be implemented
+ by each derived class.
+
+ In the details, the decorator adds a 'parametrize_fn' property to the test function that is called
+ during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this,
+ there is no need to parametrize over device type, as that is already handled separately.
+
+ If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
+ composite 'parametrize_fn' will be created that generates tests with the product of the parameters
+ generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
+
+ Args:
+ handles_dtypes (bool): If True, indicates that it is the responsibility of the decorator to handle
+ dtypes internally. This allows for more flexibility when needed (e.g. for op-specific dtype handling).
+ Default: True
+ """
+ def __init__(self, handles_dtypes=True):
+ self.handles_dtypes = handles_dtypes
+
+ def _parametrize_test(self, test, generic_cls, device_cls):
+ """
+ Parametrizes the given test function across whatever dimension is specified by the derived class.
+ Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
+ ops, all modules, or all ops + their associated dtypes.
+
+ Args:
+ test (fn): Test function to parametrize over
+ generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+ device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
+ if the tests are not part of a device-specific set
+
+ Returns:
+ Generator object returning 3-tuples of:
+ test (fn): Parametrized test function; must support a device arg and args for any params
+ test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
+ the base name of the test
+ param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
+ """
+ raise NotImplementedError
+
+ def __call__(self, fn):
+ if hasattr(fn, 'parametrize_fn'):
+ # Do composition with the product of args.
+ old_parametrize_fn = fn.parametrize_fn
+ new_parametrize_fn = self._parametrize_test
+
+ def composite_fn(test, generic_cls, device_cls,
+ old_parametrize_fn=old_parametrize_fn,
+ new_parametrize_fn=new_parametrize_fn):
+ old_tests = [(test, test_name, param_kwargs) for (test, test_name, param_kwargs) in
+ old_parametrize_fn(test, generic_cls, device_cls)]
+ for (old_test, old_test_name, old_param_kwargs) in old_tests:
+ for (new_test, new_test_name, new_param_kwargs) in \
+ new_parametrize_fn(old_test, generic_cls, device_cls):
+ full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
+ yield (new_test, '{}_{}'.format(new_test_name, old_test_name), full_param_kwargs)
+
+ fn.parametrize_fn = composite_fn
+ old_handles_dtypes = fn.handles_dtypes if hasattr(fn, 'handles_dtypes') else False
+ if self.handles_dtypes and old_handles_dtypes:
+ raise RuntimeError('Cannot compose multiple parametrization decorators that handle dtypes; '
+ 'their dtype handling conflicts')
+ fn.handles_dtypes = self.handles_dtypes or old_handles_dtypes
+ else:
+ fn.parametrize_fn = self._parametrize_test
+ fn.handles_dtypes = self.handles_dtypes
+ return fn
+
+
+def instantiate_parametrized_tests(generic_cls):
+ """
+ Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
+ decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
+ parametrized tests with specialized names.
+
+ Args:
+ generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+ """
+ for attr_name in tuple(dir(generic_cls)):
+ class_attr = getattr(generic_cls, attr_name)
+ if not hasattr(class_attr, 'parametrize_fn'):
+ continue
+
+ if hasattr(class_attr, 'handles_dtypes') and class_attr.handles_dtypes:
+ raise RuntimeError('instantiate_parametrized_tests() should not be used with decorators '
+ 'that handle dtypes internally (e.g. @ops, @modules, etc.). Use '
+ 'instantiate_device_type_tests() with these instead.')
+
+ # Remove the generic test from the test class.
+ delattr(generic_cls, attr_name)
+
+ # Add parametrized tests to the test class.
+ def instantiate_test_helper(cls, name, test, param_kwargs):
+ @wraps(test)
+ def instantiated_test(self, param_kwargs=param_kwargs):
+ test(self, **param_kwargs)
+
+ assert not hasattr(generic_cls, name), "Redefinition of test {0}".format(name)
+ setattr(generic_cls, name, instantiated_test)
+
+ for (test, test_suffix, param_kwargs) in class_attr.parametrize_fn(
+ class_attr, generic_cls=generic_cls, device_cls=None):
+ full_name = '{}_{}'.format(test.__name__, test_suffix)
+ instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
+
+
+class subtest(object):
+ """
+ Explicit subtest case for use with test parametrization.
+ Allows for explicit naming of individual subtest cases as well as applying
+ decorators to the parametrized test.
+
+ Args:
+ arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+ tuples of arg values (e.g. [(1, 2), (3, 4)]).
+ name (str): Optional name to use for the test.
+ decorators (iterable): Iterable of decorators to apply to the generated test.
+ """
+ __slots__ = ['arg_values', 'name', 'decorators']
+
+ def __init__(self, arg_values, name=None, decorators=None):
+ self.arg_values = arg_values
+ self.name = name
+ self.decorators = decorators if decorators else []
+
+
+class parametrize(_TestParametrizer):
+ """
+ Decorator for applying generic test parametrizations.
+
+ The interface for this decorator is modeled after `@pytest.mark.parametrize`.
+ Basic usage between this decorator and pytest's is identical. The first argument
+ should be a string containing comma-separated names of parameters for the test, and
+ the second argument should be an iterable returning values or tuples of values for
+ the case of multiple parameters.
+
+ Beyond this basic usage, the decorator provides some additional functionality that
+ pytest does not.
+
+ 1. Parametrized tests end up as generated test functions on unittest test classes.
+ Since this differs from how pytest works, this decorator takes on the additional
+ responsibility of naming these test functions. The default test names consists of
+ the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
+ but custom names can be defined using `name_fn` or the `subtest` structure (see below).
+
+ 2. The decorator specially handles parameter values of type `subtest`, which allows for
+ more fine-grained control over both test naming and test execution. In particular, it can
+ be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
+ below).
+
+ Examples::
+
+ @parametrize("x", range(5))
+ def test_foo(self, x):
+ ...
+
+ @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+ def test_bar(self, x, y):
+ ...
+
+ @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
+ name_fn=lambda x, y: '{}_{}'.format(x, y))
+ def test_bar_custom_names(self, x, y):
+ ...
+
+ @parametrize("x, y", [subtest((1, 2), name='double'),
+ subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
+ subtest((1, 4), name='quadruple')])
+ def test_baz(self, x, y):
+ ...
+
+ Args:
+ arg_str (str): String of arg names separate by commas (e.g. "x,y").
+ arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+ tuples of arg values (e.g. [(1, 2), (3, 4)]).
+ name_fn (callable): Optional function that takes in parameters and returns subtest name.
+ """
+ def __init__(self, arg_str, arg_values, name_fn=None):
+ super().__init__(handles_dtypes=False)
+ self.arg_names = arg_str.split(',')
+ self.arg_values = arg_values
+ self.name_fn = name_fn
+
+ def _formatted_str_repr(self, name, value):
+ """ Returns a string representation for the given arg that is suitable for use in test function names. """
+ if isinstance(value, torch.dtype):
+ return dtype_name(value)
+ elif isinstance(value, torch.device):
+ return str(value)
+ # Can't use isinstance as it would cause a circular import
+ elif value.__class__.__name__ == 'OpInfo' or value.__class__.__name__ == 'ModuleInfo':
+ return value.formatted_name
+ else:
+ # Include name and value separated by underscore.
+ return '{}_{}'.format(name, str(value).replace('.', '_'))
+
+ def _default_subtest_name(self, values):
+ return '_'.join([self._formatted_str_repr(a, v) for a, v in zip(self.arg_names, values)])
+
+ def _get_subtest_name(self, values, explicit_name=None):
+ if explicit_name:
+ subtest_name = explicit_name
+ elif self.name_fn:
+ subtest_name = self.name_fn(*values)
+ else:
+ subtest_name = self._default_subtest_name(values)
+ return subtest_name
+
+ def _parametrize_test(self, test, generic_cls, device_cls):
+ if len(self.arg_names) == 0:
+ # No additional parameters needed for the test.
+ test_name = device_cls.device_type if device_cls else ''
+ yield (test, test_name, {})
+ else:
+ # Each "values" item is expected to be either:
+ # * A tuple of values with one for each arg. For a single arg, a single item is expected.
+ # * A subtest instance with arg_values matching the previous.
+ for values in self.arg_values:
+ maybe_name = None
+ if isinstance(values, subtest):
+ sub = values
+ values = sub.arg_values
+ maybe_name = sub.name
+
+ # Apply decorators.
+ @wraps(test)
+ def test_wrapper(*args, **kwargs):
+ return test(*args, **kwargs)
+
+ for decorator in sub.decorators:
+ test_wrapper = decorator(test_wrapper)
+
+ gen_test = test_wrapper
+ else:
+ gen_test = test
+
+ values = list(values) if len(self.arg_names) > 1 else [values]
+ if len(values) != len(self.arg_names):
+ raise RuntimeError('Expected # values == # arg names, but got: {} '
+ 'values and {} names for test "{}"'.format(
+ len(values), len(self.arg_names), test.__name__))
+
+ param_kwargs = {
+ name: value for name, value in zip(self.arg_names, values)
+ }
+
+ subtest_name = self._get_subtest_name(values, explicit_name=maybe_name)
+ test_name = '{}{}'.format(subtest_name, '_' + device_cls.device_type if device_cls else '')
+ if '.' in test_name:
+ raise RuntimeError('Test name cannot contain periods, but got: {}'.format(test_name))
+
+ yield (gen_test, test_name, param_kwargs)
+
+
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
@@ -271,6 +532,12 @@
def get_test_names(test_cases):
return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
+def _print_test_names():
+ suite = unittest.TestLoader().loadTestsFromModule(__main__)
+ test_cases = discover_test_cases_recursively(suite)
+ for name in get_test_names(test_cases):
+ print(name)
+
def chunk_list(lst, nchunks):
return [lst[i::nchunks] for i in range(nchunks)]
@@ -300,10 +567,7 @@
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
# Determine the test launch mechanism
if TEST_DISCOVER:
- suite = unittest.TestLoader().loadTestsFromModule(__main__)
- test_cases = discover_test_cases_recursively(suite)
- for name in get_test_names(test_cases):
- print(name)
+ _print_test_names()
elif TEST_IN_SUBPROCESS:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = discover_test_cases_recursively(suite)
@@ -2585,3 +2849,7 @@
return wrapper
return decorator
+
+def dtype_name(dtype):
+ """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
+ return str(dtype).split('.')[1]