blob: 178c3c94bfd319e3182a60054ea55c4ccaf01607 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Permutation bijectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.util import deprecation
__all__ = [
"Permute",
]
class Permute(bijector.Bijector):
"""Permutes the rightmost dimension of a `Tensor`.
```python
import tensorflow_probability as tfp
tfb = tfp.bijectors
reverse = tfb.Permute(permutation=[2, 1, 0])
reverse.forward([-1., 0., 1.])
# ==> [1., 0., -1]
reverse.inverse([1., 0., -1])
# ==> [-1., 0., 1.]
reverse.forward_log_det_jacobian(any_value)
# ==> 0.
reverse.inverse_log_det_jacobian(any_value)
# ==> 0.
```
Warning: `tf.estimator` may repeatedly build the graph thus
`Permute(np.random.permutation(event_size)).astype("int32"))` is not a
reliable parameterization (nor would it be even if using `tf.constant`). A
safe alternative is to use `tf.get_variable` to achieve "init once" behavior,
i.e.,
```python
def init_once(x, name):
return tf.get_variable(name, initializer=x, trainable=False)
Permute(permutation=init_once(
np.random.permutation(event_size).astype("int32"),
name="permutation"))
```
"""
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def __init__(self, permutation, validate_args=False, name=None):
"""Creates the `Permute` bijector.
Args:
permutation: An `int`-like vector-shaped `Tensor` representing the
permutation to apply to the rightmost dimension of the transformed
`Tensor`.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str`, name given to ops managed by this object.
Raises:
TypeError: if `not permutation.dtype.is_integer`.
ValueError: if `permutation` does not contain exactly one of each of
`{0, 1, ..., d}`.
"""
with ops.name_scope(name, "permute", values=[permutation]):
permutation = ops.convert_to_tensor(
permutation,
name="permutation")
if not permutation.dtype.is_integer:
raise TypeError("permutation.dtype ({}) should be `int`-like.".format(
permutation.dtype.name))
p = tensor_util.constant_value(permutation)
if p is not None:
if set(p) != set(np.arange(p.size)):
raise ValueError("Permutation over `d` must contain exactly one of "
"each of `{0, 1, ..., d}`.")
elif validate_args:
p, _ = nn_ops.top_k(-permutation,
k=array_ops.shape(permutation)[-1],
sorted=True)
permutation = control_flow_ops.with_dependencies([
check_ops.assert_equal(
-p, math_ops.range(array_ops.size(p)),
message=("Permutation over `d` must contain exactly one of "
"each of `{0, 1, ..., d}`.")),
], permutation)
self._permutation = permutation
super(Permute, self).__init__(
forward_min_event_ndims=1,
is_constant_jacobian=True,
validate_args=validate_args,
name=name or "permute")
@property
def permutation(self):
return self._permutation
def _forward(self, x):
return array_ops.gather(x, self.permutation, axis=-1)
def _inverse(self, y):
return array_ops.gather(
y,
array_ops.invert_permutation(self.permutation),
axis=-1)
def _inverse_log_det_jacobian(self, y):
# is_constant_jacobian = True for this bijector, hence the
# `log_det_jacobian` need only be specified for a single input, as this will
# be tiled to match `event_ndims`.
return constant_op.constant(0., dtype=y.dtype.base_dtype)
def _forward_log_det_jacobian(self, x):
return constant_op.constant(0., dtype=x.dtype.base_dtype)