blob: 278360ad4f875153c62878cb7302d6ac79582336 [file] [log] [blame]
"""
Note [Randomized statistical tests]
-----------------------------------
This note describes how to maintain tests in this file as random sources
change. This file contains two types of randomized tests:
1. The easier type of randomized test are tests that should always pass but are
initialized with random data. If these fail something is wrong, but it's
fine to use a fixed seed by inheriting from common.TestCase.
2. The trickier tests are statistical tests. These tests explicitly call
set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
These statistical tests have a known positive failure rate
(we set failure_rate=1e-3 by default). We need to balance strength of these
tests with annoyance of false alarms. One way that works is to specifically
set seeds in each of the randomized tests. When a random generator
occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
of these statistical tests may (rarely) fail. If one fails in this case,
it's fine to increment the seed of the failing test (but you shouldn't need
to increment it more than once; otherwise something is probably actually
wrong).
"""
import math
import unittest
from collections import namedtuple
from itertools import product
import torch
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy,
Dirichlet, Exponential, Gamma, Laplace,
Normal, OneHotCategorical, Uniform)
TEST_NUMPY = True
try:
import numpy as np
import scipy.stats
import scipy.special
except ImportError:
TEST_NUMPY = False
# Register all distributions for generic tests.
Example = namedtuple('Example', ['Dist', 'params'])
EXAMPLES = [
Example(Bernoulli, [
{'probs': Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)},
{'probs': Variable(torch.Tensor([0.3]), requires_grad=True)},
{'probs': 0.3},
]),
Example(Beta, [
{
'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
'beta': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
},
{
'alpha': Variable(torch.exp(torch.randn(4)), requires_grad=True),
'beta': Variable(torch.exp(torch.randn(4)), requires_grad=True),
},
]),
Example(Categorical, [
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
Example(OneHotCategorical, [
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
Example(Cauchy, [
{'loc': 0.0, 'scale': 1.0},
{'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
{'loc': Variable(torch.Tensor([[0.0], [0.0]])),
'scale': Variable(torch.Tensor([[1.0], [1.0]]))}
]),
Example(Gamma, [
{
'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
'beta': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
},
{
'alpha': Variable(torch.exp(torch.randn(1)), requires_grad=True),
'beta': Variable(torch.exp(torch.randn(1)), requires_grad=True),
},
]),
Example(Dirichlet, [
{'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
{'alpha': Variable(torch.exp(torch.randn(4)), requires_grad=True)},
]),
Example(Exponential, [
{'rate': Variable(torch.randn(5, 5).abs(), requires_grad=True)},
{'rate': Variable(torch.randn(1).abs(), requires_grad=True)},
]),
Example(Normal, [
{
'mean': Variable(torch.randn(5, 5), requires_grad=True),
'std': Variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
'mean': Variable(torch.randn(1), requires_grad=True),
'std': Variable(torch.randn(1), requires_grad=True),
},
{
'mean': torch.Tensor([1.0, 0.0]),
'std': torch.Tensor([1e-5, 1e-5]),
},
]),
Example(Uniform, [
{
'low': Variable(torch.zeros(5, 5), requires_grad=True),
'high': Variable(torch.ones(5, 5), requires_grad=True),
},
{
'low': Variable(torch.zeros(1), requires_grad=True),
'high': Variable(torch.ones(1), requires_grad=True),
},
{
'low': torch.Tensor([1.0, 1.0]),
'high': torch.Tensor([2.0, 3.0]),
},
]),
Example(Laplace, [
{
'loc': Variable(torch.randn(5, 5), requires_grad=True),
'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
'loc': Variable(torch.randn(1), requires_grad=True),
'scale': Variable(torch.randn(1), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, 0.0]),
'scale': torch.Tensor([1e-5, 1e-5]),
},
]),
]
class TestDistributions(TestCase):
def _gradcheck_log_prob(self, dist_ctor, ctor_params):
# performs gradient checks on log_prob
distribution = dist_ctor(*ctor_params)
s = distribution.sample()
expected_shape = distribution.batch_shape + distribution.event_shape
if not expected_shape:
expected_shape = torch.Size((1,)) # Work around lack of scalars.
self.assertEqual(s.size(), expected_shape)
def apply_fn(*params):
return dist_ctor(*params).log_prob(s)
gradcheck(apply_fn, ctor_params, raise_exception=True)
def _check_log_prob(self, dist, asset_fn):
# checks that the log_prob matches a reference function
s = dist.sample()
log_probs = dist.log_prob(s)
for i, (val, log_prob) in enumerate(zip(s.data.view(-1), log_probs.data.view(-1))):
asset_fn(i, val, log_prob)
def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
num_samples=10000, failure_rate=1e-3):
# Checks that the .sample() method matches a reference function.
torch_samples = torch_dist.sample_n(num_samples).squeeze()
if isinstance(torch_samples, Variable):
torch_samples = torch_samples.data
torch_samples = torch_samples.cpu().numpy()
ref_samples = ref_dist.rvs(num_samples)
if multivariate:
# Project onto a random axis.
axis = np.random.normal(size=torch_samples.shape[-1])
axis /= np.linalg.norm(axis)
torch_samples = np.dot(torch_samples, axis)
ref_samples = np.dot(ref_samples, axis)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
samples.sort()
samples = np.array(samples)[:, 1]
# Aggragate into bins filled with roughly zero-mean unit-variance RVs.
num_bins = 10
samples_per_bin = len(samples) // num_bins
bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
stddev = samples_per_bin ** -0.5
threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
message = '{}.sample() is biased:\n{}'.format(message, bins)
for bias in bins:
self.assertLess(-threshold, bias, message)
self.assertLess(bias, threshold, message)
def _check_enumerate_support(self, dist, examples):
for param, expected in examples:
param = torch.Tensor(param)
expected = torch.Tensor(expected)
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
param = Variable(param)
expected = Variable(expected)
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
def test_bernoulli(self):
p = Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)
r = Variable(torch.Tensor([0.3]), requires_grad=True)
s = 0.3
self.assertEqual(Bernoulli(p).sample_n(8).size(), (8, 3))
self.assertEqual(Bernoulli(r).sample_n(8).size(), (8, 1))
self.assertEqual(Bernoulli(r).sample().size(), (1,))
self.assertEqual(Bernoulli(r).sample((3, 2)).size(), (3, 2, 1))
self.assertEqual(Bernoulli(s).sample().size(), (1,))
self._gradcheck_log_prob(Bernoulli, (p,))
def ref_log_prob(idx, val, log_prob):
prob = p.data[idx]
self.assertEqual(log_prob, math.log(prob if val else 1 - prob))
self._check_log_prob(Bernoulli(p), ref_log_prob)
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
def test_bernoulli_enumerate_support(self):
examples = [
([0.1], [[0], [1]]),
([0.1, 0.9], [[0, 0], [1, 1]]),
([[0.1, 0.2], [0.3, 0.4]], [[[0, 0], [0, 0]], [[1, 1], [1, 1]]]),
]
self._check_enumerate_support(Bernoulli, examples)
def test_bernoulli_3d(self):
p = Variable(torch.Tensor(2, 3, 5).fill_(0.5), requires_grad=True)
self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(Bernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
def test_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
# TODO: this should return a 0-dim tensor once we have Scalar support
self.assertEqual(Categorical(p).sample().size(), (1,))
self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
self.assertEqual(Categorical(p).sample_n(1).size(), (1,))
self._gradcheck_log_prob(Categorical, (p,))
self.assertRaises(NotImplementedError, Categorical(p).rsample)
def test_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
self.assertEqual(Categorical(p).sample().size(), (2,))
self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
self.assertEqual(Categorical(p).sample_n(6).size(), (6, 2))
self._gradcheck_log_prob(Categorical, (p,))
# sample check for extreme value of probs
set_rng_seed(0)
self.assertEqual(Categorical(s).sample(sample_shape=(2,)).data,
torch.Tensor([[0, 1], [0, 1]]))
def ref_log_prob(idx, val, log_prob):
sample_prob = p.data[idx][val] / p.data[idx].sum()
self.assertEqual(log_prob, math.log(sample_prob))
self._check_log_prob(Categorical(p), ref_log_prob)
def test_categorical_enumerate_support(self):
examples = [
([0.1, 0.2, 0.7], [0, 1, 2]),
([[0.1, 0.9], [0.3, 0.7]], [[0, 0], [1, 1]]),
]
self._check_enumerate_support(Categorical, examples)
def test_one_hot_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(OneHotCategorical(p).sample_n(1).size(), (1, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))
self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
def test_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(OneHotCategorical(p).sample_n(6).size(), (6, 2, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))
dist = OneHotCategorical(p)
x = dist.sample()
self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
def test_one_hot_categorical_enumerate_support(self):
examples = [
([0.1, 0.2, 0.7], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
([[0.1, 0.9], [0.3, 0.7]], [[[1, 0], [1, 0]], [[0, 1], [0, 1]]]),
]
self._check_enumerate_support(OneHotCategorical, examples)
def test_uniform(self):
low = Variable(torch.zeros(5, 5), requires_grad=True)
high = Variable(torch.ones(5, 5) * 3, requires_grad=True)
low_1d = Variable(torch.zeros(1), requires_grad=True)
high_1d = Variable(torch.ones(1) * 3, requires_grad=True)
self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
self.assertEqual(Uniform(low, high).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
self.assertEqual(Uniform(low_1d, high_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Uniform(0.0, 1.0).sample_n(1).size(), (1,))
# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d)
above_high = Variable(torch.Tensor([4.0]))
below_low = Variable(torch.Tensor([-1.0]))
self.assertEqual(uniform.log_prob(above_high).data[0], -float('inf'), allow_inf=True)
self.assertEqual(uniform.log_prob(below_low).data[0], -float('inf'), allow_inf=True)
set_rng_seed(1)
self._gradcheck_log_prob(Uniform, (low, high))
self._gradcheck_log_prob(Uniform, (low, 1.0))
self._gradcheck_log_prob(Uniform, (0.0, high))
state = torch.get_rng_state()
rand = low.new(low.size()).uniform_()
torch.set_rng_state(state)
u = Uniform(low, high).rsample()
u.backward(torch.ones_like(u))
self.assertEqual(low.grad, 1 - rand)
self.assertEqual(high.grad, rand)
low.grad.zero_()
high.grad.zero_()
def test_cauchy(self):
loc = Variable(torch.zeros(5, 5), requires_grad=True)
scale = Variable(torch.ones(5, 5), requires_grad=True)
loc_1d = Variable(torch.zeros(1), requires_grad=True)
scale_1d = Variable(torch.ones(1), requires_grad=True)
self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
self.assertEqual(Cauchy(loc, scale).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Cauchy(loc_1d, scale_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Cauchy(0.0, 1.0).sample_n(1).size(), (1,))
set_rng_seed(1)
self._gradcheck_log_prob(Uniform, (loc, scale))
self._gradcheck_log_prob(Uniform, (loc, 1.0))
self._gradcheck_log_prob(Uniform, (0.0, scale))
state = torch.get_rng_state()
eps = loc.new(loc.size()).cauchy_()
torch.set_rng_state(state)
c = Cauchy(loc, scale).rsample()
c.backward(torch.ones_like(c))
self.assertEqual(loc.grad, torch.ones_like(scale))
self.assertEqual(scale.grad, eps)
loc.grad.zero_()
scale.grad.zero_()
def test_normal(self):
mean = Variable(torch.randn(5, 5), requires_grad=True)
std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
mean_1d = Variable(torch.randn(1), requires_grad=True)
std_1d = Variable(torch.randn(1), requires_grad=True)
mean_delta = torch.Tensor([1.0, 0.0])
std_delta = torch.Tensor([1e-5, 1e-5])
self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1,))
self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1,))
self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(1)
self.assertEqual(Normal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
prec=1e-4)
self._gradcheck_log_prob(Normal, (mean, std))
self._gradcheck_log_prob(Normal, (mean, 1.0))
self._gradcheck_log_prob(Normal, (0.0, std))
state = torch.get_rng_state()
eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
torch.set_rng_state(state)
z = Normal(mean, std).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(mean.grad, torch.ones_like(mean))
self.assertEqual(std.grad, eps)
mean.grad.zero_()
std.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = mean.data.view(-1)[idx]
s = std.data.view(-1)[idx]
expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
math.sqrt(2 * math.pi * s ** 2))
self.assertAlmostEqual(log_prob, math.log(expected), places=3)
self._check_log_prob(Normal(mean, std), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Normal(mean, std),
scipy.stats.norm(loc=mean, scale=std),
'Normal(mean={}, std={})'.format(mean, std))
def test_exponential(self):
rate = Variable(torch.randn(5, 5).abs(), requires_grad=True)
rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
self._gradcheck_log_prob(Exponential, (rate,))
state = torch.get_rng_state()
eps = rate.new(rate.size()).exponential_()
torch.set_rng_state(state)
z = Exponential(rate).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(rate.grad, -eps / rate**2)
rate.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = rate.data.view(-1)[idx]
expected = math.log(m) - m * x
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(Exponential(rate), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_exponential_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for rate in [1e-5, 1.0, 10.]:
self._check_sampler_sampler(Exponential(rate),
scipy.stats.expon(scale=1. / rate),
'Exponential(rate={})'.format(rate))
def test_laplace(self):
loc = Variable(torch.randn(5, 5), requires_grad=True)
scale = Variable(torch.randn(5, 5).abs(), requires_grad=True)
loc_1d = Variable(torch.randn(1), requires_grad=True)
scale_1d = Variable(torch.randn(1), requires_grad=True)
loc_delta = torch.Tensor([1.0, 0.0])
scale_delta = torch.Tensor([1e-5, 1e-5])
self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
self.assertEqual(Laplace(loc, scale).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Laplace(loc_1d, scale_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Laplace(0.2, .6).sample_n(1).size(), (1,))
self.assertEqual(Laplace(-0.7, 50.0).sample_n(1).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(0)
self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
prec=1e-4)
self._gradcheck_log_prob(Laplace, (loc, scale))
self._gradcheck_log_prob(Laplace, (loc, 1.0))
self._gradcheck_log_prob(Laplace, (0.0, scale))
state = torch.get_rng_state()
eps = torch.ones_like(loc).uniform_(-.5, .5)
torch.set_rng_state(state)
z = Laplace(loc, scale).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(loc.grad, torch.ones_like(loc))
self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
loc.grad.zero_()
scale.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = loc.data.view(-1)[idx]
s = scale.data.view(-1)[idx]
expected = (-math.log(2 * s) - abs(x - m) / s)
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(Laplace(loc, scale), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_laplace_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Laplace(loc, scale),
scipy.stats.laplace(loc=loc, scale=scale),
'Laplace(loc={}, scale={})'.format(loc, scale))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
beta = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
alpha_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
beta_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample_n(5).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample_n(1).size(), (1,))
def ref_log_prob(idx, x, log_prob):
a = alpha.data.view(-1)[idx]
b = beta.data.view(-1)[idx]
expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Gamma(alpha, beta),
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(alpha={}, beta={})'.format(alpha, beta))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_sample_grad(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
num_samples = 100
for alpha in [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
alphas = Variable(torch.FloatTensor([alpha] * num_samples), requires_grad=True)
betas = Variable(torch.ones(num_samples).type_as(alphas))
x = Gamma(alphas, betas).rsample()
x.sum().backward()
x, ind = x.data.sort()
x = x.numpy()
actual_grad = alphas.grad.data[ind].numpy()
# Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
cdf = scipy.stats.gamma.cdf
pdf = scipy.stats.gamma.pdf
eps = 0.01 * alpha / (1.0 + alpha ** 0.5)
cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
cdf_x = pdf(x, alpha)
expected_grad = -cdf_alpha / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.0005,
'\n'.join(['Bad gradients for Gamma({}, 1)'.format(alpha),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
'at alpha={}, x={}'.format(alpha, x[rel_error.argmax()])]))
def test_dirichlet_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_dirichlet_log_prob(self):
num_samples = 10
alpha = torch.exp(torch.randn(5))
dist = Dirichlet(alpha)
x = dist.sample((num_samples,))
actual_log_prob = dist.log_prob(x)
for i in range(num_samples):
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_dirichlet_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
alpha = torch.exp(torch.randn(3))
self._check_sampler_sampler(Dirichlet(alpha),
scipy.stats.dirichlet(alpha.numpy()),
'Dirichlet(alpha={})'.format(list(alpha)),
multivariate=True)
def test_beta_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
beta = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
beta_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
self.assertEqual(Beta(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Beta(alpha, beta).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Beta(alpha_1d, beta_1d).sample().size(), (4,))
self.assertEqual(Beta(alpha_1d, beta_1d).sample((1,)).size(), (1, 4))
self.assertEqual(Beta(0.1, 0.3).sample().size(), (1,))
self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_beta_log_prob(self):
for _ in range(100):
alpha = np.exp(np.random.normal())
beta = np.exp(np.random.normal())
dist = Beta(alpha, beta)
x = dist.sample()
actual_log_prob = dist.log_prob(x).sum()
expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)[0]
self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3, allow_inf=True)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_beta_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for alpha, beta in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Beta(alpha, beta),
scipy.stats.beta(alpha, beta),
'Beta(alpha={}, beta={})'.format(alpha, beta))
# Check that small alphas do not cause NANs.
for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_beta_sample_grad(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
num_samples = 20
for alpha, beta in product([1e-2, 1e0, 1e2], [1e-2, 1e0, 1e2]):
alphas = Variable(torch.Tensor([alpha] * num_samples), requires_grad=True)
betas = Variable(torch.Tensor([beta] * num_samples))
x = Beta(alphas, betas).rsample()
x.sum().backward()
x, ind = x.data.sort()
x = x.numpy()
actual_grad = alphas.grad.data[ind].numpy()
# Compare with expected gradient dx/dalpha along constant cdf(x,alpha,beta).
cdf = scipy.stats.beta.cdf
pdf = scipy.stats.beta.pdf
eps = 0.02 * alpha / (1.0 + np.sqrt(alpha))
cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (2 * eps)
cdf_x = pdf(x, alpha, beta)
expected_grad = -cdf_alpha / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-100)
self.assertLess(np.max(rel_error), 0.01,
'\n'.join(['Bad gradients for Beta({}, {})'.format(alpha, beta),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max())]))
def test_valid_parameter_broadcasting(self):
# Test correct broadcasting of parameter sizes for distributions that have multiple
# parameters.
# example type (distribution instance, expected sample shape)
valid_examples = [
(Normal(mean=torch.Tensor([0, 0]), std=1),
(2,)),
(Normal(mean=0, std=torch.Tensor([1, 1])),
(2,)),
(Normal(mean=torch.Tensor([0, 0]), std=torch.Tensor([1])),
(2,)),
(Normal(mean=torch.Tensor([0, 0]), std=torch.Tensor([[1], [1]])),
(2, 2)),
(Normal(mean=torch.Tensor([0, 0]), std=torch.Tensor([[1]])),
(1, 2)),
(Normal(mean=torch.Tensor([0]), std=torch.Tensor([[1]])),
(1, 1)),
(Gamma(alpha=torch.Tensor([1, 1]), beta=1),
(2,)),
(Gamma(alpha=1, beta=torch.Tensor([1, 1])),
(2,)),
(Gamma(alpha=torch.Tensor([1, 1]), beta=torch.Tensor([[1], [1], [1]])),
(3, 2)),
(Gamma(alpha=torch.Tensor([1, 1]), beta=torch.Tensor([[1], [1]])),
(2, 2)),
(Gamma(alpha=torch.Tensor([1, 1]), beta=torch.Tensor([[1]])),
(1, 2)),
(Gamma(alpha=torch.Tensor([1]), beta=torch.Tensor([[1]])),
(1, 1)),
(Laplace(loc=torch.Tensor([0, 0]), scale=1),
(2,)),
(Laplace(loc=0, scale=torch.Tensor([1, 1])),
(2,)),
(Laplace(loc=torch.Tensor([0, 0]), scale=torch.Tensor([1])),
(2,)),
(Laplace(loc=torch.Tensor([0, 0]), scale=torch.Tensor([[1], [1]])),
(2, 2)),
(Laplace(loc=torch.Tensor([0, 0]), scale=torch.Tensor([[1]])),
(1, 2)),
(Laplace(loc=torch.Tensor([0]), scale=torch.Tensor([[1]])),
(1, 1)),
]
for dist, expected_size in valid_examples:
dist_sample_size = dist.sample().size()
self.assertEqual(dist_sample_size, expected_size,
'actual size: {} != expected size: {}'.format(dist_sample_size, expected_size))
def test_invalid_parameter_broadcasting(self):
# invalid broadcasting cases; should throw error
# example type (distribution class, distribution params)
invalid_examples = [
(Normal, {
'mean': torch.Tensor([[0, 0]]),
'std': torch.Tensor([1, 1, 1, 1])
}),
(Normal, {
'mean': torch.Tensor([[[0, 0, 0], [0, 0, 0]]]),
'std': torch.Tensor([1, 1])
}),
(Gamma, {
'alpha': torch.Tensor([0, 0]),
'beta': torch.Tensor([1, 1, 1])
}),
(Laplace, {
'loc': torch.Tensor([0, 0]),
'scale': torch.Tensor([1, 1, 1])
})
]
for dist, kwargs in invalid_examples:
self.assertRaises(RuntimeError, dist, **kwargs)
class TestDistributionShapes(TestCase):
def setUp(self):
super(TestCase, self).setUp()
self.scalar_sample = 1
self.tensor_sample_1 = torch.ones(3, 2)
self.tensor_sample_2 = torch.ones(3, 2, 3)
def test_entropy_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
actual_shape = dist.entropy().size()
expected_shape = dist._batch_shape
if not expected_shape:
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
Dist.__name__, i, len(params), expected_shape, actual_shape)
self.assertEqual(actual_shape, expected_shape, message=message)
def test_bernoulli_shape_scalar_params(self):
bernoulli = Bernoulli(0.3)
self.assertEqual(bernoulli._batch_shape, torch.Size())
self.assertEqual(bernoulli._event_shape, torch.Size())
self.assertEqual(bernoulli.sample().size(), torch.Size((1,)))
self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample)
self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_bernoulli_shape_tensor_params(self):
bernoulli = Bernoulli(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
self.assertEqual(bernoulli._event_shape, torch.Size(()))
self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
def test_beta_shape_scalar_params(self):
dist = Beta(0.1, 0.1)
self.assertEqual(dist._batch_shape, torch.Size())
self.assertEqual(dist._event_shape, torch.Size())
self.assertEqual(dist.sample().size(), torch.Size((1,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_beta_shape_tensor_params(self):
dist = Beta(torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def test_categorical_shape(self):
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_one_hot_categorical_shape(self):
dist = OneHotCategorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
def test_cauchy_shape_scalar_params(self):
cauchy = Cauchy(0, 1)
self.assertEqual(cauchy._batch_shape, torch.Size())
self.assertEqual(cauchy._event_shape, torch.Size())
self.assertEqual(cauchy.sample().size(), torch.Size((1,)))
self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample)
self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_cauchy_shape_tensor_params(self):
cauchy = Cauchy(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
self.assertEqual(cauchy._batch_shape, torch.Size((2,)))
self.assertEqual(cauchy._event_shape, torch.Size(()))
self.assertEqual(cauchy.sample().size(), torch.Size((2,)))
self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
def test_dirichlet_shape(self):
dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def test_gamma_shape_scalar_params(self):
gamma = Gamma(1, 1)
self.assertEqual(gamma._batch_shape, torch.Size())
self.assertEqual(gamma._event_shape, torch.Size())
self.assertEqual(gamma.sample().size(), torch.Size((1,)))
self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, gamma.log_prob, self.scalar_sample)
self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_gamma_shape_tensor_params(self):
gamma = Gamma(torch.Tensor([1, 1]), torch.Tensor([1, 1]))
self.assertEqual(gamma._batch_shape, torch.Size((2,)))
self.assertEqual(gamma._event_shape, torch.Size(()))
self.assertEqual(gamma.sample().size(), torch.Size((2,)))
self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
def test_normal_shape_scalar_params(self):
normal = Normal(0, 1)
self.assertEqual(normal._batch_shape, torch.Size())
self.assertEqual(normal._event_shape, torch.Size())
self.assertEqual(normal.sample().size(), torch.Size((1,)))
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_normal_shape_tensor_params(self):
normal = Normal(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
self.assertEqual(normal._batch_shape, torch.Size((2,)))
self.assertEqual(normal._event_shape, torch.Size(()))
self.assertEqual(normal.sample().size(), torch.Size((2,)))
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
def test_uniform_shape_scalar_params(self):
uniform = Uniform(0, 1)
self.assertEqual(uniform._batch_shape, torch.Size())
self.assertEqual(uniform._event_shape, torch.Size())
self.assertEqual(uniform.sample().size(), torch.Size((1,)))
self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample)
self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_uniform_shape_tensor_params(self):
uniform = Uniform(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
self.assertEqual(uniform._batch_shape, torch.Size((2,)))
self.assertEqual(uniform._event_shape, torch.Size(()))
self.assertEqual(uniform.sample().size(), torch.Size((2,)))
self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
def test_exponential_shape_scalar_param(self):
expon = Exponential(1.)
self.assertEqual(expon._batch_shape, torch.Size())
self.assertEqual(expon._event_shape, torch.Size())
self.assertEqual(expon.sample().size(), torch.Size((1,)))
self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_exponential_shape_tensor_param(self):
expon = Exponential(torch.Tensor([1, 1]))
self.assertEqual(expon._batch_shape, torch.Size((2,)))
self.assertEqual(expon._event_shape, torch.Size(()))
self.assertEqual(expon.sample().size(), torch.Size((2,)))
self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
def test_laplace_shape_scalar_params(self):
laplace = Laplace(0, 1)
self.assertEqual(laplace._batch_shape, torch.Size())
self.assertEqual(laplace._event_shape, torch.Size())
self.assertEqual(laplace.sample().size(), torch.Size((1,)))
self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_laplace_shape_tensor_params(self):
laplace = Laplace(torch.Tensor([0, 0]), torch.Tensor([1, 1]))
self.assertEqual(laplace._batch_shape, torch.Size((2,)))
self.assertEqual(laplace._event_shape, torch.Size(()))
self.assertEqual(laplace.sample().size(), torch.Size((2,)))
self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
if __name__ == '__main__':
run_tests()