blob: 9e6a146f67796466202cc5074ddd25e4c2b083a6 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Monte Carlo Ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
layers = layers_lib
mc = monte_carlo_lib
class ExpectationImportanceSampleTest(test.TestCase):
def test_normal_integral_mean_and_var_correctly_estimated(self):
n = int(1e6)
with self.test_session():
mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X].
e_x = mc.expectation_importance_sampler(
f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
# Compute E_p[X^2].
e_x2 = mc.expectation_importance_sampler(
f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
# Convergence of mean is +- 0.003 if n = 100M
# Convergence of stddev is +- 0.00001 if n = 100M
self.assertEqual(p.batch_shape, e_x.get_shape())
self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
def test_multivariate_normal_prob_positive_product_of_components(self):
# Test that importance sampling can correctly estimate the probability that
# the product of components in a MultivariateNormal are > 0.
n = 1000
with self.test_session():
p = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.], scale_diag=[1.0, 1.0])
q = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.5], scale_diag=[3., 3.])
# Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
# Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
def indicator(x):
x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1])
return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)
prob = mc.expectation_importance_sampler(
f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
# Convergence is +- 0.004 if n = 100k.
self.assertEqual(p.batch_shape, prob.get_shape())
self.assertAllClose(0.5, prob.eval(), rtol=0.05)
class ExpectationImportanceSampleLogspaceTest(test.TestCase):
def test_normal_distribution_second_moment_estimated_correctly(self):
# Test the importance sampled estimate against an analytical result.
n = int(1e6)
with self.test_session():
mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X^2].
# Should equal [1, (2/3)^2]
log_e_x2 = mc.expectation_importance_sampler_logspace(
log_f=lambda x: math_ops.log(math_ops.square(x)),
log_p=p.log_prob,
sampling_dist_q=q,
n=n,
seed=42)
e_x2 = math_ops.exp(log_e_x2)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
self.assertEqual(p.batch_shape, e_x2.get_shape())
self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02)
class GetSamplesTest(test.TestCase):
"""Test the private method 'get_samples'."""
def test_raises_if_both_z_and_n_are_none(self):
with self.test_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = None
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
_get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_not_none(self):
with self.test_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(seed=42)
n = 1
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
_get_samples(dist, z, n, seed)
def test_returns_n_samples_if_n_provided(self):
with self.test_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = 10
seed = None
z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
def test_returns_z_if_z_provided(self):
with self.test_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(10, seed=42)
n = None
seed = None
z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
class ExpectationTest(test.TestCase):
def test_works_correctly(self):
with self.test_session() as sess:
x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
p = normal_lib.Normal(loc=x, scale=1.)
# We use the prefex "efx" to mean "E_p[f(X)]".
f = lambda u: u
efx_true = x
samples = p.sample(int(1e5), seed=1)
efx_reparam = mc.expectation(f, samples, p.log_prob)
efx_score = mc.expectation(f, samples, p.log_prob,
use_reparametrization=False)
[
efx_true_,
efx_reparam_,
efx_score_,
efx_true_grad_,
efx_reparam_grad_,
efx_score_grad_,
] = sess.run([
efx_true,
efx_reparam,
efx_score,
gradients_impl.gradients(efx_true, x)[0],
gradients_impl.gradients(efx_reparam, x)[0],
gradients_impl.gradients(efx_score, x)[0],
])
self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_)
self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.)
self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.)
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
np.isfinite(efx_reparam_grad_))
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
np.isfinite(efx_score_grad_))
self.assertAllClose(efx_true_grad_, efx_reparam_grad_,
rtol=0.03, atol=0.)
# Variance is too high to be meaningful, so we'll only check those which
# converge.
self.assertAllClose(efx_true_grad_[2:-2],
efx_score_grad_[2:-2],
rtol=0.05, atol=0.)
def test_docstring_example_normal(self):
with self.test_session() as sess:
num_draws = int(1e5)
mu_p = constant_op.constant(0.)
mu_q = constant_op.constant(1.)
p = normal_lib.Normal(loc=mu_p, scale=1.)
q = normal_lib.Normal(loc=mu_q, scale=2.)
exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q)
approx_kl_normal_normal = monte_carlo_lib.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparametrization=(p.reparameterization_type
== distribution_lib.FULLY_REPARAMETERIZED))
[exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([
exact_kl_normal_normal, approx_kl_normal_normal])
self.assertEqual(
True,
p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_,
rtol=0.01, atol=0.)
# Compare gradients. (Not present in `docstring`.)
gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0]
gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0]
[
gradp_exact_kl_normal_normal_,
gradq_exact_kl_normal_normal_,
gradp_approx_kl_normal_normal_,
gradq_approx_kl_normal_normal_,
] = sess.run([
gradp(exact_kl_normal_normal),
gradq(exact_kl_normal_normal),
gradp(approx_kl_normal_normal),
gradq(approx_kl_normal_normal),
])
self.assertAllClose(gradp_exact_kl_normal_normal_,
gradp_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
self.assertAllClose(gradq_exact_kl_normal_normal_,
gradq_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
if __name__ == '__main__':
test.main()