Setting validation flag for Distributions tests to work with TorchDynamo (#80081)

Is this is a good enough workaround? @jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80081
Approved by: https://github.com/jansel
diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py
index c8b5551..4db246a 100644
--- a/test/distributions/test_distributions.py
+++ b/test/distributions/test_distributions.py
@@ -793,7 +793,14 @@
 ]
 
 
-class TestDistributions(TestCase):
+class DistributionsTestCase(TestCase):
+    def setUp(self):
+        """The tests assume that the validation flag is set."""
+        torch.distributions.Distribution.set_default_validate_args(True)
+        super(DistributionsTestCase, self).setUp()
+
+
+class TestDistributions(DistributionsTestCase):
     _do_cuda_memory_leak_check = True
     _do_cuda_non_default_stream = True
 
@@ -3240,7 +3247,7 @@
 # These tests are only needed for a few distributions that implement custom
 # reparameterized gradients. Most .rsample() implementations simply rely on
 # the reparameterization trick and do not need to be tested for accuracy.
-class TestRsample(TestCase):
+class TestRsample(DistributionsTestCase):
     @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
     def test_gamma(self):
         num_samples = 100
@@ -3448,7 +3455,7 @@
             ]))
 
 
-class TestDistributionShapes(TestCase):
+class TestDistributionShapes(DistributionsTestCase):
     def setUp(self):
         super(TestDistributionShapes, self).setUp()
         self.scalar_sample = 1
@@ -3910,7 +3917,7 @@
         self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
 
 
-class TestKL(TestCase):
+class TestKL(DistributionsTestCase):
 
     def setUp(self):
         super(TestKL, self).setUp()
@@ -4304,7 +4311,7 @@
                 ]))
 
 
-class TestConstraints(TestCase):
+class TestConstraints(DistributionsTestCase):
     def test_params_constraints(self):
         normalize_probs_dists = (
             Categorical,
@@ -4356,7 +4363,7 @@
                 self.assertTrue(ok.all(), msg=message)
 
 
-class TestNumericalStability(TestCase):
+class TestNumericalStability(DistributionsTestCase):
     def _test_pdf_score(self,
                         dist_class,
                         x,
@@ -4573,7 +4580,7 @@
 
 
 # TODO: make this a pytest parameterized test
-class TestLazyLogitsInitialization(TestCase):
+class TestLazyLogitsInitialization(DistributionsTestCase):
     def setUp(self):
         super(TestLazyLogitsInitialization, self).setUp()
         # ContinuousBernoulli is not tested because log_prob is not computed simply
@@ -4620,7 +4627,7 @@
 
 
 @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
-class TestAgainstScipy(TestCase):
+class TestAgainstScipy(DistributionsTestCase):
     def setUp(self):
         super(TestAgainstScipy, self).setUp()
         positive_var = torch.randn(20).exp()
@@ -4794,7 +4801,7 @@
             self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)
 
 
-class TestFunctors(TestCase):
+class TestFunctors(DistributionsTestCase):
     def test_cat_transform(self):
         x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
         x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
@@ -4912,9 +4919,9 @@
         self.assertEqual(actual_jac, expected_jac)
 
 
-class TestValidation(TestCase):
+class TestValidation(DistributionsTestCase):
     def setUp(self):
-        super(TestCase, self).setUp()
+        super(TestValidation, self).setUp()
 
     def test_valid(self):
         for Dist, params in EXAMPLES:
@@ -5007,7 +5014,7 @@
         super(TestValidation, self).tearDown()
 
 
-class TestJit(TestCase):
+class TestJit(DistributionsTestCase):
     def _examples(self):
         for Dist, params in EXAMPLES:
             for param in params: