blob: 95ef5b7358b27f70de11036713190c77c9333dd0 [file] [log] [blame]
import torch
__all__ = [
'Constraint',
'boolean',
'dependent',
'dependent_property',
'greater_than',
'integer_interval',
'interval',
'is_dependent',
'less_than',
'lower_triangular',
'nonnegative_integer',
'positive',
'real',
'simplex',
'unit_interval',
]
class Constraint(object):
"""
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
def check(self, value):
"""
Returns a byte tensor of `sample_shape + batch_shape` indicating
whether each event in value satisfies this constraint.
"""
raise NotImplementedError
class _Dependent(Constraint):
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def check(self, x):
raise ValueError('Cannot determine validity of dependent constraint')
def is_dependent(constraint):
return isinstance(constraint, _Dependent)
class _DependentProperty(property, _Dependent):
"""
Decorator that extends @property to act like a `Dependent` constraint when
called on a class and act like a property when called on an object.
Example::
class Uniform(Distribution):
def __init__(self, low, high):
self.low = low
self.high = high
@constraints.dependent_property
def support(self):
return constraints.interval(self.low, self.high)
"""
pass
class _Boolean(Constraint):
"""
Constrain to the two values `{0, 1}`.
"""
def check(self, value):
return (value == 0) | (value == 1)
class _NonnegativeInteger(Constraint):
"""
Constrain to non-negative integers `{0, 1, 2, ...}`.
"""
def check(self, value):
return (value % 1 == 0) & (value >= 0)
class _IntegerInterval(Constraint):
"""
Constrain to an integer interval `[lower_bound, upper_bound]`.
"""
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
class _Real(Constraint):
"""
Trivially constrain to the extended real line `[-inf, inf]`.
"""
def check(self, value):
return value == value # False for NANs.
class _GreaterThan(Constraint):
"""
Constrain to a real half line `[lower_bound, inf]`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
def check(self, value):
return self.lower_bound <= value
class _LessThan(Constraint):
"""
Constrain to a real half line `[inf, upper_bound]`.
"""
def __init__(self, upper_bound):
self.upper_bound = upper_bound
def check(self, value):
return value <= self.upper_bound
class _Interval(Constraint):
"""
Constrain to a real interval `[lower_bound, upper_bound]`.
"""
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
def check(self, value):
return (self.lower_bound <= value) & (value <= self.upper_bound)
class _Simplex(Constraint):
"""
Constrain to the unit simplex in the innermost (rightmost) dimension.
Specifically: `x >= 0` and `x.sum(-1) == 1`.
"""
def check(self, value):
return (value >= 0) & ((value.sum(-1, True) - 1).abs() < 1e-6)
class _LowerTriangular(Constraint):
"""
Constrain to lower-triangular square matrices.
"""
def check(self, value):
return (torch.tril(value) == value).min(-1).min(-1)
# Public interface.
dependent = _Dependent()
dependent_property = _DependentProperty
boolean = _Boolean()
nonnegative_integer = _NonnegativeInteger()
integer_interval = _IntegerInterval
real = _Real()
positive = _GreaterThan(0)
greater_than = _GreaterThan
less_than = _LessThan
unit_interval = _Interval(0, 1)
interval = _Interval
simplex = _Simplex()
lower_triangular = _LowerTriangular()