| import torch |
| from torch.distributions import constraints |
| from torch.distributions.distribution import Distribution |
| from torch.distributions.transforms import Transform |
| from torch.distributions.utils import _sum_rightmost |
| |
| |
| class TransformedDistribution(Distribution): |
| r""" |
| Extension of the Distribution class, which applies a sequence of Transforms |
| to a base distribution. Let f be the composition of transforms applied:: |
| |
| X ~ BaseDistribution |
| Y = f(X) ~ TransformedDistribution(BaseDistribution, f) |
| log p(Y) = log p(X) + log |det (dX/dY)| |
| |
| Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the |
| maximum shape of its base distribution and its transforms, since transforms |
| can introduce correlations among events. |
| |
| An example for the usage of :class:`TransformedDistribution` would be:: |
| |
| # Building a Logistic Distribution |
| # X ~ Uniform(0, 1) |
| # f = a + b * logit(X) |
| # Y ~ f(X) ~ Logistic(a, b) |
| base_distribution = Uniform(0, 1) |
| transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] |
| logistic = TransformedDistribution(base_distribution, transforms) |
| |
| For more examples, please look at the implementations of |
| :class:`~torch.distributions.gumbel.Gumbel`, |
| :class:`~torch.distributions.half_cauchy.HalfCauchy`, |
| :class:`~torch.distributions.half_normal.HalfNormal`, |
| :class:`~torch.distributions.log_normal.LogNormal`, |
| :class:`~torch.distributions.pareto.Pareto`, |
| :class:`~torch.distributions.weibull.Weibull`, |
| :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and |
| :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` |
| """ |
| arg_constraints = {} |
| |
| def __init__(self, base_distribution, transforms, validate_args=None): |
| self.base_dist = base_distribution |
| if isinstance(transforms, Transform): |
| self.transforms = [transforms, ] |
| elif isinstance(transforms, list): |
| if not all(isinstance(t, Transform) for t in transforms): |
| raise ValueError("transforms must be a Transform or a list of Transforms") |
| self.transforms = transforms |
| else: |
| raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) |
| shape = self.base_dist.batch_shape + self.base_dist.event_shape |
| event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms]) |
| batch_shape = shape[:len(shape) - event_dim] |
| event_shape = shape[len(shape) - event_dim:] |
| super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args) |
| |
| def expand(self, batch_shape, _instance=None): |
| new = self._get_checked_instance(TransformedDistribution, _instance) |
| batch_shape = torch.Size(batch_shape) |
| base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):] |
| new.base_dist = self.base_dist.expand(base_dist_batch_shape) |
| new.transforms = self.transforms |
| super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) |
| new._validate_args = self._validate_args |
| return new |
| |
| @constraints.dependent_property |
| def support(self): |
| return self.transforms[-1].codomain if self.transforms else self.base_dist.support |
| |
| @property |
| def has_rsample(self): |
| return self.base_dist.has_rsample |
| |
| def sample(self, sample_shape=torch.Size()): |
| """ |
| Generates a sample_shape shaped sample or sample_shape shaped batch of |
| samples if the distribution parameters are batched. Samples first from |
| base distribution and applies `transform()` for every transform in the |
| list. |
| """ |
| with torch.no_grad(): |
| x = self.base_dist.sample(sample_shape) |
| for transform in self.transforms: |
| x = transform(x) |
| return x |
| |
| def rsample(self, sample_shape=torch.Size()): |
| """ |
| Generates a sample_shape shaped reparameterized sample or sample_shape |
| shaped batch of reparameterized samples if the distribution parameters |
| are batched. Samples first from base distribution and applies |
| `transform()` for every transform in the list. |
| """ |
| x = self.base_dist.rsample(sample_shape) |
| for transform in self.transforms: |
| x = transform(x) |
| return x |
| |
| def log_prob(self, value): |
| """ |
| Scores the sample by inverting the transform(s) and computing the score |
| using the score of the base distribution and the log abs det jacobian. |
| """ |
| event_dim = len(self.event_shape) |
| log_prob = 0.0 |
| y = value |
| for transform in reversed(self.transforms): |
| x = transform.inv(y) |
| log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), |
| event_dim - transform.event_dim) |
| y = x |
| |
| log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), |
| event_dim - len(self.base_dist.event_shape)) |
| return log_prob |
| |
| def _monotonize_cdf(self, value): |
| """ |
| This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is |
| monotone increasing. |
| """ |
| sign = 1 |
| for transform in self.transforms: |
| sign = sign * transform.sign |
| if isinstance(sign, int) and sign == 1: |
| return value |
| return sign * (value - 0.5) + 0.5 |
| |
| def cdf(self, value): |
| """ |
| Computes the cumulative distribution function by inverting the |
| transform(s) and computing the score of the base distribution. |
| """ |
| for transform in self.transforms[::-1]: |
| value = transform.inv(value) |
| if self._validate_args: |
| self.base_dist._validate_sample(value) |
| value = self.base_dist.cdf(value) |
| value = self._monotonize_cdf(value) |
| return value |
| |
| def icdf(self, value): |
| """ |
| Computes the inverse cumulative distribution function using |
| transform(s) and computing the score of the base distribution. |
| """ |
| value = self._monotonize_cdf(value) |
| if self._validate_args: |
| self.base_dist._validate_sample(value) |
| value = self.base_dist.icdf(value) |
| for transform in self.transforms: |
| value = transform(value) |
| return value |