allow_inf on test_beta_log_prob (#4354)
* allow_inf on test_beta_log_prob
* Support allow_inf on assertAlmostEqual
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/test/common.py b/test/common.py
index c5b9174..bfde53a 100644
--- a/test/common.py
+++ b/test/common.py
@@ -239,11 +239,11 @@
else:
super(TestCase, self).assertEqual(x, y, message)
- def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None):
+ def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
prec = delta
if places:
prec = 10**(-places)
- self.assertEqual(x, y, prec, msg)
+ self.assertEqual(x, y, prec, msg, allow_inf)
def assertNotEqual(self, x, y, prec=None, message=''):
if prec is None:
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 776cc60..efe7572 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -529,7 +529,7 @@
x = dist.sample()
actual_log_prob = dist.log_prob(x).sum()
expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)[0]
- self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3)
+ self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3, allow_inf=True)
# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")