Feature/vonmises upstream (#33418)

Summary:
Third try of https://github.com/pytorch/pytorch/issues/33177 😄
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33418

Differential Revision: D20069683

Pulled By: ezyang

fbshipit-source-id: f58e45e91b672bfde2e41a4480215ba4c613f9de
diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst
index 35857ad..3ace6bd 100644
--- a/docs/source/distributions.rst
+++ b/docs/source/distributions.rst
@@ -311,6 +311,15 @@
     :undoc-members:
     :show-inheritance:
 
+:hidden:`VonMises`
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. currentmodule:: torch.distributions.von_mises
+.. autoclass:: VonMises
+    :members:
+    :undoc-members:
+    :show-inheritance:
+    
 :hidden:`Weibull`
 ~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 15e31ff..3a86697 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -50,7 +50,7 @@
                                  NegativeBinomial, Normal, OneHotCategorical, Pareto,
                                  Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
                                  StudentT, TransformedDistribution, Uniform,
-                                 Weibull, constraints, kl_divergence)
+                                 VonMises, Weibull, constraints, kl_divergence)
 from torch.distributions.constraint_registry import biject_to, transform_to
 from torch.distributions.constraints import Constraint, is_dependent
 from torch.distributions.dirichlet import _Dirichlet_backward
@@ -443,7 +443,17 @@
                 loc=torch.randn(5, 2, requires_grad=True),
                 covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)),
         },     
-    ]) 
+    ]),
+    Example(VonMises, [
+        {
+            'loc': torch.tensor(1.0, requires_grad=True),
+            'concentration': torch.tensor(10.0, requires_grad=True)
+        },
+        {
+            'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
+            'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
+        }
+    ])
 ]
 
 BAD_EXAMPLES = [
@@ -696,7 +706,7 @@
             asset_fn(i, val.squeeze(), log_prob)
 
     def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
-                               num_samples=10000, failure_rate=1e-3):
+                               circular=False, num_samples=10000, failure_rate=1e-3):
         # Checks that the .sample() method matches a reference function.
         torch_samples = torch_dist.sample((num_samples,)).squeeze()
         torch_samples = torch_samples.cpu().numpy()
@@ -708,6 +718,8 @@
             torch_samples = np.dot(torch_samples, axis)
             ref_samples = np.dot(ref_samples, axis)
         samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
+        if circular:
+            samples = [(np.cos(x), v) for (x, v) in samples]
         shuffle(samples)  # necessary to prevent stable sort from making uneven bins for discrete
         samples.sort(key=lambda x: x[0])
         samples = np.array(samples)[:, 1]
@@ -1361,6 +1373,23 @@
         low.grad.zero_()
         high.grad.zero_()
 
+    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
+    def test_vonmises_sample(self):
+        for loc in [0.0, math.pi / 2.0]:
+            for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
+                self._check_sampler_sampler(VonMises(loc, concentration),
+                                            scipy.stats.vonmises(loc=loc, kappa=concentration),
+                                            "VonMises(loc={}, concentration={})".format(loc, concentration),
+                                            num_samples=int(1e5), circular=True)
+
+    def test_vonmises_logprob(self):
+        concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
+        for concentration in concentrations:
+            grid = torch.arange(0., 2 * math.pi, 1e-4)
+            prob = VonMises(0.0, concentration).log_prob(grid).exp()
+            norm = prob.mean().item() * 2 * math.pi
+            self.assertLess(abs(norm - 1), 1e-3)
+
     def test_cauchy(self):
         loc = torch.zeros(5, 5, requires_grad=True)
         scale = torch.ones(5, 5, requires_grad=True)
@@ -3132,6 +3161,27 @@
         self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
         self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
 
+    def test_vonmises_shape_tensor_params(self):
+        von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
+        self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
+        self.assertEqual(von_mises._event_shape, torch.Size(()))
+        self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
+        self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
+        self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
+        self.assertEqual(von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
+
+    def test_vonmises_shape_scalar_params(self):
+        von_mises = VonMises(0., 1.)
+        self.assertEqual(von_mises._batch_shape, torch.Size())
+        self.assertEqual(von_mises._event_shape, torch.Size())
+        self.assertEqual(von_mises.sample().size(), torch.Size())
+        self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(),
+                         torch.Size((3, 2)))
+        self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(),
+                         torch.Size((3, 2)))
+        self.assertEqual(von_mises.log_prob(self.tensor_sample_2).size(),
+                         torch.Size((3, 2, 3)))
+
     def test_weibull_scale_scalar_params(self):
         weibull = Weibull(1, 1)
         self.assertEqual(weibull._batch_shape, torch.Size())
@@ -3883,6 +3933,10 @@
                 scipy.stats.uniform(random_var, positive_var)
             ),
             (
+                VonMises(random_var, positive_var),
+                scipy.stats.vonmises(positive_var, loc=random_var)
+            ),
+            (
                 Weibull(positive_var[0], positive_var2[0]),  # scipy var for Weibull only supports scalars
                 scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
             )
@@ -3900,8 +3954,9 @@
 
     def test_variance_stddev(self):
         for pytorch_dist, scipy_dist in self.distribution_pairs:
-            if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
+            if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)):
                 # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
+                # VonMises variance is circular and scipy doesn't produce a correct result
                 continue
             elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
                 self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
@@ -4233,9 +4288,9 @@
 
 class TestFunctors(TestCase):
     def test_cat_transform(self):
-        x1 = -1 * torch.range(1, 100).view(-1, 100)
-        x2 = (torch.range(1, 100).view(-1, 100) - 1) / 100
-        x3 = torch.range(1, 100).view(-1, 100)
+        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
+        x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
+        x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
         t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
         dim = 0
         x = torch.cat([x1, x2, x3], dim=dim)
@@ -4248,9 +4303,9 @@
         actual = t(x)
         expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
         self.assertEqual(expected, actual)
-        y1 = torch.range(1, 100).view(-1, 100)
-        y2 = torch.range(1, 100).view(-1, 100)
-        y3 = torch.range(1, 100).view(-1, 100)
+        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
+        y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
+        y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
         y = torch.cat([y1, y2, y3], dim=dim)
         actual_cod_check = t.codomain.check(y)
         expected_cod_check = torch.cat([t1.codomain.check(y1),
@@ -4267,9 +4322,9 @@
         self.assertEqual(actual_jac, expected_jac)
 
     def test_cat_transform_non_uniform(self):
-        x1 = -1 * torch.range(1, 100).view(-1, 100)
-        x2 = torch.cat([(torch.range(1, 100).view(-1, 100) - 1) / 100,
-                        torch.range(1, 100).view(-1, 100)])
+        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
+        x2 = torch.cat([(torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100,
+                        torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
         t1 = ExpTransform()
         t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
         dim = 0
@@ -4282,9 +4337,9 @@
         actual = t(x)
         expected = torch.cat([t1(x1), t2(x2)], dim=dim)
         self.assertEqual(expected, actual)
-        y1 = torch.range(1, 100).view(-1, 100)
-        y2 = torch.cat([torch.range(1, 100).view(-1, 100),
-                        torch.range(1, 100).view(-1, 100)])
+        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
+        y2 = torch.cat([torch.arange(1, 101, dtype=torch.float).view(-1, 100),
+                        torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
         y = torch.cat([y1, y2], dim=dim)
         actual_cod_check = t.codomain.check(y)
         expected_cod_check = torch.cat([t1.codomain.check(y1),
@@ -4299,9 +4354,9 @@
         self.assertEqual(actual_jac, expected_jac)
 
     def test_stack_transform(self):
-        x1 = -1 * torch.range(1, 100)
-        x2 = (torch.range(1, 100) - 1) / 100
-        x3 = torch.range(1, 100)
+        x1 = -1 * torch.arange(1, 101, dtype=torch.float)
+        x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100
+        x3 = torch.arange(1, 101, dtype=torch.float)
         t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
         dim = 0
         x = torch.stack([x1, x2, x3], dim=dim)
@@ -4314,9 +4369,9 @@
         actual = t(x)
         expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
         self.assertEqual(expected, actual)
-        y1 = torch.range(1, 100)
-        y2 = torch.range(1, 100)
-        y3 = torch.range(1, 100)
+        y1 = torch.arange(1, 101, dtype=torch.float)
+        y2 = torch.arange(1, 101, dtype=torch.float)
+        y3 = torch.arange(1, 101, dtype=torch.float)
         y = torch.stack([y1, y2, y3], dim=dim)
         actual_cod_check = t.codomain.check(y)
         expected_cod_check = torch.stack([t1.codomain.check(y1),
@@ -4509,6 +4564,7 @@
             xfail = [
                 Cauchy,  # aten::cauchy(Double(2,1), float, float, Generator)
                 HalfCauchy,  # aten::cauchy(Double(2, 1), float, float, Generator)
+                VonMises  # Variance is not Euclidean
             ]
             if Dist in xfail:
                 continue
diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py
index 9c71a96..8a16ec1 100644
--- a/torch/distributions/__init__.py
+++ b/torch/distributions/__init__.py
@@ -108,6 +108,7 @@
 from .transformed_distribution import TransformedDistribution
 from .transforms import *
 from .uniform import Uniform
+from .von_mises import VonMises
 from .weibull import Weibull
 
 __all__ = [
@@ -144,6 +145,7 @@
     'StudentT',
     'Poisson',
     'Uniform',
+    'VonMises',
     'Weibull',
     'TransformedDistribution',
     'biject_to',
diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py
new file mode 100644
index 0000000..8240d6d
--- /dev/null
+++ b/torch/distributions/von_mises.py
@@ -0,0 +1,140 @@
+from __future__ import absolute_import, division, print_function

+

+import math

+

+import torch

+import torch.jit

+from torch.distributions import constraints

+from torch.distributions.distribution import Distribution

+from torch.distributions.utils import broadcast_all, lazy_property

+

+

+def _eval_poly(y, coef):

+    coef = list(coef)

+    result = coef.pop()

+    while coef:

+        result = coef.pop() + y * result

+    return result

+

+

+_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]

+_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,

+                  -0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]

+_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]

+_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,

+                  0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]

+

+_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]

+_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]

+

+

+def _log_modified_bessel_fn(x, order=0):

+    """

+    Returns ``log(I_order(x))`` for ``x > 0``,

+    where `order` is either 0 or 1.

+    """

+    assert order == 0 or order == 1

+

+    # compute small solution

+    y = (x / 3.75)

+    y = y * y

+    small = _eval_poly(y, _COEF_SMALL[order])

+    if order == 1:

+        small = x.abs() * small

+    small = small.log()

+

+    # compute large solution

+    y = 3.75 / x

+    large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()

+

+    result = torch.where(x < 3.75, small, large)

+    return result

+

+

+@torch.jit.script

+def _rejection_sample(loc, concentration, proposal_r, x):

+    done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)

+    while not done.all():

+        u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)

+        u1, u2, u3 = u.unbind()

+        z = torch.cos(math.pi * u1)

+        f = (1 + proposal_r * z) / (proposal_r + z)

+        c = concentration * (proposal_r - f)

+        accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)

+        if accept.any():

+            x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)

+            done = done | accept

+    return (x + math.pi + loc) % (2 * math.pi) - math.pi

+

+

+class VonMises(Distribution):

+    """

+    A circular von Mises distribution.

+

+    This implementation uses polar coordinates. The ``loc`` and ``value`` args

+    can be any real number (to facilitate unconstrained optimization), but are

+    interpreted as angles modulo 2 pi.

+

+    Example::

+        >>> m = dist.VonMises(torch.tensor([1.0]), torch.tensor([1.0]))

+        >>> m.sample() # von Mises distributed with loc=1 and concentration=1

+        tensor([1.9777])

+

+    :param torch.Tensor loc: an angle in radians.

+    :param torch.Tensor concentration: concentration parameter

+    """

+    arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}

+    support = constraints.real

+    has_rsample = False

+

+    def __init__(self, loc, concentration, validate_args=None):

+        self.loc, self.concentration = broadcast_all(loc, concentration)

+        batch_shape = self.loc.shape

+        event_shape = torch.Size()

+

+        # Parameters for sampling

+        tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()

+        rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)

+        self._proposal_r = (1 + rho ** 2) / (2 * rho)

+

+        super(VonMises, self).__init__(batch_shape, event_shape, validate_args)

+

+    def log_prob(self, value):

+        log_prob = self.concentration * torch.cos(value - self.loc)

+        log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)

+        return log_prob

+

+    @torch.no_grad()

+    def sample(self, sample_shape=torch.Size()):

+        """

+        The sampling algorithm for the von Mises distribution is based on the following paper:

+        Best, D. J., and Nicholas I. Fisher.

+        "Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157.

+        """

+        shape = self._extended_shape(sample_shape)

+        x = torch.empty(shape, dtype=self.loc.dtype, device=self.loc.device)

+        return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)

+

+    def expand(self, batch_shape):

+        try:

+            return super(VonMises, self).expand(batch_shape)

+        except NotImplementedError:

+            validate_args = self.__dict__.get('_validate_args')

+            loc = self.loc.expand(batch_shape)

+            concentration = self.concentration.expand(batch_shape)

+            return type(self)(loc, concentration, validate_args=validate_args)

+

+    @property

+    def mean(self):

+        """

+        The provided mean is the circular one.

+        """

+        return self.loc

+

+    @lazy_property

+    def variance(self):

+        """

+        The provided variance is the circular one.

+        """

+        return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -

+                    _log_modified_bessel_fn(self.concentration, order=0)).exp()