|  | import torch | 
|  | from torch.distributions.distribution import Distribution | 
|  | from torch.distributions import Categorical | 
|  | from torch.distributions import constraints | 
|  |  | 
|  |  | 
|  | class MixtureSameFamily(Distribution): | 
|  | r""" | 
|  | The `MixtureSameFamily` distribution implements a (batch of) mixture | 
|  | distribution where all component are from different parameterizations of | 
|  | the same distribution type. It is parameterized by a `Categorical` | 
|  | "selecting distribution" (over `k` component) and a component | 
|  | distribution, i.e., a `Distribution` with a rightmost batch shape | 
|  | (equal to `[k]`) which indexes each (batch of) component. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | # Construct Gaussian Mixture Model in 1D consisting of 5 equally | 
|  | # weighted normal distributions | 
|  | >>> mix = D.Categorical(torch.ones(5,)) | 
|  | >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) | 
|  | >>> gmm = MixtureSameFamily(mix, comp) | 
|  |  | 
|  | # Construct Gaussian Mixture Modle in 2D consisting of 5 equally | 
|  | # weighted bivariate normal distributions | 
|  | >>> mix = D.Categorical(torch.ones(5,)) | 
|  | >>> comp = D.Independent(D.Normal( | 
|  | torch.randn(5,2), torch.rand(5,2)), 1) | 
|  | >>> gmm = MixtureSameFamily(mix, comp) | 
|  |  | 
|  | # Construct a batch of 3 Gaussian Mixture Models in 2D each | 
|  | # consisting of 5 random weighted bivariate normal distributions | 
|  | >>> mix = D.Categorical(torch.rand(3,5)) | 
|  | >>> comp = D.Independent(D.Normal( | 
|  | torch.randn(3,5,2), torch.rand(3,5,2)), 1) | 
|  | >>> gmm = MixtureSameFamily(mix, comp) | 
|  |  | 
|  | Args: | 
|  | mixture_distribution: `torch.distributions.Categorical`-like | 
|  | instance. Manages the probability of selecting component. | 
|  | The number of categories must match the rightmost batch | 
|  | dimension of the `component_distribution`. Must have either | 
|  | scalar `batch_shape` or `batch_shape` matching | 
|  | `component_distribution.batch_shape[:-1]` | 
|  | component_distribution: `torch.distributions.Distribution`-like | 
|  | instance. Right-most batch dimension indexes component. | 
|  | """ | 
|  | arg_constraints = {} | 
|  | has_rsample = False | 
|  |  | 
|  | def __init__(self, | 
|  | mixture_distribution, | 
|  | component_distribution, | 
|  | validate_args=None): | 
|  | self._mixture_distribution = mixture_distribution | 
|  | self._component_distribution = component_distribution | 
|  |  | 
|  | if not isinstance(self._mixture_distribution, Categorical): | 
|  | raise ValueError(" The Mixture distribution needs to be an " | 
|  | " instance of torch.distribtutions.Categorical") | 
|  |  | 
|  | if not isinstance(self._component_distribution, Distribution): | 
|  | raise ValueError("The Component distribution need to be an " | 
|  | "instance of torch.distributions.Distribution") | 
|  |  | 
|  | # Check that batch size matches | 
|  | mdbs = self._mixture_distribution.batch_shape | 
|  | cdbs = self._component_distribution.batch_shape[:-1] | 
|  | for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): | 
|  | if size1 != 1 and size2 != 1 and size1 != size2: | 
|  | raise ValueError("`mixture_distribution.batch_shape` ({0}) is not " | 
|  | "compatible with `component_distribution." | 
|  | "batch_shape`({1})".format(mdbs, cdbs)) | 
|  |  | 
|  | # Check that the number of mixture component matches | 
|  | km = self._mixture_distribution.logits.shape[-1] | 
|  | kc = self._component_distribution.batch_shape[-1] | 
|  | if km is not None and kc is not None and km != kc: | 
|  | raise ValueError("`mixture_distribution component` ({0}) does not" | 
|  | " equal `component_distribution.batch_shape[-1]`" | 
|  | " ({1})".format(km, kc)) | 
|  | self._num_component = km | 
|  |  | 
|  | event_shape = self._component_distribution.event_shape | 
|  | self._event_ndims = len(event_shape) | 
|  | super(MixtureSameFamily, self).__init__(batch_shape=cdbs, | 
|  | event_shape=event_shape, | 
|  | validate_args=validate_args) | 
|  |  | 
|  | def expand(self, batch_shape, _instance=None): | 
|  | batch_shape = torch.Size(batch_shape) | 
|  | batch_shape_comp = batch_shape + (self._num_component,) | 
|  | new = self._get_checked_instance(MixtureSameFamily, _instance) | 
|  | new._component_distribution = \ | 
|  | self._component_distribution.expand(batch_shape_comp) | 
|  | new._mixture_distribution = \ | 
|  | self._mixture_distribution.expand(batch_shape) | 
|  | new._num_component = self._num_component | 
|  | new._event_ndims = self._event_ndims | 
|  | event_shape = new._component_distribution.event_shape | 
|  | super(MixtureSameFamily, new).__init__(batch_shape=batch_shape, | 
|  | event_shape=event_shape, | 
|  | validate_args=False) | 
|  | new._validate_args = self._validate_args | 
|  | return new | 
|  |  | 
|  | @constraints.dependent_property | 
|  | def support(self): | 
|  | # FIXME this may have the wrong shape when support contains batched | 
|  | # parameters | 
|  | return self._component_distribution.support | 
|  |  | 
|  | @property | 
|  | def mixture_distribution(self): | 
|  | return self._mixture_distribution | 
|  |  | 
|  | @property | 
|  | def component_distribution(self): | 
|  | return self._component_distribution | 
|  |  | 
|  | @property | 
|  | def mean(self): | 
|  | probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | 
|  | return torch.sum(probs * self.component_distribution.mean, | 
|  | dim=-1 - self._event_ndims)  # [B, E] | 
|  |  | 
|  | @property | 
|  | def variance(self): | 
|  | # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) | 
|  | probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | 
|  | mean_cond_var = torch.sum(probs * self.component_distribution.variance, | 
|  | dim=-1 - self._event_ndims) | 
|  | var_cond_mean = torch.sum(probs * (self.component_distribution.mean - | 
|  | self._pad(self.mean)).pow(2.0), | 
|  | dim=-1 - self._event_ndims) | 
|  | return mean_cond_var + var_cond_mean | 
|  |  | 
|  | def cdf(self, x): | 
|  | x = self._pad(x) | 
|  | cdf_x = self.component_distribution.cdf(x) | 
|  | mix_prob = self.mixture_distribution.probs | 
|  |  | 
|  | return torch.sum(cdf_x * mix_prob, dim=-1) | 
|  |  | 
|  | def log_prob(self, x): | 
|  | x = self._pad(x) | 
|  | log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k] | 
|  | log_mix_prob = torch.log_softmax(self.mixture_distribution.logits, | 
|  | dim=-1)  # [B, k] | 
|  | return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B] | 
|  |  | 
|  | def sample(self, sample_shape=torch.Size()): | 
|  | with torch.no_grad(): | 
|  | sample_len = len(sample_shape) | 
|  | batch_len = len(self.batch_shape) | 
|  | gather_dim = sample_len + batch_len | 
|  | es = self.event_shape | 
|  |  | 
|  | # mixture samples [n, B] | 
|  | mix_sample = self.mixture_distribution.sample(sample_shape) | 
|  | mix_shape = mix_sample.shape | 
|  |  | 
|  | # component samples [n, B, k, E] | 
|  | comp_samples = self.component_distribution.sample(sample_shape) | 
|  |  | 
|  | # Gather along the k dimension | 
|  | mix_sample_r = mix_sample.reshape( | 
|  | mix_shape + torch.Size([1] * (len(es) + 1))) | 
|  | mix_sample_r = mix_sample_r.repeat( | 
|  | torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es) | 
|  |  | 
|  | samples = torch.gather(comp_samples, gather_dim, mix_sample_r) | 
|  | return samples.squeeze(gather_dim) | 
|  |  | 
|  | def _pad(self, x): | 
|  | return x.unsqueeze(-1 - self._event_ndims) | 
|  |  | 
|  | def _pad_mixture_dimensions(self, x): | 
|  | dist_batch_ndims = self.batch_shape.numel() | 
|  | cat_batch_ndims = self.mixture_distribution.batch_shape.numel() | 
|  | pad_ndims = 0 if cat_batch_ndims == 1 else \ | 
|  | dist_batch_ndims - cat_batch_ndims | 
|  | xs = x.shape | 
|  | x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) + | 
|  | xs[-1:] + torch.Size(self._event_ndims * [1])) | 
|  | return x | 
|  |  | 
|  | def __repr__(self): | 
|  | args_string = '\n  {},\n  {}'.format(self.mixture_distribution, | 
|  | self.component_distribution) | 
|  | return 'MixtureSameFamily' + '(' + args_string + ')' |