Add LinearOperatorTridiag for tridiagonal matrices.
- Removed restriction on tridiagonal_matmul and tridiagonal_solve for shapes to be statically known.
PiperOrigin-RevId: 285830396
Change-Id: I2291ee1111f2a41e658b085cbd402d2977acb939
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 7294462..eb732fb 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -383,6 +383,29 @@
)
cuda_py_test(
+ name = "linear_operator_tridiag_test",
+ size = "medium",
+ srcs = ["linear_operator_tridiag_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ ],
+ shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+ xla_enable_strict_auto_jit = True,
+)
+
+cuda_py_test(
name = "linear_operator_zeros_test",
size = "medium",
srcs = ["linear_operator_zeros_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py
new file mode 100644
index 0000000..d69f872
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py
@@ -0,0 +1,184 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables as variables_module
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_test_util
+from tensorflow.python.platform import test
+
+
+class _LinearOperatorTriDiagBase(object):
+
+ def build_operator_and_matrix(
+ self, build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=False,
+ diagonals_format='sequence'):
+ shape = list(build_info.shape)
+
+ # Ensure that diagonal has large enough values. If we generate a
+ # self adjoint PD matrix, then the diagonal will be dominant guaranteeing
+ # positive definitess.
+ diag = linear_operator_test_util.random_sign_uniform(
+ shape[:-1], minval=4., maxval=6., dtype=dtype)
+ # We'll truncate these depending on the format
+ subdiag = linear_operator_test_util.random_sign_uniform(
+ shape[:-1], minval=1., maxval=2., dtype=dtype)
+ if ensure_self_adjoint_and_pd:
+ # Abs on complex64 will result in a float32, so we cast back up.
+ diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
+ # The first element of subdiag is ignored. We'll add a dummy element
+ # to superdiag to pad it.
+ superdiag = math_ops.conj(subdiag)
+ superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1)
+ else:
+ superdiag = linear_operator_test_util.random_sign_uniform(
+ shape[:-1], minval=1., maxval=2., dtype=dtype)
+
+ matrix_diagonals = array_ops.stack(
+ [superdiag, diag, subdiag], axis=-2)
+ matrix = gen_array_ops.matrix_diag_v3(
+ matrix_diagonals,
+ k=(-1, 1),
+ num_rows=-1,
+ num_cols=-1,
+ align='LEFT_RIGHT',
+ padding_value=0.)
+
+ if diagonals_format == 'sequence':
+ diagonals = [superdiag, diag, subdiag]
+ elif diagonals_format == 'compact':
+ diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2)
+ elif diagonals_format == 'matrix':
+ diagonals = matrix
+
+ lin_op_diagonals = diagonals
+
+ if use_placeholder:
+ if diagonals_format == 'sequence':
+ lin_op_diagonals = [array_ops.placeholder_with_default(
+ d, shape=None) for d in lin_op_diagonals]
+ else:
+ lin_op_diagonals = array_ops.placeholder_with_default(
+ lin_op_diagonals, shape=None)
+
+ operator = linalg_lib.LinearOperatorTridiag(
+ diagonals=lin_op_diagonals,
+ diagonals_format=diagonals_format,
+ is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
+ is_positive_definite=True if ensure_self_adjoint_and_pd else None)
+ return operator, matrix
+
+ @staticmethod
+ def operator_shapes_infos():
+ shape_info = linear_operator_test_util.OperatorShapesInfo
+ # non-batch operators (n, n) and batch operators.
+ return [
+ shape_info((3, 3)),
+ shape_info((1, 6, 6)),
+ shape_info((3, 4, 4)),
+ shape_info((2, 1, 3, 3))
+ ]
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagCompactTest(
+ _LinearOperatorTriDiagBase,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+ def operator_and_matrix(
+ self, build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=False):
+ return self.build_operator_and_matrix(
+ build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+ diagonals_format='compact')
+
+ def test_tape_safe(self):
+ diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]])
+ operator = linalg_lib.LinearOperatorTridiag(
+ diag, diagonals_format='compact')
+ self.check_tape_safe(operator)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagSequenceTest(
+ _LinearOperatorTriDiagBase,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+ def operator_and_matrix(
+ self, build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=False):
+ return self.build_operator_and_matrix(
+ build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+ diagonals_format='sequence')
+
+ def test_tape_safe(self):
+ diagonals = [
+ variables_module.Variable([3., 6., 2.]),
+ variables_module.Variable([2., 4., 2.]),
+ variables_module.Variable([5., 1., 2.])]
+ operator = linalg_lib.LinearOperatorTridiag(
+ diagonals, diagonals_format='sequence')
+ # Skip the diagonal part and trace since this only dependent on the
+ # middle variable. We test this below.
+ self.check_tape_safe(operator, skip_options=['diag_part', 'trace'])
+
+ diagonals = [
+ [3., 6., 2.],
+ variables_module.Variable([2., 4., 2.]),
+ [5., 1., 2.]
+ ]
+ operator = linalg_lib.LinearOperatorTridiag(
+ diagonals, diagonals_format='sequence')
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagMatrixTest(
+ _LinearOperatorTriDiagBase,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+ def operator_and_matrix(
+ self, build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=False):
+ return self.build_operator_and_matrix(
+ build_info, dtype, use_placeholder,
+ ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+ diagonals_format='matrix')
+
+ def test_tape_safe(self):
+ matrix = variables_module.Variable([[3., 2., 0.], [1., 6., 4.], [0., 2, 2]])
+ operator = linalg_lib.LinearOperatorTridiag(
+ matrix, diagonals_format='matrix')
+ self.check_tape_safe(operator)
+
+
+if __name__ == '__main__':
+ linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
+ linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
+ linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
+ test.main()
diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py
index 2bfbb37..94d85cb 100644
--- a/tensorflow/python/ops/linalg/linalg.py
+++ b/tensorflow/python/ops/linalg/linalg.py
@@ -39,6 +39,7 @@
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
from tensorflow.python.ops.linalg.linear_operator_permutation import *
from tensorflow.python.ops.linalg.linear_operator_toeplitz import *
+from tensorflow.python.ops.linalg.linear_operator_tridiag import *
from tensorflow.python.ops.linalg.linear_operator_zeros import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 18d2296..3412486 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -28,10 +28,8 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import manip_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
@@ -486,14 +484,8 @@
'Expected last two dimensions of diagonals to be same, got {} and {}'
.format(m1, m2))
m = m1 or m2
- diagonals = gen_array_ops.matrix_diag_part_v2(
- diagonals, k=(-1, 1), padding_value=0.)
- # matrix_diag_part pads at the end. Because the subdiagonal has the
- # convention of having the padding in the front, we need to rotate the last
- # Tensor.
- superdiag, d, subdiag = array_ops.unstack(diagonals, num=3, axis=-2)
- subdiag = manip_ops.roll(subdiag, shift=1, axis=-1)
- diagonals = array_ops.stack((superdiag, d, subdiag), axis=-2)
+ diagonals = array_ops.matrix_diag_part(
+ diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
return _tridiagonal_solve_compact_format(
diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name)
@@ -614,20 +606,11 @@
raise ValueError(
'Expected last two dimensions of diagonals to be same, got {} and {}'
.format(m1, m2))
-
- maindiag = array_ops.matrix_diag_part(diagonals)
- superdiag = gen_array_ops.matrix_diag_part_v2(
- diagonals, k=1, padding_value=0.)
- superdiag = array_ops.concat(
- [superdiag,
- array_ops.zeros_like(
- superdiag[..., 0])[..., array_ops.newaxis]],
- axis=-1)
- subdiag = gen_array_ops.matrix_diag_part_v2(
- diagonals, k=-1, padding_value=0.)
- subdiag = array_ops.concat([
- array_ops.zeros_like(subdiag[..., 0])[..., array_ops.newaxis],
- subdiag], axis=-1)
+ diags = array_ops.matrix_diag_part(
+ diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
+ superdiag = diags[..., 0, :]
+ maindiag = diags[..., 1, :]
+ subdiag = diags[..., 2, :]
else:
raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
diff --git a/tensorflow/python/ops/linalg/linear_operator_tridiag.py b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
new file mode 100644
index 0000000..4227478
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
@@ -0,0 +1,373 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""`LinearOperator` acting like a tridiagonal matrix."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg_impl as linalg
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.util.tf_export import tf_export
+
+__all__ = ['LinearOperatorTridiag',]
+
+_COMPACT = 'compact'
+_MATRIX = 'matrix'
+_SEQUENCE = 'sequence'
+_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE})
+
+
+@tf_export('linalg.LinearOperatorTridiag')
+class LinearOperatorTridiag(linear_operator.LinearOperator):
+ """`LinearOperator` acting like a [batch] square tridiagonal matrix.
+
+ This operator acts like a [batch] square tridiagonal matrix `A` with shape
+ `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
+ batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ an `N x M` matrix. This matrix `A` is not materialized, but for
+ purposes of broadcasting this shape will be relevant.
+
+ Example usage:
+
+ Create a 3 x 3 tridiagonal linear operator.
+
+ >>> superdiag = [3., 4., 5.]
+ >>> diag = [1., -1., 2.]
+ >>> subdiag = [6., 7., 8]
+ >>> operator = tf.linalg.LinearOperatorTridiag(
+ ... [superdiag, diag, subdiag],
+ ... diagonals_format='sequence')
+ >>> operator.to_dense()
+ <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
+ array([[ 1., 3., 0.],
+ [ 7., -1., 4.],
+ [ 0., 8., 2.]], dtype=float32)>
+ >>> operator.shape
+ TensorShape([3, 3])
+
+ Scalar Tensor output.
+
+ >>> operator.log_abs_determinant()
+ <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333>
+
+ Create a [2, 3] batch of 4 x 4 linear operators.
+
+ >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4])
+ >>> operator = tf.linalg.LinearOperatorTridiag(
+ ... diagonals,
+ ... diagonals_format='compact')
+
+ Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible
+ since the batch dimensions, [2, 1], are broadcast to
+ operator.batch_shape = [2, 3].
+
+ >>> y = tf.random.normal(shape=[2, 1, 4, 2])
+ >>> x = operator.solve(y)
+ >>> x
+ <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=...,
+ dtype=float32)>
+
+ #### Shape compatibility
+
+ This operator acts on [batch] matrix with compatible shape.
+ `x` is a batch matrix with compatible shape for `matmul` and `solve` if
+
+ ```
+ operator.shape = [B1,...,Bb] + [N, N], with b >= 0
+ x.shape = [C1,...,Cc] + [N, R],
+ and [C1,...,Cc] broadcasts with [B1,...,Bb].
+ ```
+
+ #### Performance
+
+ Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`,
+ and `x.shape = [N, R]`. Then
+
+ * `operator.matmul(x)` will take O(N * R) time.
+ * `operator.solve(x)` will take O(N * R) time.
+
+ If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+ `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ #### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint, positive_definite, square`.
+ These have the following meaning:
+
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ diagonals,
+ diagonals_format=_COMPACT,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=None,
+ name='LinearOperatorTridiag'):
+ r"""Initialize a `LinearOperatorTridiag`.
+
+ Args:
+ diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`.
+
+ If `diagonals_format=sequence`, this is a list of three `Tensor`'s each
+ with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the
+ superdiagonal, diagonal and subdiagonal in that order. Note the
+ superdiagonal is padded with an element in the last position, and the
+ subdiagonal is padded with an element in the front.
+
+ If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped
+ `Tensor` representing the full tridiagonal matrix.
+
+ If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped
+ `Tensor` with the second to last dimension indexing the
+ superdiagonal, diagonal and subdiagonal in that order. Note the
+ superdiagonal is padded with an element in the last position, and the
+ subdiagonal is padded with an element in the front.
+
+ In every case, these `Tensor`s are all floating dtype.
+ diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
+ `compact`.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. If `diag.dtype` is real, this is auto-set to `True`.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the quadratic form `x^H A x` has positive real part for all
+ nonzero `x`. Note that we do not require the operator to be
+ self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name for this `LinearOperator`.
+
+ Raises:
+ TypeError: If `diag.dtype` is not an allowed type.
+ ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
+ """
+
+ with ops.name_scope(name, values=[diagonals]):
+ if diagonals_format not in _DIAGONAL_FORMATS:
+ raise ValueError(
+ 'Diagonals Format must be one of compact, matrix, sequence'
+ ', got : {}'.format(diagonals_format))
+ if diagonals_format == _SEQUENCE:
+ self._diagonals = [linear_operator_util.convert_nonref_to_tensor(
+ d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)]
+ dtype = self._diagonals[0].dtype
+ else:
+ self._diagonals = linear_operator_util.convert_nonref_to_tensor(
+ diagonals, name='diagonals')
+ dtype = self._diagonals.dtype
+ self._diagonals_format = diagonals_format
+
+ super(LinearOperatorTridiag, self).__init__(
+ dtype=dtype,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+ def _shape(self):
+ if self.diagonals_format == _MATRIX:
+ return self.diagonals.shape
+ if self.diagonals_format == _COMPACT:
+ # Remove the second to last dimension that contains the value 3.
+ d_shape = self.diagonals.shape[:-2].concatenate(
+ self.diagonals.shape[-1])
+ else:
+ broadcast_shape = array_ops.broadcast_static_shape(
+ self.diagonals[0].shape[:-1],
+ self.diagonals[1].shape[:-1])
+ broadcast_shape = array_ops.broadcast_static_shape(
+ broadcast_shape,
+ self.diagonals[2].shape[:-1])
+ d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1])
+ return d_shape.concatenate(d_shape[-1])
+
+ def _shape_tensor(self, diagonals=None):
+ diagonals = diagonals if diagonals is not None else self.diagonals
+ if self.diagonals_format == _MATRIX:
+ return array_ops.shape(diagonals)
+ if self.diagonals_format == _COMPACT:
+ d_shape = array_ops.shape(diagonals[..., 0, :])
+ else:
+ broadcast_shape = array_ops.broadcast_dynamic_shape(
+ array_ops.shape(self.diagonals[0])[:-1],
+ array_ops.shape(self.diagonals[1])[:-1])
+ broadcast_shape = array_ops.broadcast_dynamic_shape(
+ broadcast_shape,
+ array_ops.shape(self.diagonals[2])[:-1])
+ d_shape = array_ops.concat(
+ [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0)
+ return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1)
+
+ def _assert_self_adjoint(self):
+ # Check the diagonal has non-zero imaginary, and the super and subdiagonals
+ # are conjugate.
+
+ asserts = []
+ diag_message = (
+ 'This tridiagonal operator contained non-zero '
+ 'imaginary values on the diagonal.')
+ off_diag_message = (
+ 'This tridiagonal operator has non-conjugate '
+ 'subdiagonal and superdiagonal.')
+
+ if self.diagonals_format == _MATRIX:
+ asserts += [check_ops.assert_equal(
+ self.diagonals, linalg.adjoint(self.diagonals),
+ message='Matrix was not equal to its adjoint.')]
+ elif self.diagonals_format == _COMPACT:
+ diagonals = ops.convert_to_tensor(self.diagonals)
+ asserts += [linear_operator_util.assert_zero_imag_part(
+ diagonals[..., 1, :], message=diag_message)]
+ # Roll the subdiagonal so the shifted argument is at the end.
+ subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1)
+ asserts += [check_ops.assert_equal(
+ math_ops.conj(subdiag[..., :-1]),
+ diagonals[..., 0, :-1],
+ message=off_diag_message)]
+ else:
+ asserts += [linear_operator_util.assert_zero_imag_part(
+ self.diagonals[1], message=diag_message)]
+ subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1)
+ asserts += [check_ops.assert_equal(
+ math_ops.conj(subdiag[..., :-1]),
+ self.diagonals[0][..., :-1],
+ message=off_diag_message)]
+ return control_flow_ops.group(asserts)
+
+ def _construct_adjoint_diagonals(self, diagonals):
+ # Constructs adjoint tridiagonal matrix from diagonals.
+ if self.diagonals_format == _SEQUENCE:
+ diagonals = [math_ops.conj(d) for d in reversed(diagonals)]
+ # The subdiag and the superdiag swap places, so we need to shift the
+ # padding argument.
+ diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1)
+ diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1)
+ return diagonals
+ elif self.diagonals_format == _MATRIX:
+ return linalg.adjoint(diagonals)
+ else:
+ diagonals = math_ops.conj(diagonals)
+ superdiag, diag, subdiag = array_ops.unstack(
+ diagonals, num=3, axis=-2)
+ # The subdiag and the superdiag swap places, so we need
+ # to shift all arguments.
+ new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1)
+ new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1)
+ return array_ops.stack([new_superdiag, diag, new_subdiag], axis=-2)
+
+ def _matmul(self, x, adjoint=False, adjoint_arg=False):
+ diagonals = self.diagonals
+ if adjoint:
+ diagonals = self._construct_adjoint_diagonals(diagonals)
+ x = linalg.adjoint(x) if adjoint_arg else x
+ return linalg.tridiagonal_matmul(
+ diagonals, x,
+ diagonals_format=self.diagonals_format)
+
+ def _solve(self, rhs, adjoint=False, adjoint_arg=False):
+ diagonals = self.diagonals
+ if adjoint:
+ diagonals = self._construct_adjoint_diagonals(diagonals)
+
+ # TODO(b/144860784): Remove the broadcasting code below once
+ # tridiagonal_solve broadcasts.
+
+ rhs_shape = array_ops.shape(rhs)
+ k = self._shape_tensor(diagonals)[-1]
+ broadcast_shape = array_ops.broadcast_dynamic_shape(
+ self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
+ rhs = array_ops.broadcast_to(
+ rhs, array_ops.concat(
+ [broadcast_shape, rhs_shape[-2:]], axis=-1))
+ if self.diagonals_format == _MATRIX:
+ diagonals = array_ops.broadcast_to(
+ diagonals, array_ops.concat(
+ [broadcast_shape, [k, k]], axis=-1))
+ elif self.diagonals_format == _COMPACT:
+ diagonals = array_ops.broadcast_to(
+ diagonals, array_ops.concat(
+ [broadcast_shape, [3, k]], axis=-1))
+ else:
+ diagonals = [
+ array_ops.broadcast_to(d, array_ops.concat(
+ [broadcast_shape, [k]], axis=-1)) for d in diagonals]
+
+ y = linalg.tridiagonal_solve(
+ diagonals, rhs,
+ diagonals_format=self.diagonals_format,
+ transpose_rhs=adjoint_arg,
+ conjugate_rhs=adjoint_arg)
+ return y
+
+ def _diag_part(self):
+ if self.diagonals_format == _MATRIX:
+ return array_ops.matrix_diag_part(self.diagonals)
+ elif self.diagonals_format == _SEQUENCE:
+ diagonal = self.diagonals[1]
+ return array_ops.broadcast_to(
+ diagonal, self.shape_tensor()[:-1])
+ else:
+ return self.diagonals[..., 1, :]
+
+ def _to_dense(self):
+ if self.diagonals_format == _MATRIX:
+ return self.diagonals
+
+ if self.diagonals_format == _COMPACT:
+ return gen_array_ops.matrix_diag_v3(
+ self.diagonals,
+ k=(-1, 1),
+ num_rows=-1,
+ num_cols=-1,
+ align='LEFT_RIGHT',
+ padding_value=0.)
+
+ diagonals = [ops.convert_to_tensor(d) for d in self.diagonals]
+ diagonals = array_ops.stack(diagonals, axis=-2)
+
+ return gen_array_ops.matrix_diag_v3(
+ diagonals,
+ k=(-1, 1),
+ num_rows=-1,
+ num_cols=-1,
+ align='LEFT_RIGHT',
+ padding_value=0.)
+
+ @property
+ def diagonals(self):
+ return self._diagonals
+
+ @property
+ def diagonals_format(self):
+ return self._diagonals_format
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt
new file mode 100644
index 0000000..0609904
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt
@@ -0,0 +1,185 @@
+path: "tensorflow.linalg.LinearOperatorTridiag"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_tridiag.LinearOperatorTridiag\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+ is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+ is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "diagonals"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "diagonals_format"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name_scope"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "submodules"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'diagonals\', \'diagonals_format\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'compact\', \'None\', \'None\', \'None\', \'None\', \'LinearOperatorTridiag\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "cholesky"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+ }
+ member_method {
+ name: "cond"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cond\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "eigvals"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'eigvals\'], "
+ }
+ member_method {
+ name: "inverse"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+ member_method {
+ name: "with_name_scope"
+ argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
index 632400c..264294d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
@@ -73,6 +73,10 @@
mtype: "<type \'type\'>"
}
member {
+ name: "LinearOperatorTridiag"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearOperatorZeros"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt
new file mode 100644
index 0000000..0609904
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt
@@ -0,0 +1,185 @@
+path: "tensorflow.linalg.LinearOperatorTridiag"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_tridiag.LinearOperatorTridiag\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+ is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+ is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "diagonals"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "diagonals_format"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name_scope"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "submodules"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'diagonals\', \'diagonals_format\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'compact\', \'None\', \'None\', \'None\', \'None\', \'LinearOperatorTridiag\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "cholesky"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+ }
+ member_method {
+ name: "cond"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cond\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "eigvals"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'eigvals\'], "
+ }
+ member_method {
+ name: "inverse"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+ member_method {
+ name: "with_name_scope"
+ argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index 041041f..7d6f02a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -73,6 +73,10 @@
mtype: "<type \'type\'>"
}
member {
+ name: "LinearOperatorTridiag"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearOperatorZeros"
mtype: "<type \'type\'>"
}