blob: 0bcb08cdea7142b82af3116245306a11773ef93c [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Real NVP bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import template as template_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.util import deprecation
__all__ = [
"RealNVP",
"real_nvp_default_template"
]
class RealNVP(bijector.Bijector):
"""RealNVP "affine coupling layer" for vector-valued events.
Real NVP models a normalizing flow on a `D`-dimensional distribution via a
single `D-d`-dimensional conditional distribution [(Dinh et al., 2017)][1]:
`y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])`
`y[0:d] = x[0:d]`
The last `D-d` units are scaled and shifted based on the first `d` units only,
while the first `d` units are 'masked' and left unchanged. Real NVP's
`shift_and_log_scale_fn` computes vector-valued quantities. For
scale-and-shift transforms that do not depend on any masked units, i.e.
`d=0`, use the `tfb.Affine` bijector with learned parameters instead.
Masking is currently only supported for base distributions with
`event_ndims=1`. For more sophisticated masking schemes like checkerboard or
channel-wise masking [(Papamakarios et al., 2016)[4], use the `tfb.Permute`
bijector to re-order desired masked units into the first `d` units. For base
distributions with `event_ndims > 1`, use the `tfb.Reshape` bijector to
flatten the event shape.
Recall that the MAF bijector [(Papamakarios et al., 2016)][4] implements a
normalizing flow via an autoregressive transformation. MAF and IAF have
opposite computational tradeoffs - MAF can train all units in parallel but
must sample units sequentially, while IAF must train units sequentially but
can sample in parallel. In contrast, Real NVP can compute both forward and
inverse computations in parallel. However, the lack of an autoregressive
transformations makes it less expressive on a per-bijector basis.
A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or
"mu" in [Papamakarios et al. (2016)][4]) and `log(scale)` (aka "alpha" in
[Papamakarios et al. (2016)][4]) such that each are broadcastable with the
arguments to `forward` and `inverse`, i.e., such that the calculations in
`forward`, `inverse` [below] are possible. For convenience,
`real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn`
function.
NICE [(Dinh et al., 2014)][2] is a special case of the Real NVP bijector
which discards the scale transformation, resulting in a constant-time
inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real
NVP, `shift_and_log_scale_fn` should return `(shift, None)`, and
`is_constant_jacobian` should be set to `True` in the `RealNVP` constructor.
Calling `real_nvp_default_template` with `shift_only=True` returns one such
NICE-compatible `shift_and_log_scale_fn`.
Caching: the scalar input depth `D` of the base distribution is not known at
construction time. The first call to any of `forward(x)`, `inverse(x)`,
`inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes
`D`, which is re-used in subsequent calls. This shape must be known prior to
graph execution (which is the case if using tf.layers).
#### Example Use
```python
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
nvp = tfd.TransformedDistribution(
distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])),
bijector=tfb.RealNVP(
num_masked=2,
shift_and_log_scale_fn=tfb.real_nvp_default_template(
hidden_layers=[512, 512])))
x = nvp.sample()
nvp.log_prob(x)
nvp.log_prob(0.)
```
For more examples, see [Jang (2018)][3].
#### References
[1]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation
using Real NVP. In _International Conference on Learning
Representations_, 2017. https://arxiv.org/abs/1605.08803
[2]: Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear
Independent Components Estimation. _arXiv preprint arXiv:1410.8516_,
2014. https://arxiv.org/abs/1410.8516
[3]: Eric Jang. Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows.
_Technical Report_, 2018. http://blog.evjang.com/2018/01/nf2.html
[4]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked
Autoregressive Flow for Density Estimation. In _Neural Information
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
"""
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def __init__(self,
num_masked,
shift_and_log_scale_fn,
is_constant_jacobian=False,
validate_args=False,
name=None):
"""Creates the Real NVP or NICE bijector.
Args:
num_masked: Python `int` indicating that the first `d` units of the event
should be masked. Must be in the closed interval `[1, D-1]`, where `D`
is the event size of the base distribution.
shift_and_log_scale_fn: Python `callable` which computes `shift` and
`log_scale` from both the forward domain (`x`) and the inverse domain
(`y`). Calculation must respect the "autoregressive property" (see class
docstring). Suggested default
`masked_autoregressive_default_template(hidden_layers=...)`.
Typically the function contains `tf.Variables` and is wrapped using
`tf.make_template`. Returning `None` for either (both) `shift`,
`log_scale` is equivalent to (but more efficient than) returning zero.
is_constant_jacobian: Python `bool`. Default: `False`. When `True` the
implementation assumes `log_scale` does not depend on the forward domain
(`x`) or inverse domain (`y`) values. (No validation is made;
`is_constant_jacobian=False` is always safe but possibly computationally
inefficient.)
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str`, name given to ops managed by this object.
Raises:
ValueError: If num_masked < 1.
"""
name = name or "real_nvp"
if num_masked <= 0:
raise ValueError("num_masked must be a positive integer.")
self._num_masked = num_masked
# At construction time, we don't know input_depth.
self._input_depth = None
self._shift_and_log_scale_fn = shift_and_log_scale_fn
super(RealNVP, self).__init__(
forward_min_event_ndims=1,
is_constant_jacobian=is_constant_jacobian,
validate_args=validate_args,
name=name)
def _cache_input_depth(self, x):
if self._input_depth is None:
self._input_depth = x.shape.with_rank_at_least(1)[-1].value
if self._input_depth is None:
raise NotImplementedError(
"Rightmost dimension must be known prior to graph execution.")
if self._num_masked >= self._input_depth:
raise ValueError(
"Number of masked units must be smaller than the event size.")
def _forward(self, x):
self._cache_input_depth(x)
# Performs scale and shift.
x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:]
shift, log_scale = self._shift_and_log_scale_fn(
x0, self._input_depth - self._num_masked)
y1 = x1
if log_scale is not None:
y1 *= math_ops.exp(log_scale)
if shift is not None:
y1 += shift
y = array_ops.concat([x0, y1], axis=-1)
return y
def _inverse(self, y):
self._cache_input_depth(y)
# Performs un-shift and un-scale.
y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:]
shift, log_scale = self._shift_and_log_scale_fn(
y0, self._input_depth - self._num_masked)
x1 = y1
if shift is not None:
x1 -= shift
if log_scale is not None:
x1 *= math_ops.exp(-log_scale)
x = array_ops.concat([y0, x1], axis=-1)
return x
def _inverse_log_det_jacobian(self, y):
self._cache_input_depth(y)
y0 = y[:, :self._num_masked]
_, log_scale = self._shift_and_log_scale_fn(
y0, self._input_depth - self._num_masked)
if log_scale is None:
return constant_op.constant(0., dtype=y.dtype, name="ildj")
return -math_ops.reduce_sum(log_scale, axis=-1)
def _forward_log_det_jacobian(self, x):
self._cache_input_depth(x)
x0 = x[:, :self._num_masked]
_, log_scale = self._shift_and_log_scale_fn(
x0, self._input_depth - self._num_masked)
if log_scale is None:
return constant_op.constant(0., dtype=x.dtype, name="fldj")
return math_ops.reduce_sum(log_scale, axis=-1)
@deprecation.deprecated(
"2018-10-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.contrib.distributions`.",
warn_once=True)
def real_nvp_default_template(
hidden_layers,
shift_only=False,
activation=nn_ops.relu,
name=None,
*args,
**kwargs):
"""Build a scale-and-shift function using a multi-layer neural network.
This will be wrapped in a make_template to ensure the variables are only
created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d`
dimensional outputs `loc` ("mu") and `log_scale` ("alpha").
Arguments:
hidden_layers: Python `list`-like of non-negative integer, scalars
indicating the number of units in each hidden layer. Default: `[512, 512].
shift_only: Python `bool` indicating if only the `shift` term shall be
computed (i.e. NICE bijector). Default: `False`.
activation: Activation function (callable). Explicitly setting to `None`
implies a linear activation.
name: A name for ops managed by this function. Default:
"real_nvp_default_template".
*args: `tf.layers.dense` arguments.
**kwargs: `tf.layers.dense` keyword arguments.
Returns:
shift: `Float`-like `Tensor` of shift terms ("mu" in
[Papamakarios et al. (2016)][1]).
log_scale: `Float`-like `Tensor` of log(scale) terms ("alpha" in
[Papamakarios et al. (2016)][1]).
Raises:
NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
graph execution.
#### References
[1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked
Autoregressive Flow for Density Estimation. In _Neural Information
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
"""
with ops.name_scope(name, "real_nvp_default_template"):
def _fn(x, output_units):
"""Fully connected MLP parameterized via `real_nvp_template`."""
for units in hidden_layers:
x = layers.dense(
inputs=x,
units=units,
activation=activation,
*args,
**kwargs)
x = layers.dense(
inputs=x,
units=(1 if shift_only else 2) * output_units,
activation=None,
*args,
**kwargs)
if shift_only:
return x, None
shift, log_scale = array_ops.split(x, 2, axis=-1)
return shift, log_scale
return template_ops.make_template(
"real_nvp_default_template", _fn)