Implement Multinomial distribution (#4624)
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 956e5d6..273b884 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -31,8 +31,8 @@
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, grad, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
- Dirichlet, Exponential, Gamma, Gumbel,
- Laplace, Normal, OneHotCategorical, Pareto,
+ Dirichlet, Exponential, Gamma, Gumbel, Laplace,
+ Normal, OneHotCategorical, Multinomial, Pareto,
StudentT, Uniform, kl_divergence)
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.constraints import Constraint, is_dependent
@@ -69,6 +69,10 @@
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
+ Example(Multinomial, [
+ {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
+ {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
+ ]),
Example(Cauchy, [
{'loc': 0.0, 'scale': 1.0},
{'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
@@ -294,6 +298,53 @@
(2, 5, 2, 3, 5))
self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
+ def test_multinomial_1d(self):
+ total_count = 10
+ p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
+ self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
+ self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3))
+ self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
+ self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
+ self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
+
+ @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
+ def test_multinomial_1d_log_prob(self):
+ total_count = 10
+ p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ dist = Multinomial(total_count, probs=p)
+ x = dist.sample()
+ log_prob = dist.log_prob(x)
+ expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
+ self.assertEqual(log_prob.data, expected)
+
+ dist = Multinomial(total_count, logits=p.log())
+ x = dist.sample()
+ log_prob = dist.log_prob(x)
+ expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
+ self.assertEqual(log_prob.data, expected)
+
+ def test_multinomial_2d(self):
+ total_count = 10
+ probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
+ probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
+ p = Variable(torch.Tensor(probabilities), requires_grad=True)
+ s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
+ self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
+ self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3))
+ set_rng_seed(0)
+ self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
+ p.grad.zero_()
+ self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
+
+ # sample check for extreme value of probs
+ self.assertEqual(Multinomial(total_count, s).sample().data,
+ torch.Tensor([[total_count, 0], [0, total_count]]))
+
+ # check entropy computation
+ self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
+
def test_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
# TODO: this should return a 0-dim tensor once we have Scalar support
@@ -1096,13 +1147,16 @@
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
- actual_shape = dist.entropy().size()
- expected_shape = dist._batch_shape
- if not expected_shape:
- expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
- message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
- Dist.__name__, i, len(params), expected_shape, actual_shape)
- self.assertEqual(actual_shape, expected_shape, message=message)
+ try:
+ actual_shape = dist.entropy().size()
+ expected_shape = dist._batch_shape
+ if not expected_shape:
+ expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
+ message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
+ Dist.__name__, i, len(params), expected_shape, actual_shape)
+ self.assertEqual(actual_shape, expected_shape, message=message)
+ except NotImplementedError:
+ continue
def test_bernoulli_shape_scalar_params(self):
bernoulli = Bernoulli(0.3)
@@ -1145,6 +1199,16 @@
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
+ def test_multinomial_shape(self):
+ dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
+ self.assertEqual(dist._batch_shape, torch.Size((3,)))
+ self.assertEqual(dist._event_shape, torch.Size((2,)))
+ self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
+ self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
+ self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
+ self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
+ self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
+
def test_categorical_shape(self):
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -1375,11 +1439,14 @@
for name, value in param.items():
if not (torch.is_tensor(value) or isinstance(value, Variable)):
value = torch.Tensor([value])
- if Dist in (Categorical, OneHotCategorical) and name == 'probs':
+ if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
# These distributions accept positive probs, but elsewhere we
# use a stricter constraint to the simplex.
value = value / value.sum(-1, True)
- constraint = dist.params[name]
+ try:
+ constraint = dist.params[name]
+ except KeyError:
+ continue # ignore optional parameters
if is_dependent(constraint):
continue
message = '{} example {}/{} parameter {} = {}'.format(
@@ -1499,6 +1566,23 @@
log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
+ def test_multinomial_log_prob(self):
+ for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
+ p = Variable(tensor_type([0, 1]), requires_grad=True)
+ s = Variable(tensor_type([0, 10]))
+ multinomial = Multinomial(10, p)
+ log_pdf = multinomial.log_prob(s)
+ self.assertEqual(log_pdf.data[0], 0)
+
+ def test_multinomial_log_prob_with_logits(self):
+ for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
+ p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
+ multinomial = Multinomial(10, logits=p)
+ log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
+ self.assertEqual(log_pdf_prob_1.data[0], 0)
+ log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
+ self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py
index 574f7c0..1496edc 100644
--- a/torch/distributions/__init__.py
+++ b/torch/distributions/__init__.py
@@ -42,6 +42,7 @@
from .gumbel import Gumbel
from .kl import kl_divergence, register_kl
from .laplace import Laplace
+from .multinomial import Multinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
@@ -60,6 +61,7 @@
'Gamma',
'Gumbel',
'Laplace',
+ 'Multinomial',
'Normal',
'OneHotCategorical',
'Pareto',
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index bd741ad..1e12c3b 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -7,10 +7,12 @@
class Categorical(Distribution):
r"""
- Creates a categorical distribution parameterized by `probs`.
+ Creates a categorical distribution parameterized by either `probs` or
+ `logits` (but not both).
.. note::
- It is equivalent to the distribution that ``multinomial()`` samples from.
+ It is equivalent to the distribution that :func:`torch.multinomial`
+ samples from.
Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
@@ -30,6 +32,7 @@
Args:
probs (Tensor or Variable): event probabilities
+ logits (Tensor or Variable): event log probabilities
"""
params = {'probs': constraints.simplex}
has_enumerate_support = True
diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py
index f8b1545..95ef5b7 100644
--- a/torch/distributions/constraints.py
+++ b/torch/distributions/constraints.py
@@ -10,6 +10,7 @@
'integer_interval',
'interval',
'is_dependent',
+ 'less_than',
'lower_triangular',
'nonnegative_integer',
'positive',
@@ -112,6 +113,17 @@
return self.lower_bound <= value
+class _LessThan(Constraint):
+ """
+ Constrain to a real half line `[inf, upper_bound]`.
+ """
+ def __init__(self, upper_bound):
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return value <= self.upper_bound
+
+
class _Interval(Constraint):
"""
Constrain to a real interval `[lower_bound, upper_bound]`.
@@ -150,6 +162,7 @@
real = _Real()
positive = _GreaterThan(0)
greater_than = _GreaterThan
+less_than = _LessThan
unit_interval = _Interval(0, 1)
interval = _Interval
simplex = _Simplex()
diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py
new file mode 100644
index 0000000..414edc0
--- /dev/null
+++ b/torch/distributions/multinomial.py
@@ -0,0 +1,85 @@
+import torch
+from torch.distributions.distribution import Distribution
+from torch.autograd import Variable
+from torch.distributions import Categorical
+from numbers import Number
+from torch.distributions import constraints
+from torch.distributions.utils import log_sum_exp, broadcast_all
+
+
+class Multinomial(Distribution):
+ r"""
+ Creates a Multinomial distribution parameterized by `total_count` and
+ either `probs` or `logits` (but not both). The innermost dimension of
+ `probs` indexes over categories. All other dimensions index over batches.
+
+ Note that `total_count` need not be specified if only :meth:`log_prob` is
+ called (see example below)
+
+ - :meth:`sample` requires a single shared `total_count` for all
+ parameters and samples.
+ - :meth:`log_prob` allows different `total_count` for each parameter and
+ sample.
+
+ Example::
+
+ >>> m = Multinomial(100, torch.Tensor([ 1, 1, 1, 1]))
+ >>> x = m.sample() # equal probability of 0, 1, 2, 3
+ 21
+ 24
+ 30
+ 25
+ [torch.FloatTensor of size 4]]
+
+ >>> Multinomial(probs=torch.Tensor([1, 1, 1, 1])).log_prob(x)
+ -4.1338
+ [torch.FloatTensor of size 1]
+
+ Args:
+ total_count (int): number of trials
+ probs (Tensor or Variable): event probabilities
+ logits (Tensor or Variable): event log probabilities
+ """
+ params = {'logits': constraints.real} # Let logits be the canonical parameterization.
+
+ def __init__(self, total_count=1, probs=None, logits=None):
+ if not isinstance(total_count, Number):
+ raise NotImplementedError('inhomogeneous total_count is not supported')
+ self.total_count = total_count
+ self._categorical = Categorical(probs=probs, logits=logits)
+ batch_shape = probs.size()[:-1] if probs is not None else logits.size()[:-1]
+ event_shape = probs.size()[-1:] if probs is not None else logits.size()[-1:]
+ super(Multinomial, self).__init__(batch_shape, event_shape)
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.integer_interval(0, self.total_count)
+
+ @property
+ def logits(self):
+ return self._categorical.logits
+
+ @property
+ def probs(self):
+ return self._categorical.probs
+
+ def sample(self, sample_shape=torch.Size()):
+ sample_shape = torch.Size(sample_shape)
+ samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
+ # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
+ # (sample_shape, batch_shape, total_count)
+ shifted_idx = list(range(samples.dim()))
+ shifted_idx.append(shifted_idx.pop(0))
+ samples = samples.permute(*shifted_idx)
+ counts = samples.new(self._extended_shape(sample_shape)).zero_()
+ counts.scatter_add_(-1, samples, torch.ones_like(samples))
+ return counts.type_as(self.probs)
+
+ def log_prob(self, value):
+ self._validate_log_prob_arg(value)
+ logits, value = broadcast_all(self.logits.clone(), value)
+ log_factorial_n = torch.lgamma(value.sum(-1) + 1)
+ log_factorial_xs = torch.lgamma(value + 1).sum(-1)
+ logits[(value == 0) & (logits == -float('inf'))] = 0
+ log_powers = (logits * value).sum(-1)
+ return log_factorial_n - log_factorial_xs + log_powers