blob: 9b9b3ce2dd9d42286d2d9657d5f00de8445261f0 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib import distributions
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
tfd = distributions
class DistributionTest(test.TestCase):
def testParamShapesAndFromParams(self):
classes = [
tfd.Normal,
tfd.Bernoulli,
tfd.Beta,
tfd.Chi2,
tfd.Exponential,
tfd.Gamma,
tfd.InverseGamma,
tfd.Laplace,
tfd.StudentT,
tfd.Uniform,
]
sample_shapes = [(), (10,), (10, 20, 30)]
with self.cached_session():
for cls in classes:
for sample_shape in sample_shapes:
param_shapes = cls.param_shapes(sample_shape)
params = dict([(name, random_ops.random_normal(shape))
for name, shape in param_shapes.items()])
dist = cls(**params)
self.assertAllEqual(sample_shape,
array_ops.shape(dist.sample()).eval())
dist_copy = dist.copy()
self.assertAllEqual(sample_shape,
array_ops.shape(dist_copy.sample()).eval())
self.assertEqual(dist.parameters, dist_copy.parameters)
def testCopyExtraArgs(self):
with self.cached_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
normal = tfd.Normal(loc=1., scale=2., validate_args=True)
self.assertEqual(normal.parameters, normal.copy().parameters)
wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]],
validate_args=True)
self.assertEqual(wishart.parameters, wishart.copy().parameters)
def testCopyOverride(self):
with self.cached_session():
normal = tfd.Normal(loc=1., scale=2., validate_args=True)
unused_normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
copy_params = normal.copy(validate_args=False).parameters.copy()
self.assertNotEqual(
base_params.pop("validate_args"), copy_params.pop("validate_args"))
self.assertEqual(base_params, copy_params)
def testIsScalar(self):
with self.cached_session():
mu = 1.
sigma = 2.
normal = tfd.Normal(mu, sigma, validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch()))
normal = tfd.Normal([mu], [sigma], validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch()))
mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch()))
mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch()))
# We now test every codepath within the underlying is_scalar_helper
# function.
# Test case 1, 2.
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
# None would fire an exception were it actually executed.
self.assertTrue(normal._is_scalar_helper(x.get_shape(), lambda: None))
self.assertTrue(
normal._is_scalar_helper(tensor_shape.TensorShape(None),
lambda: array_ops.shape(x)))
x = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
# None would fire an exception were it actually executed.
self.assertFalse(normal._is_scalar_helper(x.get_shape(), lambda: None))
self.assertFalse(
normal._is_scalar_helper(tensor_shape.TensorShape(None),
lambda: array_ops.shape(x)))
# Test case 3.
x = array_ops.placeholder(dtype=dtypes.int32)
is_scalar = normal._is_scalar_helper(x.get_shape(),
lambda: array_ops.shape(x))
self.assertTrue(is_scalar.eval(feed_dict={x: 1}))
self.assertFalse(is_scalar.eval(feed_dict={x: [1]}))
def _GetFakeDistribution(self):
class FakeDistribution(tfd.Distribution):
"""Fake Distribution for testing _set_sample_static_shape."""
def __init__(self, batch_shape=None, event_shape=None):
self._static_batch_shape = tensor_shape.TensorShape(batch_shape)
self._static_event_shape = tensor_shape.TensorShape(event_shape)
super(FakeDistribution, self).__init__(
dtype=dtypes.float32,
reparameterization_type=distributions.NOT_REPARAMETERIZED,
validate_args=True,
allow_nan_stats=True,
name="DummyDistribution")
def _batch_shape(self):
return self._static_batch_shape
def _event_shape(self):
return self._static_event_shape
return FakeDistribution
def testSampleShapeHints(self):
fake_distribution = self._GetFakeDistribution()
with self.cached_session():
# Make a new session since we're playing with static shapes. [And below.]
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[2, 3], event_shape=[5])
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
# We use as_list since TensorShape comparison does not work correctly for
# unknown values, ie, Dimension(None).
self.assertAllEqual([6, 7, 2, 3, 5], y.get_shape().as_list())
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=[5])
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertAllEqual([6, 7, None, 3, 5], y.get_shape().as_list())
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=[None])
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertAllEqual([6, 7, None, 3, None], y.get_shape().as_list())
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=None, event_shape=None)
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertTrue(y.get_shape().ndims is None)
with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=None)
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertTrue(y.get_shape().ndims is None)
def testNameScopeWorksCorrectly(self):
x = tfd.Normal(loc=0., scale=1., name="x")
x_duplicate = tfd.Normal(loc=0., scale=1., name="x")
with ops.name_scope("y") as name:
y = tfd.Bernoulli(logits=0., name=name)
x_sample = x.sample(name="custom_sample")
x_sample_duplicate = x.sample(name="custom_sample")
x_log_prob = x.log_prob(0., name="custom_log_prob")
x_duplicate_sample = x_duplicate.sample(name="custom_sample")
self.assertEqual(x.name, "x/")
self.assertEqual(x_duplicate.name, "x_1/")
self.assertEqual(y.name, "y/")
self.assertTrue(x_sample.name.startswith("x/custom_sample"))
self.assertTrue(x_sample_duplicate.name.startswith("x/custom_sample_1"))
self.assertTrue(x_log_prob.name.startswith("x/custom_log_prob"))
self.assertTrue(x_duplicate_sample.name.startswith(
"x_1/custom_sample"))
def testStrWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
("tfp.distributions.Normal("
"\"Normal/\", "
"batch_shape=(), "
"event_shape=(), "
"dtype=float16)"), # Got the dtype right.
str(normal))
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
("tfp.distributions.Chi2("
"\"silly/\", " # What a silly name that is!
"batch_shape=(2,), "
"event_shape=(), "
"dtype=float32)"),
str(chi2))
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
("tfp.distributions.Exponential(\"Exponential/\", "
# No batch shape.
"event_shape=(), "
"dtype=float32)"),
str(exp))
def testStrWorksCorrectlyMultivariate(self):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
("tfp.distributions.MultivariateNormalDiag("
"\"MVN/\", "
"batch_shape=(2,), "
"event_shape=(2,), "
"dtype=float64)"),
str(mvn_static))
mvn_dynamic = tfd.MultivariateNormalDiag(
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
("tfp.distributions.MultivariateNormalDiag("
"\"MVN2/\", "
"batch_shape=(?,), " # Partially known.
"event_shape=(3,), "
"dtype=float32)"),
str(mvn_dynamic))
def testReprWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
("<tfp.distributions.Normal"
" 'Normal/'"
" batch_shape=()"
" event_shape=()"
" dtype=float16>"), # Got the dtype right.
repr(normal))
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
("<tfp.distributions.Chi2"
" 'silly/'" # What a silly name that is!
" batch_shape=(2,)"
" event_shape=()"
" dtype=float32>"),
repr(chi2))
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
("<tfp.distributions.Exponential"
" 'Exponential/'"
" batch_shape=<unknown>"
" event_shape=()"
" dtype=float32>"),
repr(exp))
def testReprWorksCorrectlyMultivariate(self):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
("<tfp.distributions.MultivariateNormalDiag"
" 'MVN/'"
" batch_shape=(2,)"
" event_shape=(2,)"
" dtype=float64>"),
repr(mvn_static))
mvn_dynamic = tfd.MultivariateNormalDiag(
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
("<tfp.distributions.MultivariateNormalDiag"
" 'MVN2/'"
" batch_shape=(?,)" # Partially known.
" event_shape=(3,)"
" dtype=float32>"),
repr(mvn_dynamic))
if __name__ == "__main__":
test.main()