blob: 507c5d36794df75c09d2293ed66111c17c06af37 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Deterministic distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.util import deprecation
__all__ = [
"Deterministic",
"VectorDeterministic",
]
@six.add_metaclass(abc.ABCMeta)
class _BaseDeterministic(distribution.Distribution):
"""Base class for Deterministic distributions."""
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def __init__(self,
loc,
atol=None,
rtol=None,
is_vector=False,
validate_args=False,
allow_nan_stats=True,
name="_BaseDeterministic"):
"""Initialize a batch of `_BaseDeterministic` distributions.
The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
computations, e.g. due to floating-point error.
```
pmf(x; loc)
= 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
= 0, otherwise.
```
Args:
loc: Numeric `Tensor`. The point (or batch of points) on which this
distribution is supported.
atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The absolute tolerance for comparing closeness to `loc`.
Default is `0`.
rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The relative tolerance for comparing closeness to `loc`.
Default is `0`.
is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`,
else `Deterministic`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
ValueError: If `loc` is a scalar.
"""
parameters = dict(locals())
with ops.name_scope(name, values=[loc, atol, rtol]) as name:
loc = ops.convert_to_tensor(loc, name="loc")
if is_vector and validate_args:
msg = "Argument loc must be at least rank 1."
if loc.get_shape().ndims is not None:
if loc.get_shape().ndims < 1:
raise ValueError(msg)
else:
loc = control_flow_ops.with_dependencies(
[check_ops.assert_rank_at_least(loc, 1, message=msg)], loc)
self._loc = loc
super(_BaseDeterministic, self).__init__(
dtype=self._loc.dtype,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._loc],
name=name)
self._atol = self._get_tol(atol)
self._rtol = self._get_tol(rtol)
# Avoid using the large broadcast with self.loc if possible.
if rtol is None:
self._slack = self.atol
else:
self._slack = self.atol + self.rtol * math_ops.abs(self.loc)
def _get_tol(self, tol):
if tol is None:
return ops.convert_to_tensor(0, dtype=self.loc.dtype)
tol = ops.convert_to_tensor(tol, dtype=self.loc.dtype)
if self.validate_args:
tol = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
tol, message="Argument 'tol' must be non-negative")
], tol)
return tol
@property
def loc(self):
"""Point (or batch of points) at which this distribution is supported."""
return self._loc
@property
def atol(self):
"""Absolute tolerance for comparing points to `self.loc`."""
return self._atol
@property
def rtol(self):
"""Relative tolerance for comparing points to `self.loc`."""
return self._rtol
def _entropy(self):
return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype)
def _mean(self):
return array_ops.identity(self.loc)
def _variance(self):
return array_ops.zeros_like(self.loc)
def _mode(self):
return self.mean()
def _sample_n(self, n, seed=None): # pylint: disable=unused-arg
n_static = tensor_util.constant_value(ops.convert_to_tensor(n))
if n_static is not None and self.loc.get_shape().ndims is not None:
ones = [1] * self.loc.get_shape().ndims
multiples = [n_static] + ones
else:
ones = array_ops.ones_like(array_ops.shape(self.loc))
multiples = array_ops.concat(([n], ones), axis=0)
return array_ops.tile(self.loc[array_ops.newaxis, ...], multiples=multiples)
class Deterministic(_BaseDeterministic):
"""Scalar `Deterministic` distribution on the real line.
The scalar `Deterministic` distribution is parameterized by a [batch] point
`loc` on the real line. The distribution is supported at this point only,
and corresponds to a random variable that is constant, equal to `loc`.
See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution).
#### Mathematical Details
The probability mass function (pmf) and cumulative distribution function (cdf)
are
```none
pmf(x; loc) = 1, if x == loc, else 0
cdf(x; loc) = 1, if x >= loc, else 0
```
#### Examples
```python
import tensorflow_probability as tfp
tfd = tfp.distributions
# Initialize a single Deterministic supported at zero.
constant = tfd.Deterministic(0.)
constant.prob(0.)
==> 1.
constant.prob(2.)
==> 0.
# Initialize a [2, 2] batch of scalar constants.
loc = [[0., 1.], [2., 3.]]
x = [[0., 1.1], [1.99, 3.]]
constant = tfd.Deterministic(loc)
constant.prob(x)
==> [[1., 0.], [0., 1.]]
```
"""
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def __init__(self,
loc,
atol=None,
rtol=None,
validate_args=False,
allow_nan_stats=True,
name="Deterministic"):
"""Initialize a scalar `Deterministic` distribution.
The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
computations, e.g. due to floating-point error.
```
pmf(x; loc)
= 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
= 0, otherwise.
```
Args:
loc: Numeric `Tensor` of shape `[B1, ..., Bb]`, with `b >= 0`.
The point (or batch of points) on which this distribution is supported.
atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The absolute tolerance for comparing closeness to `loc`.
Default is `0`.
rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The relative tolerance for comparing closeness to `loc`.
Default is `0`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
super(Deterministic, self).__init__(
loc,
atol=atol,
rtol=rtol,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
def _batch_shape_tensor(self):
return array_ops.shape(self.loc)
def _batch_shape(self):
return self.loc.get_shape()
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self):
return tensor_shape.scalar()
def _prob(self, x):
return math_ops.cast(
math_ops.abs(x - self.loc) <= self._slack, dtype=self.dtype)
def _cdf(self, x):
return math_ops.cast(x >= self.loc - self._slack, dtype=self.dtype)
class VectorDeterministic(_BaseDeterministic):
"""Vector `Deterministic` distribution on `R^k`.
The `VectorDeterministic` distribution is parameterized by a [batch] point
`loc in R^k`. The distribution is supported at this point only,
and corresponds to a random variable that is constant, equal to `loc`.
See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution).
#### Mathematical Details
The probability mass function (pmf) is
```none
pmf(x; loc)
= 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)],
= 0, otherwise.
```
#### Examples
```python
import tensorflow_probability as tfp
tfd = tfp.distributions
# Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
constant = tfd.Deterministic([0., 2.])
constant.prob([0., 2.])
==> 1.
constant.prob([0., 3.])
==> 0.
# Initialize a [3] batch of constants on R^2.
loc = [[0., 1.], [2., 3.], [4., 5.]]
constant = tfd.VectorDeterministic(loc)
constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]])
==> [1., 0., 0.]
```
"""
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def __init__(self,
loc,
atol=None,
rtol=None,
validate_args=False,
allow_nan_stats=True,
name="VectorDeterministic"):
"""Initialize a `VectorDeterministic` distribution on `R^k`, for `k >= 0`.
Note that there is only one point in `R^0`, the "point" `[]`. So if `k = 0`
then `self.prob([]) == 1`.
The `atol` and `rtol` parameters allow for some slack in `pmf`
computations, e.g. due to floating-point error.
```
pmf(x; loc)
= 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)],
= 0, otherwise
```
Args:
loc: Numeric `Tensor` of shape `[B1, ..., Bb, k]`, with `b >= 0`, `k >= 0`
The point (or batch of points) on which this distribution is supported.
atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The absolute tolerance for comparing closeness to `loc`.
Default is `0`.
rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
shape. The relative tolerance for comparing closeness to `loc`.
Default is `0`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
super(VectorDeterministic, self).__init__(
loc,
atol=atol,
rtol=rtol,
is_vector=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
def _batch_shape_tensor(self):
return array_ops.shape(self.loc)[:-1]
def _batch_shape(self):
return self.loc.get_shape()[:-1]
def _event_shape_tensor(self):
return array_ops.shape(self.loc)[-1]
def _event_shape(self):
return self.loc.get_shape()[-1:]
def _prob(self, x):
if self.validate_args:
is_vector_check = check_ops.assert_rank_at_least(x, 1)
right_vec_space_check = check_ops.assert_equal(
self.event_shape_tensor(),
array_ops.gather(array_ops.shape(x), array_ops.rank(x) - 1),
message=
"Argument 'x' not defined in the same space R^k as this distribution")
with ops.control_dependencies([is_vector_check]):
with ops.control_dependencies([right_vec_space_check]):
x = array_ops.identity(x)
return math_ops.cast(
math_ops.reduce_all(math_ops.abs(x - self.loc) <= self._slack, axis=-1),
dtype=self.dtype)