| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Quantized distribution.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.distributions import distribution as distributions |
| from tensorflow.python.ops.distributions import util as distribution_util |
| from tensorflow.python.util import deprecation |
| |
| __all__ = ["QuantizedDistribution"] |
| |
| |
| @deprecation.deprecated( |
| "2018-10-01", |
| "The TensorFlow Distributions library has moved to " |
| "TensorFlow Probability " |
| "(https://github.com/tensorflow/probability). You " |
| "should update all references to use `tfp.distributions` " |
| "instead of `tf.contrib.distributions`.", |
| warn_once=True) |
| def _logsum_expbig_minus_expsmall(big, small): |
| """Stable evaluation of `Log[exp{big} - exp{small}]`. |
| |
| To work correctly, we should have the pointwise relation: `small <= big`. |
| |
| Args: |
| big: Floating-point `Tensor` |
| small: Floating-point `Tensor` with same `dtype` as `big` and broadcastable |
| shape. |
| |
| Returns: |
| `Tensor` of same `dtype` of `big` and broadcast shape. |
| """ |
| with ops.name_scope("logsum_expbig_minus_expsmall", values=[small, big]): |
| return math_ops.log(1. - math_ops.exp(small - big)) + big |
| |
| |
| _prob_base_note = """ |
| For whole numbers `y`, |
| |
| ``` |
| P[Y = y] := P[X <= low], if y == low, |
| := P[X > high - 1], y == high, |
| := 0, if j < low or y > high, |
| := P[y - 1 < X <= y], all other y. |
| ``` |
| |
| """ |
| |
| _prob_note = _prob_base_note + """ |
| The base distribution's `cdf` method must be defined on `y - 1`. If the |
| base distribution has a `survival_function` method, results will be more |
| accurate for large values of `y`, and in this case the `survival_function` must |
| also be defined on `y - 1`. |
| """ |
| |
| _log_prob_note = _prob_base_note + """ |
| The base distribution's `log_cdf` method must be defined on `y - 1`. If the |
| base distribution has a `log_survival_function` method results will be more |
| accurate for large values of `y`, and in this case the `log_survival_function` |
| must also be defined on `y - 1`. |
| """ |
| |
| |
| _cdf_base_note = """ |
| |
| For whole numbers `y`, |
| |
| ``` |
| cdf(y) := P[Y <= y] |
| = 1, if y >= high, |
| = 0, if y < low, |
| = P[X <= y], otherwise. |
| ``` |
| |
| Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. |
| This dictates that fractional `y` are first floored to a whole number, and |
| then above definition applies. |
| """ |
| |
| _cdf_note = _cdf_base_note + """ |
| The base distribution's `cdf` method must be defined on `y - 1`. |
| """ |
| |
| _log_cdf_note = _cdf_base_note + """ |
| The base distribution's `log_cdf` method must be defined on `y - 1`. |
| """ |
| |
| |
| _sf_base_note = """ |
| |
| For whole numbers `y`, |
| |
| ``` |
| survival_function(y) := P[Y > y] |
| = 0, if y >= high, |
| = 1, if y < low, |
| = P[X <= y], otherwise. |
| ``` |
| |
| Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. |
| This dictates that fractional `y` are first floored to a whole number, and |
| then above definition applies. |
| """ |
| |
| _sf_note = _sf_base_note + """ |
| The base distribution's `cdf` method must be defined on `y - 1`. |
| """ |
| |
| _log_sf_note = _sf_base_note + """ |
| The base distribution's `log_cdf` method must be defined on `y - 1`. |
| """ |
| |
| |
| class QuantizedDistribution(distributions.Distribution): |
| """Distribution representing the quantization `Y = ceiling(X)`. |
| |
| #### Definition in Terms of Sampling |
| |
| ``` |
| 1. Draw X |
| 2. Set Y <-- ceiling(X) |
| 3. If Y < low, reset Y <-- low |
| 4. If Y > high, reset Y <-- high |
| 5. Return Y |
| ``` |
| |
| #### Definition in Terms of the Probability Mass Function |
| |
| Given scalar random variable `X`, we define a discrete random variable `Y` |
| supported on the integers as follows: |
| |
| ``` |
| P[Y = j] := P[X <= low], if j == low, |
| := P[X > high - 1], j == high, |
| := 0, if j < low or j > high, |
| := P[j - 1 < X <= j], all other j. |
| ``` |
| |
| Conceptually, without cutoffs, the quantization process partitions the real |
| line `R` into half open intervals, and identifies an integer `j` with the |
| right endpoints: |
| |
| ``` |
| R = ... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4] ... |
| j = ... -1 0 1 2 3 4 ... |
| ``` |
| |
| `P[Y = j]` is the mass of `X` within the `jth` interval. |
| If `low = 0`, and `high = 2`, then the intervals are redrawn |
| and `j` is re-assigned: |
| |
| ``` |
| R = (-infty, 0](0, 1](1, infty) |
| j = 0 1 2 |
| ``` |
| |
| `P[Y = j]` is still the mass of `X` within the `jth` interval. |
| |
| #### Examples |
| |
| We illustrate a mixture of discretized logistic distributions |
| [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit |
| audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in |
| a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures |
| `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints. |
| The lowest value has probability `P(X <= 0.5)` and the highest value has |
| probability `P(2**16 - 1.5 < X)`. |
| |
| Below we assume a `wavenet` function. It takes as `input` right-shifted audio |
| samples of shape `[..., sequence_length]`. It returns a real-valued tensor of |
| shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and |
| `scale` parameter belonging to the logistic distribution, and a `logits` |
| parameter determining the unnormalized probability of that component. |
| |
| ```python |
| import tensorflow_probability as tfp |
| tfd = tfp.distributions |
| tfb = tfp.bijectors |
| |
| net = wavenet(inputs) |
| loc, unconstrained_scale, logits = tf.split(net, |
| num_or_size_splits=3, |
| axis=-1) |
| scale = tf.nn.softplus(unconstrained_scale) |
| |
| # Form mixture of discretized logistic distributions. Note we shift the |
| # logistic distribution by -0.5. This lets the quantization capture "rounding" |
| # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. |
| discretized_logistic_dist = tfd.QuantizedDistribution( |
| distribution=tfd.TransformedDistribution( |
| distribution=tfd.Logistic(loc=loc, scale=scale), |
| bijector=tfb.AffineScalar(shift=-0.5)), |
| low=0., |
| high=2**16 - 1.) |
| mixture_dist = tfd.MixtureSameFamily( |
| mixture_distribution=tfd.Categorical(logits=logits), |
| components_distribution=discretized_logistic_dist) |
| |
| neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) |
| train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood) |
| ``` |
| |
| After instantiating `mixture_dist`, we illustrate maximum likelihood by |
| calculating its log-probability of audio samples as `target` and optimizing. |
| |
| #### References |
| |
| [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. |
| PixelCNN++: Improving the PixelCNN with discretized logistic mixture |
| likelihood and other modifications. |
| _International Conference on Learning Representations_, 2017. |
| https://arxiv.org/abs/1701.05517 |
| [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech |
| Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. |
| https://arxiv.org/abs/1711.10433 |
| """ |
| |
| @deprecation.deprecated( |
| "2018-10-01", |
| "The TensorFlow Distributions library has moved to " |
| "TensorFlow Probability " |
| "(https://github.com/tensorflow/probability). You " |
| "should update all references to use `tfp.distributions` " |
| "instead of `tf.contrib.distributions`.", |
| warn_once=True) |
| def __init__(self, |
| distribution, |
| low=None, |
| high=None, |
| validate_args=False, |
| name="QuantizedDistribution"): |
| """Construct a Quantized Distribution representing `Y = ceiling(X)`. |
| |
| Some properties are inherited from the distribution defining `X`. Example: |
| `allow_nan_stats` is determined for this `QuantizedDistribution` by reading |
| the `distribution`. |
| |
| Args: |
| distribution: The base distribution class to transform. Typically an |
| instance of `Distribution`. |
| low: `Tensor` with same `dtype` as this distribution and shape |
| able to be added to samples. Should be a whole number. Default `None`. |
| If provided, base distribution's `prob` should be defined at |
| `low`. |
| high: `Tensor` with same `dtype` as this distribution and shape |
| able to be added to samples. Should be a whole number. Default `None`. |
| If provided, base distribution's `prob` should be defined at |
| `high - 1`. |
| `high` must be strictly greater than `low`. |
| validate_args: Python `bool`, default `False`. When `True` distribution |
| parameters are checked for validity despite possibly degrading runtime |
| performance. When `False` invalid inputs may silently render incorrect |
| outputs. |
| name: Python `str` name prefixed to Ops created by this class. |
| |
| Raises: |
| TypeError: If `dist_cls` is not a subclass of |
| `Distribution` or continuous. |
| NotImplementedError: If the base distribution does not implement `cdf`. |
| """ |
| parameters = dict(locals()) |
| values = ( |
| list(distribution.parameters.values()) + |
| [low, high]) |
| with ops.name_scope(name, values=values) as name: |
| self._dist = distribution |
| |
| if low is not None: |
| low = ops.convert_to_tensor(low, name="low") |
| if high is not None: |
| high = ops.convert_to_tensor(high, name="high") |
| check_ops.assert_same_float_dtype( |
| tensors=[self.distribution, low, high]) |
| |
| # We let QuantizedDistribution access _graph_parents since this class is |
| # more like a baseclass. |
| graph_parents = self._dist._graph_parents # pylint: disable=protected-access |
| |
| checks = [] |
| if validate_args and low is not None and high is not None: |
| message = "low must be strictly less than high." |
| checks.append( |
| check_ops.assert_less( |
| low, high, message=message)) |
| self._validate_args = validate_args # self._check_integer uses this. |
| with ops.control_dependencies(checks if validate_args else []): |
| if low is not None: |
| self._low = self._check_integer(low) |
| graph_parents += [self._low] |
| else: |
| self._low = None |
| if high is not None: |
| self._high = self._check_integer(high) |
| graph_parents += [self._high] |
| else: |
| self._high = None |
| |
| super(QuantizedDistribution, self).__init__( |
| dtype=self._dist.dtype, |
| reparameterization_type=distributions.NOT_REPARAMETERIZED, |
| validate_args=validate_args, |
| allow_nan_stats=self._dist.allow_nan_stats, |
| parameters=parameters, |
| graph_parents=graph_parents, |
| name=name) |
| |
| @property |
| def distribution(self): |
| """Base distribution, p(x).""" |
| return self._dist |
| |
| @property |
| def low(self): |
| """Lowest value that quantization returns.""" |
| return self._low |
| |
| @property |
| def high(self): |
| """Highest value that quantization returns.""" |
| return self._high |
| |
| def _batch_shape_tensor(self): |
| return self.distribution.batch_shape_tensor() |
| |
| def _batch_shape(self): |
| return self.distribution.batch_shape |
| |
| def _event_shape_tensor(self): |
| return self.distribution.event_shape_tensor() |
| |
| def _event_shape(self): |
| return self.distribution.event_shape |
| |
| def _sample_n(self, n, seed=None): |
| low = self._low |
| high = self._high |
| with ops.name_scope("transform"): |
| n = ops.convert_to_tensor(n, name="n") |
| x_samps = self.distribution.sample(n, seed=seed) |
| ones = array_ops.ones_like(x_samps) |
| |
| # Snap values to the intervals (j - 1, j]. |
| result_so_far = math_ops.ceil(x_samps) |
| |
| if low is not None: |
| result_so_far = array_ops.where(result_so_far < low, |
| low * ones, result_so_far) |
| |
| if high is not None: |
| result_so_far = array_ops.where(result_so_far > high, |
| high * ones, result_so_far) |
| |
| return result_so_far |
| |
| @distribution_util.AppendDocstring(_log_prob_note) |
| def _log_prob(self, y): |
| if not hasattr(self.distribution, "_log_cdf"): |
| raise NotImplementedError( |
| "'log_prob' not implemented unless the base distribution implements " |
| "'log_cdf'") |
| y = self._check_integer(y) |
| try: |
| return self._log_prob_with_logsf_and_logcdf(y) |
| except NotImplementedError: |
| return self._log_prob_with_logcdf(y) |
| |
| def _log_prob_with_logcdf(self, y): |
| return _logsum_expbig_minus_expsmall(self.log_cdf(y), self.log_cdf(y - 1)) |
| |
| def _log_prob_with_logsf_and_logcdf(self, y): |
| """Compute log_prob(y) using log survival_function and cdf together.""" |
| # There are two options that would be equal if we had infinite precision: |
| # Log[ sf(y - 1) - sf(y) ] |
| # = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ] |
| # Log[ cdf(y) - cdf(y - 1) ] |
| # = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ] |
| logsf_y = self.log_survival_function(y) |
| logsf_y_minus_1 = self.log_survival_function(y - 1) |
| logcdf_y = self.log_cdf(y) |
| logcdf_y_minus_1 = self.log_cdf(y - 1) |
| |
| # Important: Here we use select in a way such that no input is inf, this |
| # prevents the troublesome case where the output of select can be finite, |
| # but the output of grad(select) will be NaN. |
| |
| # In either case, we are doing Log[ exp{big} - exp{small} ] |
| # We want to use the sf items precisely when we are on the right side of the |
| # median, which occurs when logsf_y < logcdf_y. |
| big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y) |
| small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1) |
| |
| return _logsum_expbig_minus_expsmall(big, small) |
| |
| @distribution_util.AppendDocstring(_prob_note) |
| def _prob(self, y): |
| if not hasattr(self.distribution, "_cdf"): |
| raise NotImplementedError( |
| "'prob' not implemented unless the base distribution implements " |
| "'cdf'") |
| y = self._check_integer(y) |
| try: |
| return self._prob_with_sf_and_cdf(y) |
| except NotImplementedError: |
| return self._prob_with_cdf(y) |
| |
| def _prob_with_cdf(self, y): |
| return self.cdf(y) - self.cdf(y - 1) |
| |
| def _prob_with_sf_and_cdf(self, y): |
| # There are two options that would be equal if we had infinite precision: |
| # sf(y - 1) - sf(y) |
| # cdf(y) - cdf(y - 1) |
| sf_y = self.survival_function(y) |
| sf_y_minus_1 = self.survival_function(y - 1) |
| cdf_y = self.cdf(y) |
| cdf_y_minus_1 = self.cdf(y - 1) |
| |
| # sf_prob has greater precision iff we're on the right side of the median. |
| return array_ops.where( |
| sf_y < cdf_y, # True iff we're on the right side of the median. |
| sf_y_minus_1 - sf_y, |
| cdf_y - cdf_y_minus_1) |
| |
| @distribution_util.AppendDocstring(_log_cdf_note) |
| def _log_cdf(self, y): |
| low = self._low |
| high = self._high |
| |
| # Recall the promise: |
| # cdf(y) := P[Y <= y] |
| # = 1, if y >= high, |
| # = 0, if y < low, |
| # = P[X <= y], otherwise. |
| |
| # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in |
| # between. |
| j = math_ops.floor(y) |
| |
| result_so_far = self.distribution.log_cdf(j) |
| |
| # Broadcast, because it's possible that this is a single distribution being |
| # evaluated on a number of samples, or something like that. |
| j += array_ops.zeros_like(result_so_far) |
| |
| # Re-define values at the cutoffs. |
| if low is not None: |
| neg_inf = -np.inf * array_ops.ones_like(result_so_far) |
| result_so_far = array_ops.where(j < low, neg_inf, result_so_far) |
| if high is not None: |
| result_so_far = array_ops.where(j >= high, |
| array_ops.zeros_like(result_so_far), |
| result_so_far) |
| |
| return result_so_far |
| |
| @distribution_util.AppendDocstring(_cdf_note) |
| def _cdf(self, y): |
| low = self._low |
| high = self._high |
| |
| # Recall the promise: |
| # cdf(y) := P[Y <= y] |
| # = 1, if y >= high, |
| # = 0, if y < low, |
| # = P[X <= y], otherwise. |
| |
| # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in |
| # between. |
| j = math_ops.floor(y) |
| |
| # P[X <= j], used when low < X < high. |
| result_so_far = self.distribution.cdf(j) |
| |
| # Broadcast, because it's possible that this is a single distribution being |
| # evaluated on a number of samples, or something like that. |
| j += array_ops.zeros_like(result_so_far) |
| |
| # Re-define values at the cutoffs. |
| if low is not None: |
| result_so_far = array_ops.where(j < low, |
| array_ops.zeros_like(result_so_far), |
| result_so_far) |
| if high is not None: |
| result_so_far = array_ops.where(j >= high, |
| array_ops.ones_like(result_so_far), |
| result_so_far) |
| |
| return result_so_far |
| |
| @distribution_util.AppendDocstring(_log_sf_note) |
| def _log_survival_function(self, y): |
| low = self._low |
| high = self._high |
| |
| # Recall the promise: |
| # survival_function(y) := P[Y > y] |
| # = 0, if y >= high, |
| # = 1, if y < low, |
| # = P[X > y], otherwise. |
| |
| # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in |
| # between. |
| j = math_ops.ceil(y) |
| |
| # P[X > j], used when low < X < high. |
| result_so_far = self.distribution.log_survival_function(j) |
| |
| # Broadcast, because it's possible that this is a single distribution being |
| # evaluated on a number of samples, or something like that. |
| j += array_ops.zeros_like(result_so_far) |
| |
| # Re-define values at the cutoffs. |
| if low is not None: |
| result_so_far = array_ops.where(j < low, |
| array_ops.zeros_like(result_so_far), |
| result_so_far) |
| if high is not None: |
| neg_inf = -np.inf * array_ops.ones_like(result_so_far) |
| result_so_far = array_ops.where(j >= high, neg_inf, result_so_far) |
| |
| return result_so_far |
| |
| @distribution_util.AppendDocstring(_sf_note) |
| def _survival_function(self, y): |
| low = self._low |
| high = self._high |
| |
| # Recall the promise: |
| # survival_function(y) := P[Y > y] |
| # = 0, if y >= high, |
| # = 1, if y < low, |
| # = P[X > y], otherwise. |
| |
| # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in |
| # between. |
| j = math_ops.ceil(y) |
| |
| # P[X > j], used when low < X < high. |
| result_so_far = self.distribution.survival_function(j) |
| |
| # Broadcast, because it's possible that this is a single distribution being |
| # evaluated on a number of samples, or something like that. |
| j += array_ops.zeros_like(result_so_far) |
| |
| # Re-define values at the cutoffs. |
| if low is not None: |
| result_so_far = array_ops.where(j < low, |
| array_ops.ones_like(result_so_far), |
| result_so_far) |
| if high is not None: |
| result_so_far = array_ops.where(j >= high, |
| array_ops.zeros_like(result_so_far), |
| result_so_far) |
| |
| return result_so_far |
| |
| def _check_integer(self, value): |
| with ops.name_scope("check_integer", values=[value]): |
| value = ops.convert_to_tensor(value, name="value") |
| if not self.validate_args: |
| return value |
| dependencies = [distribution_util.assert_integer_form( |
| value, message="value has non-integer components.")] |
| return control_flow_ops.with_dependencies(dependencies, value) |