Declare constraints for distribution parameters and support (#4450)
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 543d59c..f7860c1 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -33,6 +33,7 @@
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, Gamma, Laplace,
Normal, OneHotCategorical, Pareto, Uniform)
+from torch.distributions.constraints import Constraint, is_dependent
TEST_NUMPY = True
try:
@@ -104,7 +105,7 @@
},
{
'loc': Variable(torch.randn(1), requires_grad=True),
- 'scale': Variable(torch.randn(1), requires_grad=True),
+ 'scale': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, 0.0]),
@@ -118,7 +119,7 @@
},
{
'mean': Variable(torch.randn(1), requires_grad=True),
- 'std': Variable(torch.randn(1), requires_grad=True),
+ 'std': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'mean': torch.Tensor([1.0, 0.0]),
@@ -1055,5 +1056,36 @@
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
+class TestConstraints(TestCase):
+ def test_params_contains(self):
+ for Dist, params in EXAMPLES:
+ for i, param in enumerate(params):
+ dist = Dist(**param)
+ for name, value in param.items():
+ if not (torch.is_tensor(value) or isinstance(value, Variable)):
+ value = torch.Tensor([value])
+ if Dist in (Categorical, OneHotCategorical) and name == 'probs':
+ # These distributions accept positive probs, but elsewhere we
+ # use a stricter constraint to the simplex.
+ value = value / value.sum(-1, True)
+ constraint = dist.params[name]
+ if is_dependent(constraint):
+ continue
+ message = '{} example {}/{} parameter {} = {}'.format(
+ Dist.__name__, i, len(params), name, value)
+ self.assertTrue(constraint.check(value).all(), msg=message)
+
+ def test_support_contains(self):
+ for Dist, params in EXAMPLES:
+ self.assertIsInstance(Dist.support, Constraint)
+ for i, param in enumerate(params):
+ dist = Dist(**param)
+ value = dist.sample()
+ constraint = dist.support
+ message = '{} example {}/{} sample = {}'.format(
+ Dist.__name__, i, len(params), value)
+ self.assertTrue(constraint.check(value).all(), msg=message)
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index 25765e8..9aa53d9 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -2,6 +2,7 @@
import torch
from torch.autograd import Variable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -23,6 +24,8 @@
Args:
probs (Number, Tensor or Variable): the probabilty of sampling `1`
"""
+ params = {'probs': constraints.unit_interval}
+ support = constraints.boolean
has_enumerate_support = True
def __init__(self, probs):
diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py
index 7953a8f..12b266d 100644
--- a/torch/distributions/beta.py
+++ b/torch/distributions/beta.py
@@ -2,6 +2,7 @@
import torch
from torch.autograd import Variable
+from torch.distributions import constraints
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -21,6 +22,8 @@
Args:
alpha (Tensor or Variable): concentration parameter of the distribution
"""
+ params = {'alpha': constraints.positive, 'beta': constraints.positive}
+ support = constraints.unit_interval
has_rsample = True
def __init__(self, alpha, beta):
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index dfaa2b4..4d0d7ae 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -1,5 +1,6 @@
import torch
from torch.autograd import Variable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
@@ -29,6 +30,7 @@
Args:
probs (Tensor or Variable): event probabilities
"""
+ params = {'probs': constraints.simplex}
has_enumerate_support = True
def __init__(self, probs):
@@ -36,6 +38,10 @@
batch_shape = self.probs.size()[:-1]
super(Categorical, self).__init__(batch_shape)
+ @constraints.dependent_property
+ def support(self):
+ return constraints.integer_interval(0, self.probs.size()[-1] - 1)
+
def sample(self, sample_shape=torch.Size()):
num_events = self.probs.size()[-1]
sample_shape = self._extended_shape(sample_shape)
diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py
index 15e9ae4..26f9068 100644
--- a/torch/distributions/cauchy.py
+++ b/torch/distributions/cauchy.py
@@ -1,8 +1,8 @@
+import math
from numbers import Number
-import math
-
import torch
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -24,6 +24,8 @@
loc (float or Tensor or Variable): mode or median of the distribution.
scale (float or Tensor or Variable): half width at half maximum.
"""
+ params = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
has_rsample = True
def __init__(self, loc, scale):
diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py
index 0e68086..66a1142 100644
--- a/torch/distributions/chi2.py
+++ b/torch/distributions/chi2.py
@@ -1,4 +1,5 @@
-from torch.distributions import Gamma
+from torch.distributions import constraints
+from torch.distributions.gamma import Gamma
class Chi2(Gamma):
@@ -16,6 +17,7 @@
Args:
df (float or Tensor or Variable): shape parameter of the distribution
"""
+ params = {'df': constraints.positive}
def __init__(self, df):
super(Chi2, self).__init__(0.5 * df, 0.5)
diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py
new file mode 100644
index 0000000..f8b1545
--- /dev/null
+++ b/torch/distributions/constraints.py
@@ -0,0 +1,156 @@
+import torch
+
+
+__all__ = [
+ 'Constraint',
+ 'boolean',
+ 'dependent',
+ 'dependent_property',
+ 'greater_than',
+ 'integer_interval',
+ 'interval',
+ 'is_dependent',
+ 'lower_triangular',
+ 'nonnegative_integer',
+ 'positive',
+ 'real',
+ 'simplex',
+ 'unit_interval',
+]
+
+
+class Constraint(object):
+ """
+ Abstract base class for constraints.
+
+ A constraint object represents a region over which a variable is valid,
+ e.g. within which a variable can be optimized.
+ """
+ def check(self, value):
+ """
+ Returns a byte tensor of `sample_shape + batch_shape` indicating
+ whether each event in value satisfies this constraint.
+ """
+ raise NotImplementedError
+
+
+class _Dependent(Constraint):
+ """
+ Placeholder for variables whose support depends on other variables.
+ These variables obey no simple coordinate-wise constraints.
+ """
+ def check(self, x):
+ raise ValueError('Cannot determine validity of dependent constraint')
+
+
+def is_dependent(constraint):
+ return isinstance(constraint, _Dependent)
+
+
+class _DependentProperty(property, _Dependent):
+ """
+ Decorator that extends @property to act like a `Dependent` constraint when
+ called on a class and act like a property when called on an object.
+
+ Example::
+
+ class Uniform(Distribution):
+ def __init__(self, low, high):
+ self.low = low
+ self.high = high
+ @constraints.dependent_property
+ def support(self):
+ return constraints.interval(self.low, self.high)
+ """
+ pass
+
+
+class _Boolean(Constraint):
+ """
+ Constrain to the two values `{0, 1}`.
+ """
+ def check(self, value):
+ return (value == 0) | (value == 1)
+
+
+class _NonnegativeInteger(Constraint):
+ """
+ Constrain to non-negative integers `{0, 1, 2, ...}`.
+ """
+ def check(self, value):
+ return (value % 1 == 0) & (value >= 0)
+
+
+class _IntegerInterval(Constraint):
+ """
+ Constrain to an integer interval `[lower_bound, upper_bound]`.
+ """
+ def __init__(self, lower_bound, upper_bound):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
+
+
+class _Real(Constraint):
+ """
+ Trivially constrain to the extended real line `[-inf, inf]`.
+ """
+ def check(self, value):
+ return value == value # False for NANs.
+
+
+class _GreaterThan(Constraint):
+ """
+ Constrain to a real half line `[lower_bound, inf]`.
+ """
+ def __init__(self, lower_bound):
+ self.lower_bound = lower_bound
+
+ def check(self, value):
+ return self.lower_bound <= value
+
+
+class _Interval(Constraint):
+ """
+ Constrain to a real interval `[lower_bound, upper_bound]`.
+ """
+ def __init__(self, lower_bound, upper_bound):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (self.lower_bound <= value) & (value <= self.upper_bound)
+
+
+class _Simplex(Constraint):
+ """
+ Constrain to the unit simplex in the innermost (rightmost) dimension.
+ Specifically: `x >= 0` and `x.sum(-1) == 1`.
+ """
+ def check(self, value):
+ return (value >= 0) & ((value.sum(-1, True) - 1).abs() < 1e-6)
+
+
+class _LowerTriangular(Constraint):
+ """
+ Constrain to lower-triangular square matrices.
+ """
+ def check(self, value):
+ return (torch.tril(value) == value).min(-1).min(-1)
+
+
+# Public interface.
+dependent = _Dependent()
+dependent_property = _DependentProperty
+boolean = _Boolean()
+nonnegative_integer = _NonnegativeInteger()
+integer_interval = _IntegerInterval
+real = _Real()
+positive = _GreaterThan(0)
+greater_than = _GreaterThan
+unit_interval = _Interval(0, 1)
+interval = _Interval
+simplex = _Simplex()
+lower_triangular = _LowerTriangular()
diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py
index 6384308..97f96a9 100644
--- a/torch/distributions/dirichlet.py
+++ b/torch/distributions/dirichlet.py
@@ -3,6 +3,7 @@
import torch
from torch.autograd import Function, Variable
from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -43,6 +44,8 @@
Args:
alpha (Tensor or Variable): concentration parameter of the distribution
"""
+ params = {'alpha': constraints.positive}
+ support = constraints.simplex
has_rsample = True
def __init__(self, alpha):
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index 58ed005..727ee70 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -29,6 +29,23 @@
"""
return self._event_shape
+ @property
+ def params(self):
+ """
+ Returns a dictionary from param names to `Constraint` objects that
+ should be satisfied by each parameter of this distribution. For
+ distributions with multiple parameterization, only one complete
+ set of parameters should be specified in `.params`.
+ """
+ raise NotImplementedError
+
+ @property
+ def support(self):
+ """
+ Returns a `Constraint` object representing this distribution's support.
+ """
+ raise NotImplementedError
+
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py
index 999feef..14895a3 100644
--- a/torch/distributions/exponential.py
+++ b/torch/distributions/exponential.py
@@ -1,5 +1,7 @@
from numbers import Number
+
import torch
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -18,6 +20,8 @@
Args:
rate (float or Tensor or Variable): rate = 1 / scale of the distribution
"""
+ params = {'rate': constraints.positive}
+ support = constraints.positive
has_rsample = True
def __init__(self, rate):
diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py
index 948d72d..f78bb30 100644
--- a/torch/distributions/gamma.py
+++ b/torch/distributions/gamma.py
@@ -3,6 +3,7 @@
import torch
from torch.autograd import Function, Variable
from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -28,6 +29,8 @@
alpha (float or Tensor or Variable): shape parameter of the distribution
beta (float or Tensor or Variable): rate = 1 / scale of the distribution
"""
+ params = {'alpha': constraints.positive, 'beta': constraints.positive}
+ support = constraints.positive
has_rsample = True
def __init__(self, alpha, beta):
diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py
index 826fe3a..3938da7 100644
--- a/torch/distributions/laplace.py
+++ b/torch/distributions/laplace.py
@@ -1,5 +1,7 @@
from numbers import Number
+
import torch
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -19,6 +21,8 @@
loc (float or Tensor or Variable): mean of the distribution
scale (float or Tensor or Variable): scale of the distribution
"""
+ params = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
has_rsample = True
def __init__(self, loc, scale):
diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 0ded6ea..60aff6b 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -3,6 +3,7 @@
import torch
from torch.autograd import Variable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -23,6 +24,8 @@
mean (float or Tensor or Variable): mean of the distribution
std (float or Tensor or Variable): standard deviation of the distribution
"""
+ params = {'mean': constraints.real, 'std': constraints.positive}
+ support = constraints.real
has_rsample = True
def __init__(self, mean, std):
diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py
index 514b157..52c2794 100644
--- a/torch/distributions/one_hot_categorical.py
+++ b/torch/distributions/one_hot_categorical.py
@@ -1,7 +1,8 @@
import torch
from torch.autograd import Variable
-from torch.distributions.distribution import Distribution
+from torch.distributions import constraints
from torch.distributions.categorical import Categorical
+from torch.distributions.distribution import Distribution
class OneHotCategorical(Distribution):
@@ -25,6 +26,8 @@
Args:
probs (Tensor or Variable): event probabilities
"""
+ params = {'probs': constraints.simplex}
+ support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs):
diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py
index 2c97cb8..e78222e 100644
--- a/torch/distributions/pareto.py
+++ b/torch/distributions/pareto.py
@@ -3,6 +3,7 @@
import math
import torch
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -23,6 +24,7 @@
alpha (float or Tensor or Variable): Shape parameter of the distribution
"""
has_rsample = True
+ params = {'alpha': constraints.positive, 'scale': constraints.positive}
def __init__(self, scale, alpha):
self.scale, self.alpha = broadcast_all(scale, alpha)
@@ -32,6 +34,10 @@
batch_shape = self.scale.size()
super(Pareto, self).__init__(batch_shape)
+ @constraints.dependent_property
+ def support(self):
+ return constraints.greater_than(self.scale)
+
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
exp_dist = self.alpha.new(shape).exponential_()
diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py
index f0bd196..ae278f2 100644
--- a/torch/distributions/uniform.py
+++ b/torch/distributions/uniform.py
@@ -3,6 +3,7 @@
import torch
from torch.autograd import Variable
+from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
@@ -23,6 +24,8 @@
low (float or Tensor or Variable): lower range (inclusive).
high (float or Tensor or Variable): upper range (exclusive).
"""
+ # TODO allow (loc,scale) parameterization to allow independent constraints.
+ params = {'low': constraints.dependent, 'high': constraints.dependent}
has_rsample = True
def __init__(self, low, high):
@@ -33,6 +36,10 @@
batch_shape = self.low.size()
super(Uniform, self).__init__(batch_shape)
+ @constraints.dependent_property
+ def support(self):
+ return constraints.interval(self.low, self.high)
+
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
rand = self.low.new(shape).uniform_()