| from typing import Dict |
| |
| import torch |
| from torch.distributions import constraints |
| from torch.distributions.distribution import Distribution |
| from torch.distributions.independent import Independent |
| from torch.distributions.transforms import ComposeTransform, Transform |
| from torch.distributions.utils import _sum_rightmost |
| |
| __all__ = ["TransformedDistribution"] |
| |
| |
| class TransformedDistribution(Distribution): |
| r""" |
| Extension of the Distribution class, which applies a sequence of Transforms |
| to a base distribution. Let f be the composition of transforms applied:: |
| |
| X ~ BaseDistribution |
| Y = f(X) ~ TransformedDistribution(BaseDistribution, f) |
| log p(Y) = log p(X) + log |det (dX/dY)| |
| |
| Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the |
| maximum shape of its base distribution and its transforms, since transforms |
| can introduce correlations among events. |
| |
| An example for the usage of :class:`TransformedDistribution` would be:: |
| |
| # Building a Logistic Distribution |
| # X ~ Uniform(0, 1) |
| # f = a + b * logit(X) |
| # Y ~ f(X) ~ Logistic(a, b) |
| base_distribution = Uniform(0, 1) |
| transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] |
| logistic = TransformedDistribution(base_distribution, transforms) |
| |
| For more examples, please look at the implementations of |
| :class:`~torch.distributions.gumbel.Gumbel`, |
| :class:`~torch.distributions.half_cauchy.HalfCauchy`, |
| :class:`~torch.distributions.half_normal.HalfNormal`, |
| :class:`~torch.distributions.log_normal.LogNormal`, |
| :class:`~torch.distributions.pareto.Pareto`, |
| :class:`~torch.distributions.weibull.Weibull`, |
| :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and |
| :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` |
| """ |
| arg_constraints: Dict[str, constraints.Constraint] = {} |
| |
| def __init__(self, base_distribution, transforms, validate_args=None): |
| if isinstance(transforms, Transform): |
| self.transforms = [ |
| transforms, |
| ] |
| elif isinstance(transforms, list): |
| if not all(isinstance(t, Transform) for t in transforms): |
| raise ValueError( |
| "transforms must be a Transform or a list of Transforms" |
| ) |
| self.transforms = transforms |
| else: |
| raise ValueError( |
| f"transforms must be a Transform or list, but was {transforms}" |
| ) |
| |
| # Reshape base_distribution according to transforms. |
| base_shape = base_distribution.batch_shape + base_distribution.event_shape |
| base_event_dim = len(base_distribution.event_shape) |
| transform = ComposeTransform(self.transforms) |
| if len(base_shape) < transform.domain.event_dim: |
| raise ValueError( |
| "base_distribution needs to have shape with size at least {}, but got {}.".format( |
| transform.domain.event_dim, base_shape |
| ) |
| ) |
| forward_shape = transform.forward_shape(base_shape) |
| expanded_base_shape = transform.inverse_shape(forward_shape) |
| if base_shape != expanded_base_shape: |
| base_batch_shape = expanded_base_shape[ |
| : len(expanded_base_shape) - base_event_dim |
| ] |
| base_distribution = base_distribution.expand(base_batch_shape) |
| reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim |
| if reinterpreted_batch_ndims > 0: |
| base_distribution = Independent( |
| base_distribution, reinterpreted_batch_ndims |
| ) |
| self.base_dist = base_distribution |
| |
| # Compute shapes. |
| transform_change_in_event_dim = ( |
| transform.codomain.event_dim - transform.domain.event_dim |
| ) |
| event_dim = max( |
| transform.codomain.event_dim, # the transform is coupled |
| base_event_dim + transform_change_in_event_dim, # the base dist is coupled |
| ) |
| assert len(forward_shape) >= event_dim |
| cut = len(forward_shape) - event_dim |
| batch_shape = forward_shape[:cut] |
| event_shape = forward_shape[cut:] |
| super().__init__(batch_shape, event_shape, validate_args=validate_args) |
| |
| def expand(self, batch_shape, _instance=None): |
| new = self._get_checked_instance(TransformedDistribution, _instance) |
| batch_shape = torch.Size(batch_shape) |
| shape = batch_shape + self.event_shape |
| for t in reversed(self.transforms): |
| shape = t.inverse_shape(shape) |
| base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] |
| new.base_dist = self.base_dist.expand(base_batch_shape) |
| new.transforms = self.transforms |
| super(TransformedDistribution, new).__init__( |
| batch_shape, self.event_shape, validate_args=False |
| ) |
| new._validate_args = self._validate_args |
| return new |
| |
| @constraints.dependent_property(is_discrete=False) |
| def support(self): |
| if not self.transforms: |
| return self.base_dist.support |
| support = self.transforms[-1].codomain |
| if len(self.event_shape) > support.event_dim: |
| support = constraints.independent( |
| support, len(self.event_shape) - support.event_dim |
| ) |
| return support |
| |
| @property |
| def has_rsample(self): |
| return self.base_dist.has_rsample |
| |
| def sample(self, sample_shape=torch.Size()): |
| """ |
| Generates a sample_shape shaped sample or sample_shape shaped batch of |
| samples if the distribution parameters are batched. Samples first from |
| base distribution and applies `transform()` for every transform in the |
| list. |
| """ |
| with torch.no_grad(): |
| x = self.base_dist.sample(sample_shape) |
| for transform in self.transforms: |
| x = transform(x) |
| return x |
| |
| def rsample(self, sample_shape=torch.Size()): |
| """ |
| Generates a sample_shape shaped reparameterized sample or sample_shape |
| shaped batch of reparameterized samples if the distribution parameters |
| are batched. Samples first from base distribution and applies |
| `transform()` for every transform in the list. |
| """ |
| x = self.base_dist.rsample(sample_shape) |
| for transform in self.transforms: |
| x = transform(x) |
| return x |
| |
| def log_prob(self, value): |
| """ |
| Scores the sample by inverting the transform(s) and computing the score |
| using the score of the base distribution and the log abs det jacobian. |
| """ |
| if self._validate_args: |
| self._validate_sample(value) |
| event_dim = len(self.event_shape) |
| log_prob = 0.0 |
| y = value |
| for transform in reversed(self.transforms): |
| x = transform.inv(y) |
| event_dim += transform.domain.event_dim - transform.codomain.event_dim |
| log_prob = log_prob - _sum_rightmost( |
| transform.log_abs_det_jacobian(x, y), |
| event_dim - transform.domain.event_dim, |
| ) |
| y = x |
| |
| log_prob = log_prob + _sum_rightmost( |
| self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) |
| ) |
| return log_prob |
| |
| def _monotonize_cdf(self, value): |
| """ |
| This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is |
| monotone increasing. |
| """ |
| sign = 1 |
| for transform in self.transforms: |
| sign = sign * transform.sign |
| if isinstance(sign, int) and sign == 1: |
| return value |
| return sign * (value - 0.5) + 0.5 |
| |
| def cdf(self, value): |
| """ |
| Computes the cumulative distribution function by inverting the |
| transform(s) and computing the score of the base distribution. |
| """ |
| for transform in self.transforms[::-1]: |
| value = transform.inv(value) |
| if self._validate_args: |
| self.base_dist._validate_sample(value) |
| value = self.base_dist.cdf(value) |
| value = self._monotonize_cdf(value) |
| return value |
| |
| def icdf(self, value): |
| """ |
| Computes the inverse cumulative distribution function using |
| transform(s) and computing the score of the base distribution. |
| """ |
| value = self._monotonize_cdf(value) |
| value = self.base_dist.icdf(value) |
| for transform in self.transforms: |
| value = transform(value) |
| return value |