blob: e9914086b2e68dc7dff56d8b4d8b43f4deaaf2dc [file] [log] [blame]
# Owner(s): ["module: distributions"]
import io
from numbers import Number
import pytest
import torch
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.distributions import (
constraints,
Dirichlet,
Independent,
Normal,
TransformedDistribution,
)
from torch.distributions.transforms import (
_InverseTransform,
AbsTransform,
AffineTransform,
ComposeTransform,
CorrCholeskyTransform,
CumulativeDistributionTransform,
ExpTransform,
identity_transform,
IndependentTransform,
LowerCholeskyTransform,
PositiveDefiniteTransform,
PowerTransform,
ReshapeTransform,
SigmoidTransform,
SoftmaxTransform,
SoftplusTransform,
StickBreakingTransform,
TanhTransform,
Transform,
)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
from torch.testing._internal.common_utils import run_tests
def get_transforms(cache_size):
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2, cache_size=cache_size),
PowerTransform(exponent=-2, cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
SoftplusTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
CorrCholeskyTransform(cache_size=cache_size),
PositiveDefiniteTransform(cache_size=cache_size),
ComposeTransform(
[
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
]
),
ComposeTransform(
[
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
ExpTransform(cache_size=cache_size),
]
),
ComposeTransform(
[
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(
torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
),
]
),
ReshapeTransform((4, 5), (2, 5, 2)),
IndependentTransform(
AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
),
CumulativeDistributionTransform(Normal(0, 1)),
]
transforms += [t.inv for t in transforms]
return transforms
def reshape_transform(transform, shape):
# Needed to squash batch dims for testing jacobian
if isinstance(transform, AffineTransform):
if isinstance(transform.loc, Number):
return transform
try:
return AffineTransform(
transform.loc.expand(shape),
transform.scale.expand(shape),
cache_size=transform._cache_size,
)
except RuntimeError:
return AffineTransform(
transform.loc.reshape(shape),
transform.scale.reshape(shape),
cache_size=transform._cache_size,
)
if isinstance(transform, ComposeTransform):
reshaped_parts = []
for p in transform.parts:
reshaped_parts.append(reshape_transform(p, shape))
return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
if isinstance(transform.inv, AffineTransform):
return reshape_transform(transform.inv, shape).inv
if isinstance(transform.inv, ComposeTransform):
return reshape_transform(transform.inv, shape).inv
return transform
# Generate pytest ids
def transform_id(x):
assert isinstance(x, Transform)
name = (
f"Inv({type(x._inv).__name__})"
if isinstance(x, _InverseTransform)
else f"{type(x).__name__}"
)
return f"{name}(cache_size={x._cache_size})"
def generate_data(transform):
torch.manual_seed(1)
while isinstance(transform, IndependentTransform):
transform = transform.base_transform
if isinstance(transform, ReshapeTransform):
return torch.randn(transform.in_shape)
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (
isinstance(domain, constraints.independent)
and domain is not constraints.real_vector
):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
positive_definite_constraints = [
constraints.lower_cholesky,
constraints.positive_definite,
]
if domain in positive_definite_constraints:
x = torch.randn(6, 6)
x = x.tril(-1) + x.diag().exp().diag_embed()
if domain is constraints.positive_definite:
return x @ x.T
return x
elif codomain in positive_definite_constraints:
return torch.randn(6, 6)
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.real_vector:
# For corr_cholesky the last dim in the vector
# must be of size (dim * dim) // 2
x = torch.empty(3, 6)
x = x.normal_()
return x
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif isinstance(domain, constraints.interval):
x = x.uniform_()
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
return x
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
elif domain is constraints.corr_cholesky:
x = torch.empty(4, 5, 5)
x = x.normal_().tril()
x /= x.norm(dim=-1, keepdim=True)
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
return x
raise ValueError(f"Unsupported domain: {domain}")
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = (
TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
assert transform.inv.inv is transform
@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
if x is y:
assert x == y
else:
assert x != y
assert identity_transform == identity_transform.inv
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
if transform._cache_size == 0:
transform = transform.with_cache(1)
assert transform._cache_size == 1
x = generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
pytest.skip("Not implemented.")
y2 = transform(x)
assert y2 is y
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize("test_cached", [True, False])
def test_forward_inverse(transform, test_cached):
x = generate_data(transform).requires_grad_()
assert transform.domain.check(x).all() # verify that the input data are valid
try:
y = transform(x)
except NotImplementedError:
pytest.skip("Not implemented.")
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
else:
try:
x2 = transform.inv(y.clone()) # bypass cache
except NotImplementedError:
pytest.skip("Not implemented.")
assert x2.shape == transform.inverse_shape(y.shape)
y2 = transform(x2)
if transform.bijective:
# verify function inverse
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
[
f"{transform} t.inv(t(-)) error",
f"x = {x}",
f"y = t(x) = {y}",
f"x2 = t.inv(y) = {x2}",
]
)
else:
# verify weaker function pseudo-inverse
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
[
f"{transform} t(t.inv(t(-))) error",
f"x = {x}",
f"y = t(x) = {y}",
f"x2 = t.inv(y) = {x2}",
f"y2 = t(x2) = {y2}",
]
)
def test_compose_transform_shapes():
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
assert transform0.event_dim == 0
assert transform1.event_dim == 1
assert transform2.event_dim == 2
assert ComposeTransform([transform0, transform1]).event_dim == 1
assert ComposeTransform([transform0, transform2]).event_dim == 2
assert ComposeTransform([transform1, transform2]).event_dim == 2
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
@pytest.mark.parametrize(
("batch_shape", "event_shape", "dist"),
[
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
],
)
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
assert dist.batch_shape == batch_shape
assert dist.event_shape == event_shape
x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
pytest.skip("Not implemented.")
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
x = generate_data(transform).requires_grad_()
def f(x):
return transform(x)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip("Not implemented.")
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
y = generate_data(transform.inv).requires_grad_()
def f(y):
return transform.inv(y)
try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
pytest.skip("Not implemented.")
# check on different inputs
y = generate_data(transform.inv).requires_grad_()
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
x = generate_data(transform).requires_grad_()
def f(x):
y = transform(x)
return transform.log_abs_det_jacobian(x, y)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip("Not implemented.")
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
x = generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
pytest.skip("Not implemented.")
# Test shape
target_shape = x.shape[: x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape
# Expand if required
transform = reshape_transform(transform, x.shape)
ndims = len(x.shape)
event_dim = ndims - transform.domain.event_dim
x_ = x.view((-1,) + x.shape[event_dim:])
n = x_.shape[0]
# Reshape to squash batch dims to a single batch dim
transform = reshape_transform(transform, x_.shape)
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(
transform.inv, ReshapeTransform
):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
elif transform.domain.event_dim == 0:
jac = jacobian(transform, x_)
# assert off-diagonal elements are zero
assert torch.allclose(jac, jac.diagonal().diag_embed())
expected = jac.diagonal().abs().log().reshape(x.shape)
# 3. Transforms with non-0 off-diagonal elements
else:
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
elif isinstance(transform.inv, CorrCholeskyTransform):
jac = jacobian(
lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
tril_matrix_to_vec(x_, diag=-1),
)
elif isinstance(transform, StickBreakingTransform):
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
else:
jac = jacobian(transform, x_)
# Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
# However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
# after reshaping the event dims (see above) to give a batched square matrix whose determinant
# can be computed.
gather_idx_shape = list(jac.shape)
gather_idx_shape[-2] = 1
gather_idxs = (
torch.arange(n)
.reshape((n,) + (1,) * (len(jac.shape) - 1))
.expand(gather_idx_shape)
)
jac = jac.gather(-2, gather_idxs).squeeze(-2)
out_ndims = jac.shape[-2]
jac = jac[
..., :out_ndims
] # Remove extra zero-valued dims (for inverse stick-breaking).
expected = torch.slogdet(jac).logabsdet
assert torch.allclose(actual, expected, atol=1e-5)
@pytest.mark.parametrize(
"event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
)
def test_compose_affine(event_dims):
transforms = [
AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)
base_dist = Normal(0, 1)
if transform.domain.event_dim:
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
dist = TransformedDistribution(base_dist, transform.parts)
assert dist.support.event_dim == max(event_dims)
base_dist = Dirichlet(torch.ones(5))
if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, *event_dims)
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [
ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3)),
]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
data = torch.randn(batch_shape + (3, 2))
assert transform(data).shape == batch_shape + (2, 3)
dist = TransformedDistribution(Normal(data, 1), transforms)
assert dist.batch_shape == batch_shape
assert dist.event_shape == (2, 3)
assert dist.support.event_dim == 2
@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(
base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
):
shape = torch.Size([2, 3, 4, 5])
base_dist = Normal(0, 1)
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
if base_event_dim:
base_dist = Independent(base_dist, base_event_dim)
transforms = [
AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10)),
]
transforms = transforms[:num_transforms]
transform = ComposeTransform(transforms)
# Check validation in .__init__().
if base_batch_dim + base_event_dim < transform.domain.event_dim:
with pytest.raises(ValueError):
TransformedDistribution(base_dist, transforms)
return
d = TransformedDistribution(base_dist, transforms)
# Check sampling is sufficiently expanded.
x = d.sample(sample_shape)
assert x.shape == sample_shape + d.batch_shape + d.event_shape
num_unique = len(set(x.reshape(-1).tolist()))
assert num_unique >= 0.9 * x.numel()
# Check log_prob shape on full samples.
log_prob = d.log_prob(x)
assert log_prob.shape == sample_shape + d.batch_shape
# Check log_prob shape on partial samples.
y = x
while y.dim() > len(d.event_shape):
y = y[0]
log_prob = d.log_prob(y)
assert log_prob.shape == d.batch_shape
def test_save_load_transform():
# Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
# that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
x = torch.linspace(0, 1, 10)
log_prob = dist.log_prob(x)
stream = io.BytesIO()
torch.save(dist, stream)
stream.seek(0)
other = torch.load(stream)
assert torch.allclose(log_prob, other.log_prob(x))
@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
def test_transform_sign(transform: Transform):
try:
sign = transform.sign
except NotImplementedError:
pytest.skip("Not implemented.")
x = generate_data(transform).requires_grad_()
y = transform(x).sum()
(derivatives,) = grad(y, [x])
assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()
if __name__ == "__main__":
run_tests()