blob: 11fcaa5ee552588ff2824b5f10ccda612d09b4a8 [file] [log] [blame]
# Owner(s): ["module: distributions"]
import io
from numbers import Number
import pytest
import torch
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
CorrCholeskyTransform, CumulativeDistributionTransform,
ExpTransform, IndependentTransform,
LowerCholeskyTransform, PowerTransform,
ReshapeTransform, SigmoidTransform, TanhTransform,
SoftmaxTransform, SoftplusTransform, StickBreakingTransform,
identity_transform, Transform, _InverseTransform,
PositiveDefiniteTransform)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
from torch.testing._internal.common_utils import run_tests
def get_transforms(cache_size):
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2,
cache_size=cache_size),
PowerTransform(exponent=-2,
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
SoftplusTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
CorrCholeskyTransform(cache_size=cache_size),
PositiveDefiniteTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ReshapeTransform((4, 5), (2, 5, 2)),
IndependentTransform(
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
1),
CumulativeDistributionTransform(Normal(0, 1)),
]
transforms += [t.inv for t in transforms]
return transforms
def reshape_transform(transform, shape):
# Needed to squash batch dims for testing jacobian
if isinstance(transform, AffineTransform):
if isinstance(transform.loc, Number):
return transform
try:
return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
except RuntimeError:
return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
if isinstance(transform, ComposeTransform):
reshaped_parts = []
for p in transform.parts:
reshaped_parts.append(reshape_transform(p, shape))
return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
if isinstance(transform.inv, AffineTransform):
return reshape_transform(transform.inv, shape).inv
if isinstance(transform.inv, ComposeTransform):
return reshape_transform(transform.inv, shape).inv
return transform
# Generate pytest ids
def transform_id(x):
assert isinstance(x, Transform)
name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}'
return f'{name}(cache_size={x._cache_size})'
def generate_data(transform):
torch.manual_seed(1)
while isinstance(transform, IndependentTransform):
transform = transform.base_transform
if isinstance(transform, ReshapeTransform):
return torch.randn(transform.in_shape)
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (isinstance(domain, constraints.independent) and
domain is not constraints.real_vector):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
positive_definite_constraints = [constraints.lower_cholesky, constraints.positive_definite]
if domain in positive_definite_constraints:
x = torch.randn(6, 6)
x = x.tril(-1) + x.diag().exp().diag_embed()
if domain is constraints.positive_definite:
return x @ x.T
return x
elif codomain in positive_definite_constraints:
return torch.randn(6, 6)
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.real_vector:
# For corr_cholesky the last dim in the vector
# must be of size (dim * dim) // 2
x = torch.empty(3, 6)
x = x.normal_()
return x
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif isinstance(domain, constraints.interval):
x = x.uniform_()
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
return x
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
elif domain is constraints.corr_cholesky:
x = torch.empty(4, 5, 5)
x = x.normal_().tril()
x /= x.norm(dim=-1, keepdim=True)
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
return x
raise ValueError(f'Unsupported domain: {domain}')
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
assert transform.inv.inv is transform
@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
if x is y:
assert x == y
else:
assert x != y
assert identity_transform == identity_transform.inv
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
if transform._cache_size == 0:
transform = transform.with_cache(1)
assert transform._cache_size == 1
x = generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
y2 = transform(x)
assert y2 is y
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize('test_cached', [True, False])
def test_forward_inverse(transform, test_cached):
x = generate_data(transform).requires_grad_()
assert transform.domain.check(x).all() # verify that the input data are valid
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
else:
try:
x2 = transform.inv(y.clone()) # bypass cache
except NotImplementedError:
pytest.skip('Not implemented.')
assert x2.shape == transform.inverse_shape(y.shape)
y2 = transform(x2)
if transform.bijective:
# verify function inverse
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
f'{transform} t.inv(t(-)) error',
f'x = {x}',
f'y = t(x) = {y}',
f'x2 = t.inv(y) = {x2}',
])
else:
# verify weaker function pseudo-inverse
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
f'{transform} t(t.inv(t(-))) error',
f'x = {x}',
f'y = t(x) = {y}',
f'x2 = t.inv(y) = {x2}',
f'y2 = t(x2) = {y2}',
])
def test_compose_transform_shapes():
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
assert transform0.event_dim == 0
assert transform1.event_dim == 1
assert transform2.event_dim == 2
assert ComposeTransform([transform0, transform1]).event_dim == 1
assert ComposeTransform([transform0, transform2]).event_dim == 2
assert ComposeTransform([transform1, transform2]).event_dim == 2
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
@pytest.mark.parametrize(('batch_shape', 'event_shape', 'dist'), [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
])
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
assert dist.batch_shape == batch_shape
assert dist.event_shape == event_shape
x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
pytest.skip('Not implemented.')
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
x = generate_data(transform).requires_grad_()
def f(x):
return transform(x)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
y = generate_data(transform.inv).requires_grad_()
def f(y):
return transform.inv(y)
try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
y = generate_data(transform.inv).requires_grad_()
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
x = generate_data(transform).requires_grad_()
def f(x):
y = transform(x)
return transform.log_abs_det_jacobian(x, y)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
x = generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
pytest.skip('Not implemented.')
# Test shape
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape
# Expand if required
transform = reshape_transform(transform, x.shape)
ndims = len(x.shape)
event_dim = ndims - transform.domain.event_dim
x_ = x.view((-1,) + x.shape[event_dim:])
n = x_.shape[0]
# Reshape to squash batch dims to a single batch dim
transform = reshape_transform(transform, x_.shape)
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
elif transform.domain.event_dim == 0:
jac = jacobian(transform, x_)
# assert off-diagonal elements are zero
assert torch.allclose(jac, jac.diagonal().diag_embed())
expected = jac.diagonal().abs().log().reshape(x.shape)
# 3. Transforms with non-0 off-diagonal elements
else:
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
elif isinstance(transform.inv, CorrCholeskyTransform):
jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
tril_matrix_to_vec(x_, diag=-1))
elif isinstance(transform, StickBreakingTransform):
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
else:
jac = jacobian(transform, x_)
# Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
# However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
# after reshaping the event dims (see above) to give a batched square matrix whose determinant
# can be computed.
gather_idx_shape = list(jac.shape)
gather_idx_shape[-2] = 1
gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape)
jac = jac.gather(-2, gather_idxs).squeeze(-2)
out_ndims = jac.shape[-2]
jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking).
expected = torch.slogdet(jac).logabsdet
assert torch.allclose(actual, expected, atol=1e-5)
@pytest.mark.parametrize("event_dims",
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
ids=str)
def test_compose_affine(event_dims):
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)
base_dist = Normal(0, 1)
if transform.domain.event_dim:
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
dist = TransformedDistribution(base_dist, transform.parts)
assert dist.support.event_dim == max(event_dims)
base_dist = Dirichlet(torch.ones(5))
if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, *event_dims)
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3))]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
data = torch.randn(batch_shape + (3, 2))
assert transform(data).shape == batch_shape + (2, 3)
dist = TransformedDistribution(Normal(data, 1), transforms)
assert dist.batch_shape == batch_shape
assert dist.event_shape == (2, 3)
assert dist.support.event_dim == 2
@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
num_transforms, sample_shape):
shape = torch.Size([2, 3, 4, 5])
base_dist = Normal(0, 1)
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
if base_event_dim:
base_dist = Independent(base_dist, base_event_dim)
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10))]
transforms = transforms[:num_transforms]
transform = ComposeTransform(transforms)
# Check validation in .__init__().
if base_batch_dim + base_event_dim < transform.domain.event_dim:
with pytest.raises(ValueError):
TransformedDistribution(base_dist, transforms)
return
d = TransformedDistribution(base_dist, transforms)
# Check sampling is sufficiently expanded.
x = d.sample(sample_shape)
assert x.shape == sample_shape + d.batch_shape + d.event_shape
num_unique = len(set(x.reshape(-1).tolist()))
assert num_unique >= 0.9 * x.numel()
# Check log_prob shape on full samples.
log_prob = d.log_prob(x)
assert log_prob.shape == sample_shape + d.batch_shape
# Check log_prob shape on partial samples.
y = x
while y.dim() > len(d.event_shape):
y = y[0]
log_prob = d.log_prob(y)
assert log_prob.shape == d.batch_shape
def test_save_load_transform():
# Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
# that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
x = torch.linspace(0, 1, 10)
log_prob = dist.log_prob(x)
stream = io.BytesIO()
torch.save(dist, stream)
stream.seek(0)
other = torch.load(stream)
assert torch.allclose(log_prob, other.log_prob(x))
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_transform_sign(transform: Transform):
try:
sign = transform.sign
except NotImplementedError:
pytest.skip('Not implemented.')
x = generate_data(transform).requires_grad_()
y = transform(x).sum()
derivatives, = grad(y, [x])
assert torch.less(torch.as_tensor(0.), derivatives * sign).all()
if __name__ == "__main__":
run_tests()