blob: 5a7009db060a8723b1503005e576536d0bb03de9 [file] [log] [blame]
from numbers import Number
import torch
from torch.autograd import Function, Variable
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _finfo, broadcast_all
def _dirichlet_sample_nograd(concentration):
probs = torch._C._standard_gamma(concentration)
probs /= probs.sum(-1, True)
eps = _finfo(probs).eps
return probs.clamp_(min=eps, max=1 - eps)
# This helper is exposed for testing.
def _Dirichlet_backward(x, concentration, grad_output):
total = concentration.sum(-1, True).expand_as(concentration)
grad = torch._C._dirichlet_grad(x, concentration, total)
return grad * (grad_output - (x * grad_output).sum(-1, True))
class _Dirichlet(Function):
@staticmethod
def forward(ctx, concentration):
x = _dirichlet_sample_nograd(concentration)
ctx.save_for_backward(x, concentration)
return x
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output)
class Dirichlet(Distribution):
r"""
Creates a Dirichlet distribution parameterized by concentration `concentration`.
Example::
>>> m = Dirichlet(torch.Tensor([0.5, 0.5]))
>>> m.sample() # Dirichlet distributed with concentrarion concentration
0.1046
0.8954
[torch.FloatTensor of size 2]
Args:
concentration (Tensor or Variable): concentration parameter of the distribution
(often referred to as alpha)
"""
params = {'concentration': constraints.positive}
support = constraints.simplex
has_rsample = True
def __init__(self, concentration):
self.concentration, = broadcast_all(concentration)
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super(Dirichlet, self).__init__(batch_shape, event_shape)
def rsample(self, sample_shape=()):
shape = self._extended_shape(sample_shape)
concentration = self.concentration.expand(shape)
if isinstance(concentration, Variable):
return _Dirichlet.apply(concentration)
return _dirichlet_sample_nograd(concentration)
def log_prob(self, value):
self._validate_log_prob_arg(value)
return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
torch.lgamma(self.concentration.sum(-1)) -
torch.lgamma(self.concentration).sum(-1))
@property
def mean(self):
return self.concentration / self.concentration.sum(-1)
@property
def variance(self):
con0 = self.concentration.sum(-1)
return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1))
def entropy(self):
k = self.concentration.size(-1)
a0 = self.concentration.sum(-1)
return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
(k - a0) * torch.digamma(a0) -
((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1))