Implement Multinomial distribution (#4624)

diff --git a/test/test_distributions.py b/test/test_distributions.py
index 956e5d6..273b884 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -31,8 +31,8 @@
 from common import TestCase, run_tests, set_rng_seed
 from torch.autograd import Variable, grad, gradcheck
 from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
-                                 Dirichlet, Exponential, Gamma, Gumbel,
-                                 Laplace, Normal, OneHotCategorical, Pareto,
+                                 Dirichlet, Exponential, Gamma, Gumbel, Laplace,
+                                 Normal, OneHotCategorical, Multinomial, Pareto,
                                  StudentT, Uniform, kl_divergence)
 from torch.distributions.dirichlet import _Dirichlet_backward
 from torch.distributions.constraints import Constraint, is_dependent
@@ -69,6 +69,10 @@
         {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
         {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
     ]),
+    Example(Multinomial, [
+        {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
+        {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
+    ]),
     Example(Cauchy, [
         {'loc': 0.0, 'scale': 1.0},
         {'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
@@ -294,6 +298,53 @@
                          (2, 5, 2, 3, 5))
         self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
 
+    def test_multinomial_1d(self):
+        total_count = 10
+        p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+        self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
+        self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
+        self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3))
+        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
+        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
+        self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
+
+    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
+    def test_multinomial_1d_log_prob(self):
+        total_count = 10
+        p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+        dist = Multinomial(total_count, probs=p)
+        x = dist.sample()
+        log_prob = dist.log_prob(x)
+        expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
+        self.assertEqual(log_prob.data, expected)
+
+        dist = Multinomial(total_count, logits=p.log())
+        x = dist.sample()
+        log_prob = dist.log_prob(x)
+        expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
+        self.assertEqual(log_prob.data, expected)
+
+    def test_multinomial_2d(self):
+        total_count = 10
+        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
+        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
+        p = Variable(torch.Tensor(probabilities), requires_grad=True)
+        s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+        self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
+        self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
+        self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3))
+        set_rng_seed(0)
+        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
+        p.grad.zero_()
+        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
+
+        # sample check for extreme value of probs
+        self.assertEqual(Multinomial(total_count, s).sample().data,
+                         torch.Tensor([[total_count, 0], [0, total_count]]))
+
+        # check entropy computation
+        self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
+
     def test_categorical_1d(self):
         p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
         # TODO: this should return a 0-dim tensor once we have Scalar support
@@ -1096,13 +1147,16 @@
         for Dist, params in EXAMPLES:
             for i, param in enumerate(params):
                 dist = Dist(**param)
-                actual_shape = dist.entropy().size()
-                expected_shape = dist._batch_shape
-                if not expected_shape:
-                    expected_shape = torch.Size((1,))  # TODO Remove this once scalars are supported.
-                message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
-                    Dist.__name__, i, len(params), expected_shape, actual_shape)
-                self.assertEqual(actual_shape, expected_shape, message=message)
+                try:
+                    actual_shape = dist.entropy().size()
+                    expected_shape = dist._batch_shape
+                    if not expected_shape:
+                        expected_shape = torch.Size((1,))  # TODO Remove this once scalars are supported.
+                    message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
+                        Dist.__name__, i, len(params), expected_shape, actual_shape)
+                    self.assertEqual(actual_shape, expected_shape, message=message)
+                except NotImplementedError:
+                    continue
 
     def test_bernoulli_shape_scalar_params(self):
         bernoulli = Bernoulli(0.3)
@@ -1145,6 +1199,16 @@
         self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
         self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
 
+    def test_multinomial_shape(self):
+        dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
+        self.assertEqual(dist._batch_shape, torch.Size((3,)))
+        self.assertEqual(dist._event_shape, torch.Size((2,)))
+        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
+        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
+        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
+        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
+        self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
+
     def test_categorical_shape(self):
         dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
         self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -1375,11 +1439,14 @@
                 for name, value in param.items():
                     if not (torch.is_tensor(value) or isinstance(value, Variable)):
                         value = torch.Tensor([value])
-                    if Dist in (Categorical, OneHotCategorical) and name == 'probs':
+                    if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
                         # These distributions accept positive probs, but elsewhere we
                         # use a stricter constraint to the simplex.
                         value = value / value.sum(-1, True)
-                    constraint = dist.params[name]
+                    try:
+                        constraint = dist.params[name]
+                    except KeyError:
+                        continue  # ignore optional parameters
                     if is_dependent(constraint):
                         continue
                     message = '{} example {}/{} parameter {} = {}'.format(
@@ -1499,6 +1566,23 @@
             log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
             self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
 
+    def test_multinomial_log_prob(self):
+        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
+            p = Variable(tensor_type([0, 1]), requires_grad=True)
+            s = Variable(tensor_type([0, 10]))
+            multinomial = Multinomial(10, p)
+            log_pdf = multinomial.log_prob(s)
+            self.assertEqual(log_pdf.data[0], 0)
+
+    def test_multinomial_log_prob_with_logits(self):
+        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
+            p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
+            multinomial = Multinomial(10, logits=p)
+            log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
+            self.assertEqual(log_pdf_prob_1.data[0], 0)
+            log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
+            self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py
index 574f7c0..1496edc 100644
--- a/torch/distributions/__init__.py
+++ b/torch/distributions/__init__.py
@@ -42,6 +42,7 @@
 from .gumbel import Gumbel
 from .kl import kl_divergence, register_kl
 from .laplace import Laplace
+from .multinomial import Multinomial
 from .normal import Normal
 from .one_hot_categorical import OneHotCategorical
 from .pareto import Pareto
@@ -60,6 +61,7 @@
     'Gamma',
     'Gumbel',
     'Laplace',
+    'Multinomial',
     'Normal',
     'OneHotCategorical',
     'Pareto',
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index bd741ad..1e12c3b 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -7,10 +7,12 @@
 
 class Categorical(Distribution):
     r"""
-    Creates a categorical distribution parameterized by `probs`.
+    Creates a categorical distribution parameterized by either `probs` or
+    `logits` (but not both).
 
     .. note::
-        It is equivalent to the distribution that ``multinomial()`` samples from.
+        It is equivalent to the distribution that :func:`torch.multinomial`
+        samples from.
 
     Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
 
@@ -30,6 +32,7 @@
 
     Args:
         probs (Tensor or Variable): event probabilities
+        logits (Tensor or Variable): event log probabilities
     """
     params = {'probs': constraints.simplex}
     has_enumerate_support = True
diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py
index f8b1545..95ef5b7 100644
--- a/torch/distributions/constraints.py
+++ b/torch/distributions/constraints.py
@@ -10,6 +10,7 @@
     'integer_interval',
     'interval',
     'is_dependent',
+    'less_than',
     'lower_triangular',
     'nonnegative_integer',
     'positive',
@@ -112,6 +113,17 @@
         return self.lower_bound <= value
 
 
+class _LessThan(Constraint):
+    """
+    Constrain to a real half line `[inf, upper_bound]`.
+    """
+    def __init__(self, upper_bound):
+        self.upper_bound = upper_bound
+
+    def check(self, value):
+        return value <= self.upper_bound
+
+
 class _Interval(Constraint):
     """
     Constrain to a real interval `[lower_bound, upper_bound]`.
@@ -150,6 +162,7 @@
 real = _Real()
 positive = _GreaterThan(0)
 greater_than = _GreaterThan
+less_than = _LessThan
 unit_interval = _Interval(0, 1)
 interval = _Interval
 simplex = _Simplex()
diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py
new file mode 100644
index 0000000..414edc0
--- /dev/null
+++ b/torch/distributions/multinomial.py
@@ -0,0 +1,85 @@
+import torch
+from torch.distributions.distribution import Distribution
+from torch.autograd import Variable
+from torch.distributions import Categorical
+from numbers import Number
+from torch.distributions import constraints
+from torch.distributions.utils import log_sum_exp, broadcast_all
+
+
+class Multinomial(Distribution):
+    r"""
+    Creates a Multinomial distribution parameterized by `total_count` and
+    either `probs` or `logits` (but not both). The innermost dimension of
+    `probs` indexes over categories. All other dimensions index over batches.
+
+    Note that `total_count` need not be specified if only :meth:`log_prob` is
+    called (see example below)
+
+    -   :meth:`sample` requires a single shared `total_count` for all
+        parameters and samples.
+    -   :meth:`log_prob` allows different `total_count` for each parameter and
+        sample.
+
+    Example::
+
+        >>> m = Multinomial(100, torch.Tensor([ 1, 1, 1, 1]))
+        >>> x = m.sample()  # equal probability of 0, 1, 2, 3
+         21
+         24
+         30
+         25
+        [torch.FloatTensor of size 4]]
+
+        >>> Multinomial(probs=torch.Tensor([1, 1, 1, 1])).log_prob(x)
+        -4.1338
+        [torch.FloatTensor of size 1]
+
+    Args:
+        total_count (int): number of trials
+        probs (Tensor or Variable): event probabilities
+        logits (Tensor or Variable): event log probabilities
+    """
+    params = {'logits': constraints.real}  # Let logits be the canonical parameterization.
+
+    def __init__(self, total_count=1, probs=None, logits=None):
+        if not isinstance(total_count, Number):
+            raise NotImplementedError('inhomogeneous total_count is not supported')
+        self.total_count = total_count
+        self._categorical = Categorical(probs=probs, logits=logits)
+        batch_shape = probs.size()[:-1] if probs is not None else logits.size()[:-1]
+        event_shape = probs.size()[-1:] if probs is not None else logits.size()[-1:]
+        super(Multinomial, self).__init__(batch_shape, event_shape)
+
+    @constraints.dependent_property
+    def support(self):
+        return constraints.integer_interval(0, self.total_count)
+
+    @property
+    def logits(self):
+        return self._categorical.logits
+
+    @property
+    def probs(self):
+        return self._categorical.probs
+
+    def sample(self, sample_shape=torch.Size()):
+        sample_shape = torch.Size(sample_shape)
+        samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
+        # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
+        # (sample_shape, batch_shape, total_count)
+        shifted_idx = list(range(samples.dim()))
+        shifted_idx.append(shifted_idx.pop(0))
+        samples = samples.permute(*shifted_idx)
+        counts = samples.new(self._extended_shape(sample_shape)).zero_()
+        counts.scatter_add_(-1, samples, torch.ones_like(samples))
+        return counts.type_as(self.probs)
+
+    def log_prob(self, value):
+        self._validate_log_prob_arg(value)
+        logits, value = broadcast_all(self.logits.clone(), value)
+        log_factorial_n = torch.lgamma(value.sum(-1) + 1)
+        log_factorial_xs = torch.lgamma(value + 1).sum(-1)
+        logits[(value == 0) & (logits == -float('inf'))] = 0
+        log_powers = (logits * value).sum(-1)
+        return log_factorial_n - log_factorial_xs + log_powers