| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """The BatchReshape distribution.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.distributions import distribution as distribution_lib |
| from tensorflow.python.util import deprecation |
| |
| |
| __all__ = [ |
| "BatchReshape", |
| ] |
| |
| |
| class BatchReshape(distribution_lib.Distribution): |
| """The Batch-Reshaping distribution. |
| |
| This "meta-distribution" reshapes the batch dimensions of another |
| distribution. |
| |
| #### Examples |
| |
| ```python |
| import tensorflow_probability as tfp |
| tfd = tfp.distributions |
| |
| dtype = np.float32 |
| dims = 2 |
| new_batch_shape = [1, 2, -1] |
| old_batch_shape = [6] |
| |
| scale = np.ones(old_batch_shape + [dims], dtype) |
| mvn = tfd.MultivariateNormalDiag(scale_diag=scale) |
| reshape_mvn = tfd.BatchReshape( |
| distribution=mvn, |
| batch_shape=new_batch_shape, |
| validate_args=True) |
| |
| reshape_mvn.batch_shape |
| # ==> [1, 2, 3] |
| |
| x = reshape_mvn.sample(sample_shape=[4, 5]) |
| x.shape |
| # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] |
| |
| reshape_mvn.log_prob(x).shape |
| # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape |
| ``` |
| |
| """ |
| |
| @deprecation.deprecated( |
| "2018-10-01", |
| "The TensorFlow Distributions library has moved to " |
| "TensorFlow Probability " |
| "(https://github.com/tensorflow/probability). You " |
| "should update all references to use `tfp.distributions` " |
| "instead of `tf.contrib.distributions`.", |
| warn_once=True) |
| def __init__(self, |
| distribution, |
| batch_shape, |
| validate_args=False, |
| allow_nan_stats=True, |
| name=None): |
| """Construct BatchReshape distribution. |
| |
| Args: |
| distribution: The base distribution instance to reshape. Typically an |
| instance of `Distribution`. |
| batch_shape: Positive `int`-like vector-shaped `Tensor` representing |
| the new shape of the batch dimensions. Up to one dimension may contain |
| `-1`, meaning the remainder of the batch size. |
| validate_args: Python `bool`, default `False`. When `True` distribution |
| parameters are checked for validity despite possibly degrading runtime |
| performance. When `False` invalid inputs may silently render incorrect |
| outputs. |
| allow_nan_stats: Python `bool`, default `True`. When `True`, statistics |
| (e.g., mean, mode, variance) use the value "`NaN`" to indicate the |
| result is undefined. When `False`, an exception is raised if one or |
| more of the statistic's batch members are undefined. |
| name: The name to give Ops created by the initializer. |
| Default value: `"BatchReshape" + distribution.name`. |
| |
| Raises: |
| ValueError: if `batch_shape` is not a vector. |
| ValueError: if `batch_shape` has non-positive elements. |
| ValueError: if `batch_shape` size is not the same as a |
| `distribution.batch_shape` size. |
| """ |
| parameters = dict(locals()) |
| name = name or "BatchReshape" + distribution.name |
| with ops.name_scope(name, values=[batch_shape]) as name: |
| # The unexpanded batch shape may contain up to one dimension of -1. |
| self._batch_shape_unexpanded = ops.convert_to_tensor( |
| batch_shape, dtype=dtypes.int32, name="batch_shape") |
| validate_init_args_statically(distribution, self._batch_shape_unexpanded) |
| batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( |
| distribution.batch_shape_tensor(), self._batch_shape_unexpanded, |
| validate_args) |
| self._distribution = distribution |
| self._batch_shape_ = batch_shape |
| self._batch_shape_static = batch_shape_static |
| self._runtime_assertions = runtime_assertions |
| super(BatchReshape, self).__init__( |
| dtype=distribution.dtype, |
| reparameterization_type=distribution.reparameterization_type, |
| validate_args=validate_args, |
| allow_nan_stats=allow_nan_stats, |
| parameters=parameters, |
| graph_parents=( |
| [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access |
| name=name) |
| |
| @property |
| def distribution(self): |
| return self._distribution |
| |
| def _batch_shape_tensor(self): |
| with ops.control_dependencies(self._runtime_assertions): |
| return array_ops.identity(self._batch_shape_) |
| |
| def _batch_shape(self): |
| return self._batch_shape_static |
| |
| def _event_shape_tensor(self): |
| with ops.control_dependencies(self._runtime_assertions): |
| return array_ops.identity(self.distribution.event_shape_tensor()) |
| |
| def _event_shape(self): |
| return self.distribution.event_shape |
| |
| def _sample_n(self, n, seed=None): |
| with ops.control_dependencies(self._runtime_assertions): |
| x = self.distribution.sample(sample_shape=n, seed=seed) |
| new_shape = array_ops.concat( |
| [ |
| [n], |
| self._batch_shape_unexpanded, |
| self.event_shape_tensor(), |
| ], |
| axis=0) |
| return array_ops.reshape(x, new_shape) |
| |
| def _log_prob(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.log_prob, x) |
| |
| def _prob(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.prob, x) |
| |
| def _log_cdf(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.log_cdf, x) |
| |
| def _cdf(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.cdf, x) |
| |
| def _log_survival_function(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.log_survival_function, x) |
| |
| def _survival_function(self, x): |
| return self._call_reshape_input_output( |
| self.distribution.survival_function, x) |
| |
| def _entropy(self): |
| return self._call_and_reshape_output( |
| self.distribution.entropy, |
| [], |
| [tensor_shape.scalar()]) |
| |
| def _mean(self): |
| return self._call_and_reshape_output(self.distribution.mean) |
| |
| def _mode(self): |
| return self._call_and_reshape_output(self.distribution.mode) |
| |
| def _stddev(self): |
| return self._call_and_reshape_output(self.distribution.stddev) |
| |
| def _variance(self): |
| return self._call_and_reshape_output(self.distribution.variance) |
| |
| def _covariance(self): |
| return self._call_and_reshape_output( |
| self.distribution.covariance, |
| [self.event_shape_tensor()]*2, |
| [self.event_shape]*2) |
| |
| def _sample_shape(self, x): |
| """Computes graph and static `sample_shape`.""" |
| x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) |
| event_ndims = (array_ops.size(self.event_shape_tensor()) |
| if self.event_shape.ndims is None |
| else self.event_shape.ndims) |
| batch_ndims = ( |
| array_ops.size(self._batch_shape_unexpanded) |
| if self.batch_shape.ndims is None else self.batch_shape.ndims) |
| sample_ndims = x_ndims - batch_ndims - event_ndims |
| if isinstance(sample_ndims, int): |
| static_sample_shape = x.shape[:sample_ndims] |
| else: |
| static_sample_shape = tensor_shape.TensorShape(None) |
| if static_sample_shape.is_fully_defined(): |
| sample_shape = np.int32(static_sample_shape.as_list()) |
| else: |
| sample_shape = array_ops.shape(x)[:sample_ndims] |
| return sample_shape, static_sample_shape |
| |
| def _call_reshape_input_output(self, fn, x): |
| """Calls `fn`, appropriately reshaping its input `x` and output.""" |
| with ops.control_dependencies( |
| self._runtime_assertions + self._validate_sample_arg(x)): |
| sample_shape, static_sample_shape = self._sample_shape(x) |
| old_shape = array_ops.concat([ |
| sample_shape, |
| self.distribution.batch_shape_tensor(), |
| self.event_shape_tensor(), |
| ], axis=0) |
| result = fn(array_ops.reshape(x, old_shape)) |
| new_shape = array_ops.concat( |
| [ |
| sample_shape, |
| self._batch_shape_unexpanded, |
| ], axis=0) |
| result = array_ops.reshape(result, new_shape) |
| if (static_sample_shape.ndims is not None and |
| self.batch_shape.ndims is not None): |
| new_shape = static_sample_shape.concatenate(self.batch_shape) |
| result.set_shape(result.shape.merge_with(new_shape)) |
| return result |
| |
| def _call_and_reshape_output( |
| self, |
| fn, |
| event_shape_list=None, |
| static_event_shape_list=None): |
| """Calls `fn` and appropriately reshapes its output.""" |
| with ops.control_dependencies(self._runtime_assertions): |
| if event_shape_list is None: |
| event_shape_list = [self._event_shape_tensor()] |
| if static_event_shape_list is None: |
| static_event_shape_list = [self.event_shape] |
| new_shape = array_ops.concat( |
| [self._batch_shape_unexpanded] + event_shape_list, axis=0) |
| result = array_ops.reshape(fn(), new_shape) |
| if (self.batch_shape.ndims is not None and |
| self.event_shape.ndims is not None): |
| event_shape = tensor_shape.TensorShape([]) |
| for rss in static_event_shape_list: |
| event_shape = event_shape.concatenate(rss) |
| static_shape = result.shape.merge_with( |
| self.batch_shape.concatenate(event_shape)) |
| result.set_shape(static_shape) |
| return result |
| |
| def _validate_sample_arg(self, x): |
| """Helper which validates sample arg, e.g., input to `log_prob`.""" |
| with ops.name_scope(name="validate_sample_arg", values=[x]): |
| x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) |
| event_ndims = (array_ops.size(self.event_shape_tensor()) |
| if self.event_shape.ndims is None |
| else self.event_shape.ndims) |
| batch_ndims = ( |
| array_ops.size(self._batch_shape_unexpanded) |
| if self.batch_shape.ndims is None else self.batch_shape.ndims) |
| expected_batch_event_ndims = batch_ndims + event_ndims |
| |
| if (isinstance(x_ndims, int) and |
| isinstance(expected_batch_event_ndims, int)): |
| if x_ndims < expected_batch_event_ndims: |
| raise NotImplementedError( |
| "Broadcasting is not supported; too few batch and event dims " |
| "(expected at least {}, saw {}).".format( |
| expected_batch_event_ndims, x_ndims)) |
| ndims_assertion = [] |
| elif self.validate_args: |
| ndims_assertion = [ |
| check_ops.assert_greater_equal( |
| x_ndims, |
| expected_batch_event_ndims, |
| message=("Broadcasting is not supported; too few " |
| "batch and event dims."), |
| name="assert_batch_and_event_ndims_large_enough"), |
| ] |
| |
| if (self.batch_shape.is_fully_defined() and |
| self.event_shape.is_fully_defined()): |
| expected_batch_event_shape = np.int32(self.batch_shape.concatenate( |
| self.event_shape).as_list()) |
| else: |
| expected_batch_event_shape = array_ops.concat([ |
| self.batch_shape_tensor(), |
| self.event_shape_tensor(), |
| ], axis=0) |
| |
| sample_ndims = x_ndims - expected_batch_event_ndims |
| if isinstance(sample_ndims, int): |
| sample_ndims = max(sample_ndims, 0) |
| if (isinstance(sample_ndims, int) and |
| x.shape[sample_ndims:].is_fully_defined()): |
| actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) |
| else: |
| sample_ndims = math_ops.maximum(sample_ndims, 0) |
| actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] |
| |
| if (isinstance(expected_batch_event_shape, np.ndarray) and |
| isinstance(actual_batch_event_shape, np.ndarray)): |
| if any(expected_batch_event_shape != actual_batch_event_shape): |
| raise NotImplementedError("Broadcasting is not supported; " |
| "unexpected batch and event shape " |
| "(expected {}, saw {}).".format( |
| expected_batch_event_shape, |
| actual_batch_event_shape)) |
| # We need to set the final runtime-assertions to `ndims_assertion` since |
| # its possible this assertion was created. We could add a condition to |
| # only do so if `self.validate_args == True`, however this is redundant |
| # as `ndims_assertion` already encodes this information. |
| runtime_assertions = ndims_assertion |
| elif self.validate_args: |
| # We need to make the `ndims_assertion` a control dep because otherwise |
| # TF itself might raise an exception owing to this assertion being |
| # ill-defined, ie, one cannot even compare different rank Tensors. |
| with ops.control_dependencies(ndims_assertion): |
| shape_assertion = check_ops.assert_equal( |
| expected_batch_event_shape, |
| actual_batch_event_shape, |
| message=("Broadcasting is not supported; " |
| "unexpected batch and event shape."), |
| name="assert_batch_and_event_shape_same") |
| runtime_assertions = [shape_assertion] |
| else: |
| runtime_assertions = [] |
| |
| return runtime_assertions |
| |
| |
| @deprecation.deprecated( |
| "2018-10-01", |
| "The TensorFlow Distributions library has moved to " |
| "TensorFlow Probability " |
| "(https://github.com/tensorflow/probability). You " |
| "should update all references to use `tfp.distributions` " |
| "instead of `tf.contrib.distributions`.", |
| warn_once=True) |
| def calculate_reshape(original_shape, new_shape, validate=False, name=None): |
| """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" |
| batch_shape_static = tensor_util.constant_value_as_shape(new_shape) |
| if batch_shape_static.is_fully_defined(): |
| return np.int32(batch_shape_static.as_list()), batch_shape_static, [] |
| with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): |
| original_size = math_ops.reduce_prod(original_shape) |
| implicit_dim = math_ops.equal(new_shape, -1) |
| size_implicit_dim = ( |
| original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) |
| new_ndims = array_ops.shape(new_shape) |
| expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. |
| implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) |
| validations = [] if not validate else [ |
| check_ops.assert_rank( |
| original_shape, 1, message="Original shape must be a vector."), |
| check_ops.assert_rank( |
| new_shape, 1, message="New shape must be a vector."), |
| check_ops.assert_less_equal( |
| math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), |
| 1, |
| message="At most one dimension can be unknown."), |
| check_ops.assert_positive( |
| expanded_new_shape, message="Shape elements must be >=-1."), |
| check_ops.assert_equal( |
| math_ops.reduce_prod(expanded_new_shape), |
| original_size, |
| message="Shape sizes do not match."), |
| ] |
| return expanded_new_shape, batch_shape_static, validations |
| |
| |
| @deprecation.deprecated( |
| "2018-10-01", |
| "The TensorFlow Distributions library has moved to " |
| "TensorFlow Probability " |
| "(https://github.com/tensorflow/probability). You " |
| "should update all references to use `tfp.distributions` " |
| "instead of `tf.contrib.distributions`.", |
| warn_once=True) |
| def validate_init_args_statically(distribution, batch_shape): |
| """Helper to __init__ which makes or raises assertions.""" |
| if batch_shape.shape.ndims is not None: |
| if batch_shape.shape.ndims != 1: |
| raise ValueError("`batch_shape` must be a vector " |
| "(saw rank: {}).".format(batch_shape.shape.ndims)) |
| |
| batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) |
| batch_size_static = batch_shape_static.num_elements() |
| dist_batch_size_static = distribution.batch_shape.num_elements() |
| |
| if batch_size_static is not None and dist_batch_size_static is not None: |
| if batch_size_static != dist_batch_size_static: |
| raise ValueError("`batch_shape` size ({}) must match " |
| "`distribution.batch_shape` size ({}).".format( |
| batch_size_static, dist_batch_size_static)) |
| |
| if batch_shape_static.dims is not None: |
| if any( |
| dim.value is not None and dim.value < 1 for dim in batch_shape_static): |
| raise ValueError("`batch_shape` elements must be >=-1.") |