[traced-graph][sparse] cleanup test guards (#133375)
Rather than repeating the same guard for every test, simply express it once on the test fixture instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133375
Approved by: https://github.com/ezyang
diff --git a/test/export/test_sparse.py b/test/export/test_sparse.py
index 023cbbe..587b9e5 100644
--- a/test/export/test_sparse.py
+++ b/test/export/test_sparse.py
@@ -86,6 +86,9 @@
@unittest.skipIf(is_fbcode(), "See torch._dynamo.config")
+@unittest.skipIf(
+ sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
+)
class TestSparseProp(TestCase):
def setUp(self):
TestCase.setUp(self)
@@ -121,9 +124,6 @@
self.assertEqual(x_meta2, y_meta2, exact_layout=True)
self.assertEqual(x.values(), y.values(), exact_layout=True)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
@@ -145,9 +145,6 @@
else:
self.assertEqual(meta, None)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
@@ -172,9 +169,6 @@
else:
self.assertEqual(meta, None)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
@@ -197,9 +191,6 @@
else:
self.assertEqual(meta, None)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
@@ -227,9 +218,6 @@
else:
self.assertEqual(meta, None)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
def test_activation_coo(self):
net = SparseActivationCOO()
x = [torch.randn(3, 3) for _ in range(3)]
@@ -246,9 +234,6 @@
else:
self.assertEqual(meta, None)
- @unittest.skipIf(
- sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
- )
def test_activation_csr(self):
net = SparseActivationCSR()
x = [torch.randn(3, 3) for _ in range(3)]