blob: e037424cb1565f3c289002c2eaf8ae8582a671a4 [file] [log] [blame]
"""
The ``distributions`` package contains parameterizable probability distributions
and sampling functions.
The :meth:`log_prob` method is useful for policy gradient based methods. If the
parameters of the distribution are differentiable, then the result of ``log_prob``
is also differentiable.
Example::
probs = network(input)
m = Multinomial(probs)
action = m.sample()
loss = -m.log_prob(action) * get_reward(env, action)
loss.backward()
"""
import math
from numbers import Number
import torch
__all__ = ['Distribution', 'Bernoulli', 'Multinomial', 'Normal']
class Distribution(object):
r"""
Distribution is the abstract base class for probability distributions.
"""
def sample(self):
"""
Generates a single sample or single batch of samples if the distribution
parameters are batched.
"""
raise NotImplementedError
def sample_n(self, n):
"""
Generates n samples or n batches of samples if the distribution parameters
are batched.
"""
raise NotImplementedError
def log_prob(self, value):
"""
Returns the log of the probability density/mass function evaluated at
`value`.
Args:
value (Tensor or Variable):
"""
raise NotImplementedError
class Bernoulli(Distribution):
r"""
Creates a Bernoulli distribution parameterized by `probs`.
Samples are binary (0 or 1). They take the value `1` with probability `p`
and `0` with probability `1 - p`.
Example::
>>> m = Bernoulli(torch.Tensor([0.3]))
>>> m.sample() # 30% chance 1; 70% chance 0
0.0
[torch.FloatTensor of size 1]
Args:
probs (Tensor or Variable): the probabilty of sampling `1`
"""
def __init__(self, probs):
self.probs = probs
def sample(self):
return torch.bernoulli(self.probs)
def sample_n(self, n):
return torch.bernoulli(self.probs.expand(n, *self.probs.size()))
def log_prob(self, value):
# compute the log probabilities for 0 and 1
log_pmf = (torch.stack([1 - self.probs, self.probs])).log()
# evaluate using the values
return log_pmf.gather(0, value.unsqueeze(0).long()).squeeze(0)
class Multinomial(Distribution):
r"""
Creates a multinomial distribution parameterized by `probs`.
Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
If `probs` is 1D with length-`K`, each element is the relative probability
of sampling the class at that index.
If `probs` is 2D, it is treated as a batch of probability vectors.
See also: :func:`torch.multinomial`
Example::
>>> m = Multinomial(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
3
[torch.LongTensor of size 1]
Args:
probs (Tensor or Variable): event probabilities
"""
def __init__(self, probs):
if probs.dim() != 1 and probs.dim() != 2:
# TODO: treat higher dimensions as part of the batch
raise ValueError("probs must be 1D or 2D")
self.probs = probs
def sample(self):
return torch.multinomial(self.probs, 1, True).squeeze(-1)
def sample_n(self, n):
if n == 1:
return self.sample().expand(1, 1)
else:
return torch.multinomial(self.probs, n, True).t()
def log_prob(self, value):
p = self.probs / self.probs.sum(-1, keepdim=True)
if value.dim() == 1 and self.probs.dim() == 1:
# special handling until we have 0-dim tensor support
return p.gather(-1, value).log()
return p.gather(-1, value.unsqueeze(-1)).squeeze(-1).log()
class Normal(Distribution):
r"""
Creates a normal (also called Gaussian) distribution parameterized by
`mean` and `std`.
Example::
>>> m = Normal(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample() # normally distributed with mean=0 and stddev=1
0.1046
[torch.FloatTensor of size 1]
Args:
mean (float or Tensor or Variable): mean of the distribution
std (float or Tensor or Variable): standard deviation of the distribution
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def sample(self):
return torch.normal(self.mean, self.std)
def sample_n(self, n):
# cleanly expand float or Tensor or Variable parameters
def expand(v):
if isinstance(v, Number):
return torch.Tensor([v]).expand(n, 1)
else:
return v.expand(n, *v.size())
return torch.normal(expand(self.mean), expand(self.std))
def log_prob(self, value):
# compute the variance
var = (self.std ** 2)
log_std = math.log(self.std) if isinstance(self.std, Number) else self.std.log()
return -((value - self.mean) ** 2) / (2 * var) - log_std - math.log(math.sqrt(2 * math.pi))