blob: 5c89607c1da4ccb9f66f2637e974ec634aa5617e [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
linalg = linalg_lib
rng = np.random.RandomState(0)
class BaseLinearOperatorLowRankUpdatetest(object):
"""Base test for this type of operator."""
# Subclasses should set these attributes to either True or False.
# If True, A = L + UDV^H
# If False, A = L + UV^H or A = L + UU^H, depending on _use_v.
_use_diag_update = None
# If True, diag is > 0, which means D is symmetric positive definite.
_is_diag_update_positive = None
# If True, A = L + UDV^H
# If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update
_use_v = None
@staticmethod
def operator_shapes_infos():
shape_info = linear_operator_test_util.OperatorShapesInfo
# Previously we had a (2, 10, 10) shape at the end. We did this to test the
# inversion and determinant lemmas on not-tiny matrices, since these are
# known to have stability issues. This resulted in test timeouts, so this
# shape has been removed, but rest assured, the tests did pass.
return [
shape_info((0, 0)),
shape_info((1, 1)),
shape_info((1, 3, 3)),
shape_info((3, 4, 4)),
shape_info((2, 1, 4, 4))]
def _gen_positive_diag(self, dtype, diag_shape):
if dtype.is_complex:
diag = linear_operator_test_util.random_uniform(
diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32)
return math_ops.cast(diag, dtype=dtype)
return linear_operator_test_util.random_uniform(
diag_shape, minval=1e-4, maxval=1., dtype=dtype)
def operator_and_matrix(self, shape_info, dtype, use_placeholder,
ensure_self_adjoint_and_pd=False):
# Recall A = L + UDV^H
shape = list(shape_info.shape)
diag_shape = shape[:-1]
k = shape[-2] // 2 + 1
u_perturbation_shape = shape[:-1] + [k]
diag_update_shape = shape[:-2] + [k]
# base_operator L will be a symmetric positive definite diagonal linear
# operator, with condition number as high as 1e4.
base_diag = self._gen_positive_diag(dtype, diag_shape)
lin_op_base_diag = base_diag
# U
u = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
lin_op_u = u
# V
v = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
lin_op_v = v
# D
if self._is_diag_update_positive or ensure_self_adjoint_and_pd:
diag_update = self._gen_positive_diag(dtype, diag_update_shape)
else:
diag_update = linear_operator_test_util.random_normal(
diag_update_shape, stddev=1e-4, dtype=dtype)
lin_op_diag_update = diag_update
if use_placeholder:
lin_op_base_diag = array_ops.placeholder_with_default(
base_diag, shape=None)
lin_op_u = array_ops.placeholder_with_default(u, shape=None)
lin_op_v = array_ops.placeholder_with_default(v, shape=None)
lin_op_diag_update = array_ops.placeholder_with_default(
diag_update, shape=None)
base_operator = linalg.LinearOperatorDiag(
lin_op_base_diag,
is_positive_definite=True,
is_self_adjoint=True)
operator = linalg.LinearOperatorLowRankUpdate(
base_operator,
lin_op_u,
v=lin_op_v if self._use_v else None,
diag_update=lin_op_diag_update if self._use_diag_update else None,
is_diag_update_positive=self._is_diag_update_positive)
# The matrix representing L
base_diag_mat = array_ops.matrix_diag(base_diag)
# The matrix representing D
diag_update_mat = array_ops.matrix_diag(diag_update)
# Set up mat as some variant of A = L + UDV^H
if self._use_v and self._use_diag_update:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
elif self._use_v:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
elif self._use_diag_update:
# In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
expect_use_cholesky = self._is_diag_update_positive
matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
else:
# In this case, we have L + UU^H, which is PD since L > 0.
expect_use_cholesky = True
matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
if expect_use_cholesky:
self.assertTrue(operator._use_cholesky)
else:
self.assertFalse(operator._use_cholesky)
return operator, matrix
class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
BaseLinearOperatorLowRankUpdatetest,
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
_use_diag_update = True
_is_diag_update_positive = True
_use_v = False
def setUp(self):
# Decrease tolerance since we are testing with condition numbers as high as
# 1e4.
self._atol[dtypes.float32] = 1e-5
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
BaseLinearOperatorLowRankUpdatetest,
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky."""
@staticmethod
def tests_to_skip():
return ["cholesky"]
_use_diag_update = True
_is_diag_update_positive = False
_use_v = False
def setUp(self):
# Decrease tolerance since we are testing with condition numbers as high as
# 1e4. This class does not use Cholesky, and thus needs even looser
# tolerance.
self._atol[dtypes.float32] = 1e-4
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
self._rtol[dtypes.complex64] = 2e-4
class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
BaseLinearOperatorLowRankUpdatetest,
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky."""
_use_diag_update = False
_is_diag_update_positive = None
_use_v = False
def setUp(self):
# Decrease tolerance since we are testing with condition numbers as high as
# 1e4.
self._atol[dtypes.float32] = 1e-5
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
BaseLinearOperatorLowRankUpdatetest,
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky."""
@staticmethod
def tests_to_skip():
return ["cholesky"]
_use_diag_update = False
_is_diag_update_positive = None
_use_v = True
def setUp(self):
# Decrease tolerance since we are testing with condition numbers as high as
# 1e4. This class does not use Cholesky, and thus needs even looser
# tolerance.
self._atol[dtypes.float32] = 1e-4
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
self._atol[dtypes.complex64] = 1e-5
self._rtol[dtypes.complex64] = 2e-4
class LinearOperatorLowRankUpdatetestWithDiagNotSquare(
BaseLinearOperatorLowRankUpdatetest,
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
_use_diag_update = True
_is_diag_update_positive = True
_use_v = True
class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
"""Test that the operator's shape is the broadcast of arguments."""
def test_static_shape_broadcasts_up_from_operator_to_other_args(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3)
u = array_ops.ones(shape=[2, 3, 2])
diag = array_ops.ones(shape=[2, 2])
operator = linalg.LinearOperatorLowRankUpdate(base_operator, u, diag)
# domain_dimension is 3
self.assertAllEqual([2, 3, 3], operator.shape)
self.assertAllEqual([2, 3, 3], self.evaluate(operator.to_dense()).shape)
@test_util.run_deprecated_v1
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
num_rows_ph = array_ops.placeholder(dtypes.int32)
base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
u_shape_ph = array_ops.placeholder(dtypes.int32)
u = array_ops.ones(shape=u_shape_ph)
operator = linalg.LinearOperatorLowRankUpdate(base_operator, u)
feed_dict = {
num_rows_ph: 3,
u_shape_ph: [2, 3, 2], # batch_shape = [2]
}
with self.cached_session():
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], shape_tensor)
dense = operator.to_dense().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], dense.shape)
def test_u_and_v_incompatible_batch_shape_raises(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
u = rng.rand(5, 3, 2)
v = rng.rand(4, 3, 2)
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
linalg.LinearOperatorLowRankUpdate(base_operator, u=u, v=v)
def test_u_and_base_operator_incompatible_batch_shape_raises(self):
base_operator = linalg.LinearOperatorIdentity(
num_rows=3, batch_shape=[4], dtype=np.float64)
u = rng.rand(5, 3, 2)
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
def test_u_and_base_operator_incompatible_domain_dimension(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
u = rng.rand(5, 4, 2)
with self.assertRaisesRegexp(ValueError, "not compatible"):
linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
def test_u_and_diag_incompatible_low_rank_raises(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
u = rng.rand(5, 3, 2)
diag = rng.rand(5, 4) # Last dimension should be 2
with self.assertRaisesRegexp(ValueError, "not compatible"):
linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
def test_diag_incompatible_batch_shape_raises(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
u = rng.rand(5, 3, 2)
diag = rng.rand(4, 2) # First dimension should be 5
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
if __name__ == "__main__":
linear_operator_test_util.add_tests(
LinearOperatorLowRankUpdatetestWithDiagUseCholesky)
linear_operator_test_util.add_tests(
LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky)
linear_operator_test_util.add_tests(
LinearOperatorLowRankUpdatetestNoDiagUseCholesky)
linear_operator_test_util.add_tests(
LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky)
linear_operator_test_util.add_tests(
LinearOperatorLowRankUpdatetestWithDiagNotSquare)
test.main()