Setting validation flag for Distributions tests to work with TorchDynamo (#80081)
Is this is a good enough workaround? @jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80081
Approved by: https://github.com/jansel
diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py
index c8b5551..4db246a 100644
--- a/test/distributions/test_distributions.py
+++ b/test/distributions/test_distributions.py
@@ -793,7 +793,14 @@
]
-class TestDistributions(TestCase):
+class DistributionsTestCase(TestCase):
+ def setUp(self):
+ """The tests assume that the validation flag is set."""
+ torch.distributions.Distribution.set_default_validate_args(True)
+ super(DistributionsTestCase, self).setUp()
+
+
+class TestDistributions(DistributionsTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@@ -3240,7 +3247,7 @@
# These tests are only needed for a few distributions that implement custom
# reparameterized gradients. Most .rsample() implementations simply rely on
# the reparameterization trick and do not need to be tested for accuracy.
-class TestRsample(TestCase):
+class TestRsample(DistributionsTestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma(self):
num_samples = 100
@@ -3448,7 +3455,7 @@
]))
-class TestDistributionShapes(TestCase):
+class TestDistributionShapes(DistributionsTestCase):
def setUp(self):
super(TestDistributionShapes, self).setUp()
self.scalar_sample = 1
@@ -3910,7 +3917,7 @@
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
-class TestKL(TestCase):
+class TestKL(DistributionsTestCase):
def setUp(self):
super(TestKL, self).setUp()
@@ -4304,7 +4311,7 @@
]))
-class TestConstraints(TestCase):
+class TestConstraints(DistributionsTestCase):
def test_params_constraints(self):
normalize_probs_dists = (
Categorical,
@@ -4356,7 +4363,7 @@
self.assertTrue(ok.all(), msg=message)
-class TestNumericalStability(TestCase):
+class TestNumericalStability(DistributionsTestCase):
def _test_pdf_score(self,
dist_class,
x,
@@ -4573,7 +4580,7 @@
# TODO: make this a pytest parameterized test
-class TestLazyLogitsInitialization(TestCase):
+class TestLazyLogitsInitialization(DistributionsTestCase):
def setUp(self):
super(TestLazyLogitsInitialization, self).setUp()
# ContinuousBernoulli is not tested because log_prob is not computed simply
@@ -4620,7 +4627,7 @@
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
-class TestAgainstScipy(TestCase):
+class TestAgainstScipy(DistributionsTestCase):
def setUp(self):
super(TestAgainstScipy, self).setUp()
positive_var = torch.randn(20).exp()
@@ -4794,7 +4801,7 @@
self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)
-class TestFunctors(TestCase):
+class TestFunctors(DistributionsTestCase):
def test_cat_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
@@ -4912,9 +4919,9 @@
self.assertEqual(actual_jac, expected_jac)
-class TestValidation(TestCase):
+class TestValidation(DistributionsTestCase):
def setUp(self):
- super(TestCase, self).setUp()
+ super(TestValidation, self).setUp()
def test_valid(self):
for Dist, params in EXAMPLES:
@@ -5007,7 +5014,7 @@
super(TestValidation, self).tearDown()
-class TestJit(TestCase):
+class TestJit(DistributionsTestCase):
def _examples(self):
for Dist, params in EXAMPLES:
for param in params: