Declare constraints for distribution parameters and support (#4450)

diff --git a/test/test_distributions.py b/test/test_distributions.py
index 543d59c..f7860c1 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -33,6 +33,7 @@
 from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
                                  Dirichlet, Exponential, Gamma, Laplace,
                                  Normal, OneHotCategorical, Pareto, Uniform)
+from torch.distributions.constraints import Constraint, is_dependent
 
 TEST_NUMPY = True
 try:
@@ -104,7 +105,7 @@
         },
         {
             'loc': Variable(torch.randn(1), requires_grad=True),
-            'scale': Variable(torch.randn(1), requires_grad=True),
+            'scale': Variable(torch.randn(1).abs(), requires_grad=True),
         },
         {
             'loc': torch.Tensor([1.0, 0.0]),
@@ -118,7 +119,7 @@
         },
         {
             'mean': Variable(torch.randn(1), requires_grad=True),
-            'std': Variable(torch.randn(1), requires_grad=True),
+            'std': Variable(torch.randn(1).abs(), requires_grad=True),
         },
         {
             'mean': torch.Tensor([1.0, 0.0]),
@@ -1055,5 +1056,36 @@
         self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
 
 
+class TestConstraints(TestCase):
+    def test_params_contains(self):
+        for Dist, params in EXAMPLES:
+            for i, param in enumerate(params):
+                dist = Dist(**param)
+                for name, value in param.items():
+                    if not (torch.is_tensor(value) or isinstance(value, Variable)):
+                        value = torch.Tensor([value])
+                    if Dist in (Categorical, OneHotCategorical) and name == 'probs':
+                        # These distributions accept positive probs, but elsewhere we
+                        # use a stricter constraint to the simplex.
+                        value = value / value.sum(-1, True)
+                    constraint = dist.params[name]
+                    if is_dependent(constraint):
+                        continue
+                    message = '{} example {}/{} parameter {} = {}'.format(
+                        Dist.__name__, i, len(params), name, value)
+                    self.assertTrue(constraint.check(value).all(), msg=message)
+
+    def test_support_contains(self):
+        for Dist, params in EXAMPLES:
+            self.assertIsInstance(Dist.support, Constraint)
+            for i, param in enumerate(params):
+                dist = Dist(**param)
+                value = dist.sample()
+                constraint = dist.support
+                message = '{} example {}/{} sample = {}'.format(
+                    Dist.__name__, i, len(params), value)
+                self.assertTrue(constraint.check(value).all(), msg=message)
+
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index 25765e8..9aa53d9 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -2,6 +2,7 @@
 
 import torch
 from torch.autograd import Variable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -23,6 +24,8 @@
     Args:
         probs (Number, Tensor or Variable): the probabilty of sampling `1`
     """
+    params = {'probs': constraints.unit_interval}
+    support = constraints.boolean
     has_enumerate_support = True
 
     def __init__(self, probs):
diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py
index 7953a8f..12b266d 100644
--- a/torch/distributions/beta.py
+++ b/torch/distributions/beta.py
@@ -2,6 +2,7 @@
 
 import torch
 from torch.autograd import Variable
+from torch.distributions import constraints
 from torch.distributions.dirichlet import Dirichlet
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
@@ -21,6 +22,8 @@
     Args:
         alpha (Tensor or Variable): concentration parameter of the distribution
     """
+    params = {'alpha': constraints.positive, 'beta': constraints.positive}
+    support = constraints.unit_interval
     has_rsample = True
 
     def __init__(self, alpha, beta):
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index dfaa2b4..4d0d7ae 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -1,5 +1,6 @@
 import torch
 from torch.autograd import Variable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 
 
@@ -29,6 +30,7 @@
     Args:
         probs (Tensor or Variable): event probabilities
     """
+    params = {'probs': constraints.simplex}
     has_enumerate_support = True
 
     def __init__(self, probs):
@@ -36,6 +38,10 @@
         batch_shape = self.probs.size()[:-1]
         super(Categorical, self).__init__(batch_shape)
 
+    @constraints.dependent_property
+    def support(self):
+        return constraints.integer_interval(0, self.probs.size()[-1] - 1)
+
     def sample(self, sample_shape=torch.Size()):
         num_events = self.probs.size()[-1]
         sample_shape = self._extended_shape(sample_shape)
diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py
index 15e9ae4..26f9068 100644
--- a/torch/distributions/cauchy.py
+++ b/torch/distributions/cauchy.py
@@ -1,8 +1,8 @@
+import math
 from numbers import Number
 
-import math
-
 import torch
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -24,6 +24,8 @@
         loc (float or Tensor or Variable): mode or median of the distribution.
         scale (float or Tensor or Variable): half width at half maximum.
     """
+    params = {'loc': constraints.real, 'scale': constraints.positive}
+    support = constraints.real
     has_rsample = True
 
     def __init__(self, loc, scale):
diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py
index 0e68086..66a1142 100644
--- a/torch/distributions/chi2.py
+++ b/torch/distributions/chi2.py
@@ -1,4 +1,5 @@
-from torch.distributions import Gamma

+from torch.distributions import constraints

+from torch.distributions.gamma import Gamma

 

 

 class Chi2(Gamma):

@@ -16,6 +17,7 @@
     Args:

         df (float or Tensor or Variable): shape parameter of the distribution

     """

+    params = {'df': constraints.positive}

 

     def __init__(self, df):

         super(Chi2, self).__init__(0.5 * df, 0.5)

diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py
new file mode 100644
index 0000000..f8b1545
--- /dev/null
+++ b/torch/distributions/constraints.py
@@ -0,0 +1,156 @@
+import torch
+
+
+__all__ = [
+    'Constraint',
+    'boolean',
+    'dependent',
+    'dependent_property',
+    'greater_than',
+    'integer_interval',
+    'interval',
+    'is_dependent',
+    'lower_triangular',
+    'nonnegative_integer',
+    'positive',
+    'real',
+    'simplex',
+    'unit_interval',
+]
+
+
+class Constraint(object):
+    """
+    Abstract base class for constraints.
+
+    A constraint object represents a region over which a variable is valid,
+    e.g. within which a variable can be optimized.
+    """
+    def check(self, value):
+        """
+        Returns a byte tensor of `sample_shape + batch_shape` indicating
+        whether each event in value satisfies this constraint.
+        """
+        raise NotImplementedError
+
+
+class _Dependent(Constraint):
+    """
+    Placeholder for variables whose support depends on other variables.
+    These variables obey no simple coordinate-wise constraints.
+    """
+    def check(self, x):
+        raise ValueError('Cannot determine validity of dependent constraint')
+
+
+def is_dependent(constraint):
+    return isinstance(constraint, _Dependent)
+
+
+class _DependentProperty(property, _Dependent):
+    """
+    Decorator that extends @property to act like a `Dependent` constraint when
+    called on a class and act like a property when called on an object.
+
+    Example::
+
+        class Uniform(Distribution):
+            def __init__(self, low, high):
+                self.low = low
+                self.high = high
+            @constraints.dependent_property
+            def support(self):
+                return constraints.interval(self.low, self.high)
+    """
+    pass
+
+
+class _Boolean(Constraint):
+    """
+    Constrain to the two values `{0, 1}`.
+    """
+    def check(self, value):
+        return (value == 0) | (value == 1)
+
+
+class _NonnegativeInteger(Constraint):
+    """
+    Constrain to non-negative integers `{0, 1, 2, ...}`.
+    """
+    def check(self, value):
+        return (value % 1 == 0) & (value >= 0)
+
+
+class _IntegerInterval(Constraint):
+    """
+    Constrain to an integer interval `[lower_bound, upper_bound]`.
+    """
+    def __init__(self, lower_bound, upper_bound):
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+
+    def check(self, value):
+        return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
+
+
+class _Real(Constraint):
+    """
+    Trivially constrain to the extended real line `[-inf, inf]`.
+    """
+    def check(self, value):
+        return value == value  # False for NANs.
+
+
+class _GreaterThan(Constraint):
+    """
+    Constrain to a real half line `[lower_bound, inf]`.
+    """
+    def __init__(self, lower_bound):
+        self.lower_bound = lower_bound
+
+    def check(self, value):
+        return self.lower_bound <= value
+
+
+class _Interval(Constraint):
+    """
+    Constrain to a real interval `[lower_bound, upper_bound]`.
+    """
+    def __init__(self, lower_bound, upper_bound):
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+
+    def check(self, value):
+        return (self.lower_bound <= value) & (value <= self.upper_bound)
+
+
+class _Simplex(Constraint):
+    """
+    Constrain to the unit simplex in the innermost (rightmost) dimension.
+    Specifically: `x >= 0` and `x.sum(-1) == 1`.
+    """
+    def check(self, value):
+        return (value >= 0) & ((value.sum(-1, True) - 1).abs() < 1e-6)
+
+
+class _LowerTriangular(Constraint):
+    """
+    Constrain to lower-triangular square matrices.
+    """
+    def check(self, value):
+        return (torch.tril(value) == value).min(-1).min(-1)
+
+
+# Public interface.
+dependent = _Dependent()
+dependent_property = _DependentProperty
+boolean = _Boolean()
+nonnegative_integer = _NonnegativeInteger()
+integer_interval = _IntegerInterval
+real = _Real()
+positive = _GreaterThan(0)
+greater_than = _GreaterThan
+unit_interval = _Interval(0, 1)
+interval = _Interval
+simplex = _Simplex()
+lower_triangular = _LowerTriangular()
diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py
index 6384308..97f96a9 100644
--- a/torch/distributions/dirichlet.py
+++ b/torch/distributions/dirichlet.py
@@ -3,6 +3,7 @@
 import torch
 from torch.autograd import Function, Variable
 from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -43,6 +44,8 @@
     Args:
         alpha (Tensor or Variable): concentration parameter of the distribution
     """
+    params = {'alpha': constraints.positive}
+    support = constraints.simplex
     has_rsample = True
 
     def __init__(self, alpha):
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index 58ed005..727ee70 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -29,6 +29,23 @@
         """
         return self._event_shape
 
+    @property
+    def params(self):
+        """
+        Returns a dictionary from param names to `Constraint` objects that
+        should be satisfied by each parameter of this distribution. For
+        distributions with multiple parameterization, only one complete
+        set of parameters should be specified in `.params`.
+        """
+        raise NotImplementedError
+
+    @property
+    def support(self):
+        """
+        Returns a `Constraint` object representing this distribution's support.
+        """
+        raise NotImplementedError
+
     def sample(self, sample_shape=torch.Size()):
         """
         Generates a sample_shape shaped sample or sample_shape shaped batch of
diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py
index 999feef..14895a3 100644
--- a/torch/distributions/exponential.py
+++ b/torch/distributions/exponential.py
@@ -1,5 +1,7 @@
 from numbers import Number
+
 import torch
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -18,6 +20,8 @@
     Args:
         rate (float or Tensor or Variable): rate = 1 / scale of the distribution
     """
+    params = {'rate': constraints.positive}
+    support = constraints.positive
     has_rsample = True
 
     def __init__(self, rate):
diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py
index 948d72d..f78bb30 100644
--- a/torch/distributions/gamma.py
+++ b/torch/distributions/gamma.py
@@ -3,6 +3,7 @@
 import torch
 from torch.autograd import Function, Variable
 from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -28,6 +29,8 @@
         alpha (float or Tensor or Variable): shape parameter of the distribution
         beta (float or Tensor or Variable): rate = 1 / scale of the distribution
     """
+    params = {'alpha': constraints.positive, 'beta': constraints.positive}
+    support = constraints.positive
     has_rsample = True
 
     def __init__(self, alpha, beta):
diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py
index 826fe3a..3938da7 100644
--- a/torch/distributions/laplace.py
+++ b/torch/distributions/laplace.py
@@ -1,5 +1,7 @@
 from numbers import Number

+

 import torch

+from torch.distributions import constraints

 from torch.distributions.distribution import Distribution

 from torch.distributions.utils import broadcast_all

 

@@ -19,6 +21,8 @@
         loc (float or Tensor or Variable): mean of the distribution

         scale (float or Tensor or Variable): scale of the distribution

     """

+    params = {'loc': constraints.real, 'scale': constraints.positive}

+    support = constraints.real

     has_rsample = True

 

     def __init__(self, loc, scale):

diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 0ded6ea..60aff6b 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -3,6 +3,7 @@
 
 import torch
 from torch.autograd import Variable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -23,6 +24,8 @@
         mean (float or Tensor or Variable): mean of the distribution
         std (float or Tensor or Variable): standard deviation of the distribution
     """
+    params = {'mean': constraints.real, 'std': constraints.positive}
+    support = constraints.real
     has_rsample = True
 
     def __init__(self, mean, std):
diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py
index 514b157..52c2794 100644
--- a/torch/distributions/one_hot_categorical.py
+++ b/torch/distributions/one_hot_categorical.py
@@ -1,7 +1,8 @@
 import torch
 from torch.autograd import Variable
-from torch.distributions.distribution import Distribution
+from torch.distributions import constraints
 from torch.distributions.categorical import Categorical
+from torch.distributions.distribution import Distribution
 
 
 class OneHotCategorical(Distribution):
@@ -25,6 +26,8 @@
     Args:
         probs (Tensor or Variable): event probabilities
     """
+    params = {'probs': constraints.simplex}
+    support = constraints.simplex
     has_enumerate_support = True
 
     def __init__(self, probs):
diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py
index 2c97cb8..e78222e 100644
--- a/torch/distributions/pareto.py
+++ b/torch/distributions/pareto.py
@@ -3,6 +3,7 @@
 import math
 
 import torch
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -23,6 +24,7 @@
         alpha (float or Tensor or Variable): Shape parameter of the distribution
     """
     has_rsample = True
+    params = {'alpha': constraints.positive, 'scale': constraints.positive}
 
     def __init__(self, scale, alpha):
         self.scale, self.alpha = broadcast_all(scale, alpha)
@@ -32,6 +34,10 @@
             batch_shape = self.scale.size()
         super(Pareto, self).__init__(batch_shape)
 
+    @constraints.dependent_property
+    def support(self):
+        return constraints.greater_than(self.scale)
+
     def rsample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
         exp_dist = self.alpha.new(shape).exponential_()
diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py
index f0bd196..ae278f2 100644
--- a/torch/distributions/uniform.py
+++ b/torch/distributions/uniform.py
@@ -3,6 +3,7 @@
 
 import torch
 from torch.autograd import Variable
+from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.utils import broadcast_all
 
@@ -23,6 +24,8 @@
         low (float or Tensor or Variable): lower range (inclusive).
         high (float or Tensor or Variable): upper range (exclusive).
     """
+    # TODO allow (loc,scale) parameterization to allow independent constraints.
+    params = {'low': constraints.dependent, 'high': constraints.dependent}
     has_rsample = True
 
     def __init__(self, low, high):
@@ -33,6 +36,10 @@
             batch_shape = self.low.size()
         super(Uniform, self).__init__(batch_shape)
 
+    @constraints.dependent_property
+    def support(self):
+        return constraints.interval(self.low, self.high)
+
     def rsample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
         rand = self.low.new(shape).uniform_()