|  | import torch | 
|  | from torch.distributions import constraints | 
|  | from torch.distributions.categorical import Categorical | 
|  | from torch.distributions.distribution import Distribution | 
|  | from torch.distributions.transformed_distribution import TransformedDistribution | 
|  | from torch.distributions.transforms import ExpTransform | 
|  | from torch.distributions.utils import broadcast_all, clamp_probs | 
|  |  | 
|  | __all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] | 
|  |  | 
|  |  | 
|  | class ExpRelaxedCategorical(Distribution): | 
|  | r""" | 
|  | Creates a ExpRelaxedCategorical parameterized by | 
|  | :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). | 
|  | Returns the log of a point in the simplex. Based on the interface to | 
|  | :class:`OneHotCategorical`. | 
|  |  | 
|  | Implementation based on [1]. | 
|  |  | 
|  | See also: :func:`torch.distributions.OneHotCategorical` | 
|  |  | 
|  | Args: | 
|  | temperature (Tensor): relaxation temperature | 
|  | probs (Tensor): event probabilities | 
|  | logits (Tensor): unnormalized log probability for each event | 
|  |  | 
|  | [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables | 
|  | (Maddison et al, 2017) | 
|  |  | 
|  | [2] Categorical Reparametrization with Gumbel-Softmax | 
|  | (Jang et al, 2017) | 
|  | """ | 
|  | arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} | 
|  | support = ( | 
|  | constraints.real_vector | 
|  | )  # The true support is actually a submanifold of this. | 
|  | has_rsample = True | 
|  |  | 
|  | def __init__(self, temperature, probs=None, logits=None, validate_args=None): | 
|  | self._categorical = Categorical(probs, logits) | 
|  | self.temperature = temperature | 
|  | batch_shape = self._categorical.batch_shape | 
|  | event_shape = self._categorical.param_shape[-1:] | 
|  | super().__init__(batch_shape, event_shape, validate_args=validate_args) | 
|  |  | 
|  | def expand(self, batch_shape, _instance=None): | 
|  | new = self._get_checked_instance(ExpRelaxedCategorical, _instance) | 
|  | batch_shape = torch.Size(batch_shape) | 
|  | new.temperature = self.temperature | 
|  | new._categorical = self._categorical.expand(batch_shape) | 
|  | super(ExpRelaxedCategorical, new).__init__( | 
|  | batch_shape, self.event_shape, validate_args=False | 
|  | ) | 
|  | new._validate_args = self._validate_args | 
|  | return new | 
|  |  | 
|  | def _new(self, *args, **kwargs): | 
|  | return self._categorical._new(*args, **kwargs) | 
|  |  | 
|  | @property | 
|  | def param_shape(self): | 
|  | return self._categorical.param_shape | 
|  |  | 
|  | @property | 
|  | def logits(self): | 
|  | return self._categorical.logits | 
|  |  | 
|  | @property | 
|  | def probs(self): | 
|  | return self._categorical.probs | 
|  |  | 
|  | def rsample(self, sample_shape=torch.Size()): | 
|  | shape = self._extended_shape(sample_shape) | 
|  | uniforms = clamp_probs( | 
|  | torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) | 
|  | ) | 
|  | gumbels = -((-(uniforms.log())).log()) | 
|  | scores = (self.logits + gumbels) / self.temperature | 
|  | return scores - scores.logsumexp(dim=-1, keepdim=True) | 
|  |  | 
|  | def log_prob(self, value): | 
|  | K = self._categorical._num_events | 
|  | if self._validate_args: | 
|  | self._validate_sample(value) | 
|  | logits, value = broadcast_all(self.logits, value) | 
|  | log_scale = torch.full_like( | 
|  | self.temperature, float(K) | 
|  | ).lgamma() - self.temperature.log().mul(-(K - 1)) | 
|  | score = logits - value.mul(self.temperature) | 
|  | score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) | 
|  | return score + log_scale | 
|  |  | 
|  |  | 
|  | class RelaxedOneHotCategorical(TransformedDistribution): | 
|  | r""" | 
|  | Creates a RelaxedOneHotCategorical distribution parametrized by | 
|  | :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. | 
|  | This is a relaxed version of the :class:`OneHotCategorical` distribution, so | 
|  | its samples are on simplex, and are reparametrizable. | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> # xdoctest: +IGNORE_WANT("non-deterinistic") | 
|  | >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), | 
|  | ...                              torch.tensor([0.1, 0.2, 0.3, 0.4])) | 
|  | >>> m.sample() | 
|  | tensor([ 0.1294,  0.2324,  0.3859,  0.2523]) | 
|  |  | 
|  | Args: | 
|  | temperature (Tensor): relaxation temperature | 
|  | probs (Tensor): event probabilities | 
|  | logits (Tensor): unnormalized log probability for each event | 
|  | """ | 
|  | arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} | 
|  | support = constraints.simplex | 
|  | has_rsample = True | 
|  |  | 
|  | def __init__(self, temperature, probs=None, logits=None, validate_args=None): | 
|  | base_dist = ExpRelaxedCategorical( | 
|  | temperature, probs, logits, validate_args=validate_args | 
|  | ) | 
|  | super().__init__(base_dist, ExpTransform(), validate_args=validate_args) | 
|  |  | 
|  | def expand(self, batch_shape, _instance=None): | 
|  | new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) | 
|  | return super().expand(batch_shape, _instance=new) | 
|  |  | 
|  | @property | 
|  | def temperature(self): | 
|  | return self.base_dist.temperature | 
|  |  | 
|  | @property | 
|  | def logits(self): | 
|  | return self.base_dist.logits | 
|  |  | 
|  | @property | 
|  | def probs(self): | 
|  | return self.base_dist.probs |