Add LinearOperatorTridiag for tridiagonal matrices.
  - Removed restriction on tridiagonal_matmul and tridiagonal_solve for shapes to be statically known.

PiperOrigin-RevId: 285830396
Change-Id: I2291ee1111f2a41e658b085cbd402d2977acb939
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 7294462..eb732fb 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -383,6 +383,29 @@
 )
 
 cuda_py_test(
+    name = "linear_operator_tridiag_test",
+    size = "medium",
+    srcs = ["linear_operator_tridiag_test.py"],
+    additional_deps = [
+        "//tensorflow/python/ops/linalg",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:random_ops",
+    ],
+    shard_count = 5,
+    tags = [
+        "noasan",
+        "optonly",
+    ],
+    xla_enable_strict_auto_jit = True,
+)
+
+cuda_py_test(
     name = "linear_operator_zeros_test",
     size = "medium",
     srcs = ["linear_operator_zeros_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py
new file mode 100644
index 0000000..d69f872
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py
@@ -0,0 +1,184 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables as variables_module
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_test_util
+from tensorflow.python.platform import test
+
+
+class _LinearOperatorTriDiagBase(object):
+
+  def build_operator_and_matrix(
+      self, build_info, dtype, use_placeholder,
+      ensure_self_adjoint_and_pd=False,
+      diagonals_format='sequence'):
+    shape = list(build_info.shape)
+
+    # Ensure that diagonal has large enough values. If we generate a
+    # self adjoint PD matrix, then the diagonal will be dominant guaranteeing
+    # positive definitess.
+    diag = linear_operator_test_util.random_sign_uniform(
+        shape[:-1], minval=4., maxval=6., dtype=dtype)
+    # We'll truncate these depending on the format
+    subdiag = linear_operator_test_util.random_sign_uniform(
+        shape[:-1], minval=1., maxval=2., dtype=dtype)
+    if ensure_self_adjoint_and_pd:
+      # Abs on complex64 will result in a float32, so we cast back up.
+      diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
+      # The first element of subdiag is ignored. We'll add a dummy element
+      # to superdiag to pad it.
+      superdiag = math_ops.conj(subdiag)
+      superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1)
+    else:
+      superdiag = linear_operator_test_util.random_sign_uniform(
+          shape[:-1], minval=1., maxval=2., dtype=dtype)
+
+    matrix_diagonals = array_ops.stack(
+        [superdiag, diag, subdiag], axis=-2)
+    matrix = gen_array_ops.matrix_diag_v3(
+        matrix_diagonals,
+        k=(-1, 1),
+        num_rows=-1,
+        num_cols=-1,
+        align='LEFT_RIGHT',
+        padding_value=0.)
+
+    if diagonals_format == 'sequence':
+      diagonals = [superdiag, diag, subdiag]
+    elif diagonals_format == 'compact':
+      diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2)
+    elif diagonals_format == 'matrix':
+      diagonals = matrix
+
+    lin_op_diagonals = diagonals
+
+    if use_placeholder:
+      if diagonals_format == 'sequence':
+        lin_op_diagonals = [array_ops.placeholder_with_default(
+            d, shape=None) for d in lin_op_diagonals]
+      else:
+        lin_op_diagonals = array_ops.placeholder_with_default(
+            lin_op_diagonals, shape=None)
+
+    operator = linalg_lib.LinearOperatorTridiag(
+        diagonals=lin_op_diagonals,
+        diagonals_format=diagonals_format,
+        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
+        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
+    return operator, matrix
+
+  @staticmethod
+  def operator_shapes_infos():
+    shape_info = linear_operator_test_util.OperatorShapesInfo
+    # non-batch operators (n, n) and batch operators.
+    return [
+        shape_info((3, 3)),
+        shape_info((1, 6, 6)),
+        shape_info((3, 4, 4)),
+        shape_info((2, 1, 3, 3))
+    ]
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagCompactTest(
+    _LinearOperatorTriDiagBase,
+    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+  """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+  def operator_and_matrix(
+      self, build_info, dtype, use_placeholder,
+      ensure_self_adjoint_and_pd=False):
+    return self.build_operator_and_matrix(
+        build_info, dtype, use_placeholder,
+        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+        diagonals_format='compact')
+
+  def test_tape_safe(self):
+    diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]])
+    operator = linalg_lib.LinearOperatorTridiag(
+        diag, diagonals_format='compact')
+    self.check_tape_safe(operator)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagSequenceTest(
+    _LinearOperatorTriDiagBase,
+    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+  """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+  def operator_and_matrix(
+      self, build_info, dtype, use_placeholder,
+      ensure_self_adjoint_and_pd=False):
+    return self.build_operator_and_matrix(
+        build_info, dtype, use_placeholder,
+        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+        diagonals_format='sequence')
+
+  def test_tape_safe(self):
+    diagonals = [
+        variables_module.Variable([3., 6., 2.]),
+        variables_module.Variable([2., 4., 2.]),
+        variables_module.Variable([5., 1., 2.])]
+    operator = linalg_lib.LinearOperatorTridiag(
+        diagonals, diagonals_format='sequence')
+    # Skip the diagonal part and trace since this only dependent on the
+    # middle variable. We test this below.
+    self.check_tape_safe(operator, skip_options=['diag_part', 'trace'])
+
+    diagonals = [
+        [3., 6., 2.],
+        variables_module.Variable([2., 4., 2.]),
+        [5., 1., 2.]
+    ]
+    operator = linalg_lib.LinearOperatorTridiag(
+        diagonals, diagonals_format='sequence')
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class LinearOperatorTriDiagMatrixTest(
+    _LinearOperatorTriDiagBase,
+    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+  """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+  def operator_and_matrix(
+      self, build_info, dtype, use_placeholder,
+      ensure_self_adjoint_and_pd=False):
+    return self.build_operator_and_matrix(
+        build_info, dtype, use_placeholder,
+        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
+        diagonals_format='matrix')
+
+  def test_tape_safe(self):
+    matrix = variables_module.Variable([[3., 2., 0.], [1., 6., 4.], [0., 2, 2]])
+    operator = linalg_lib.LinearOperatorTridiag(
+        matrix, diagonals_format='matrix')
+    self.check_tape_safe(operator)
+
+
+if __name__ == '__main__':
+  linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
+  linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
+  linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
+  test.main()
diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py
index 2bfbb37..94d85cb 100644
--- a/tensorflow/python/ops/linalg/linalg.py
+++ b/tensorflow/python/ops/linalg/linalg.py
@@ -39,6 +39,7 @@
 from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
 from tensorflow.python.ops.linalg.linear_operator_permutation import *
 from tensorflow.python.ops.linalg.linear_operator_toeplitz import *
+from tensorflow.python.ops.linalg.linear_operator_tridiag import *
 from tensorflow.python.ops.linalg.linear_operator_zeros import *
 # pylint: enable=wildcard-import
 
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 18d2296..3412486 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -28,10 +28,8 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import manip_ops
 from tensorflow.python.ops import map_fn
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import special_math_ops
@@ -486,14 +484,8 @@
           'Expected last two dimensions of diagonals to be same, got {} and {}'
           .format(m1, m2))
     m = m1 or m2
-    diagonals = gen_array_ops.matrix_diag_part_v2(
-        diagonals, k=(-1, 1), padding_value=0.)
-    # matrix_diag_part pads at the end. Because the subdiagonal has the
-    # convention of having the padding in the front, we need to rotate the last
-    # Tensor.
-    superdiag, d, subdiag = array_ops.unstack(diagonals, num=3, axis=-2)
-    subdiag = manip_ops.roll(subdiag, shift=1, axis=-1)
-    diagonals = array_ops.stack((superdiag, d, subdiag), axis=-2)
+    diagonals = array_ops.matrix_diag_part(
+        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
     return _tridiagonal_solve_compact_format(
         diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name)
 
@@ -614,20 +606,11 @@
       raise ValueError(
           'Expected last two dimensions of diagonals to be same, got {} and {}'
           .format(m1, m2))
-
-    maindiag = array_ops.matrix_diag_part(diagonals)
-    superdiag = gen_array_ops.matrix_diag_part_v2(
-        diagonals, k=1, padding_value=0.)
-    superdiag = array_ops.concat(
-        [superdiag,
-         array_ops.zeros_like(
-             superdiag[..., 0])[..., array_ops.newaxis]],
-        axis=-1)
-    subdiag = gen_array_ops.matrix_diag_part_v2(
-        diagonals, k=-1, padding_value=0.)
-    subdiag = array_ops.concat([
-        array_ops.zeros_like(subdiag[..., 0])[..., array_ops.newaxis],
-        subdiag], axis=-1)
+    diags = array_ops.matrix_diag_part(
+        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
+    superdiag = diags[..., 0, :]
+    maindiag = diags[..., 1, :]
+    subdiag = diags[..., 2, :]
   else:
     raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
 
diff --git a/tensorflow/python/ops/linalg/linear_operator_tridiag.py b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
new file mode 100644
index 0000000..4227478
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
@@ -0,0 +1,373 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""`LinearOperator` acting like a tridiagonal matrix."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg_impl as linalg
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.util.tf_export import tf_export
+
+__all__ = ['LinearOperatorTridiag',]
+
+_COMPACT = 'compact'
+_MATRIX = 'matrix'
+_SEQUENCE = 'sequence'
+_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE})
+
+
+@tf_export('linalg.LinearOperatorTridiag')
+class LinearOperatorTridiag(linear_operator.LinearOperator):
+  """`LinearOperator` acting like a [batch] square tridiagonal matrix.
+
+  This operator acts like a [batch] square tridiagonal matrix `A` with shape
+  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
+  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+  an `N x M` matrix.  This matrix `A` is not materialized, but for
+  purposes of broadcasting this shape will be relevant.
+
+  Example usage:
+
+  Create a 3 x 3 tridiagonal linear operator.
+
+  >>> superdiag = [3., 4., 5.]
+  >>> diag = [1., -1., 2.]
+  >>> subdiag = [6., 7., 8]
+  >>> operator = tf.linalg.LinearOperatorTridiag(
+  ...    [superdiag, diag, subdiag],
+  ...    diagonals_format='sequence')
+  >>> operator.to_dense()
+  <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
+  array([[ 1.,  3.,  0.],
+         [ 7., -1.,  4.],
+         [ 0.,  8.,  2.]], dtype=float32)>
+  >>> operator.shape
+  TensorShape([3, 3])
+
+  Scalar Tensor output.
+
+  >>> operator.log_abs_determinant()
+  <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333>
+
+  Create a [2, 3] batch of 4 x 4 linear operators.
+
+  >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4])
+  >>> operator = tf.linalg.LinearOperatorTridiag(
+  ...   diagonals,
+  ...   diagonals_format='compact')
+
+  Create a shape [2, 1, 4, 2] vector.  Note that this shape is compatible
+  since the batch dimensions, [2, 1], are broadcast to
+  operator.batch_shape = [2, 3].
+
+  >>> y = tf.random.normal(shape=[2, 1, 4, 2])
+  >>> x = operator.solve(y)
+  >>> x
+  <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=...,
+  dtype=float32)>
+
+  #### Shape compatibility
+
+  This operator acts on [batch] matrix with compatible shape.
+  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
+
+  ```
+  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
+  x.shape =   [C1,...,Cc] + [N, R],
+  and [C1,...,Cc] broadcasts with [B1,...,Bb].
+  ```
+
+  #### Performance
+
+  Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`,
+  and `x.shape = [N, R]`.  Then
+
+  * `operator.matmul(x)` will take O(N * R) time.
+  * `operator.solve(x)` will take O(N * R) time.
+
+  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+  #### Matrix property hints
+
+  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+  for `X = non_singular, self_adjoint, positive_definite, square`.
+  These have the following meaning:
+
+  * If `is_X == True`, callers should expect the operator to have the
+    property `X`.  This is a promise that should be fulfilled, but is *not* a
+    runtime assert.  For example, finite floating point precision may result
+    in these promises being violated.
+  * If `is_X == False`, callers should expect the operator to not have `X`.
+  * If `is_X == None` (the default), callers should have no expectation either
+    way.
+  """
+
+  def __init__(self,
+               diagonals,
+               diagonals_format=_COMPACT,
+               is_non_singular=None,
+               is_self_adjoint=None,
+               is_positive_definite=None,
+               is_square=None,
+               name='LinearOperatorTridiag'):
+    r"""Initialize a `LinearOperatorTridiag`.
+
+    Args:
+      diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`.
+
+        If `diagonals_format=sequence`, this is a list of three `Tensor`'s each
+        with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the
+        superdiagonal, diagonal and subdiagonal in that order. Note the
+        superdiagonal is padded with an element in the last position, and the
+        subdiagonal is padded with an element in the front.
+
+        If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped
+        `Tensor` representing the full tridiagonal matrix.
+
+        If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped
+        `Tensor` with the second to last dimension indexing the
+        superdiagonal, diagonal and subdiagonal in that order. Note the
+        superdiagonal is padded with an element in the last position, and the
+        subdiagonal is padded with an element in the front.
+
+        In every case, these `Tensor`s are all floating dtype.
+      diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
+        `compact`.
+      is_non_singular:  Expect that this operator is non-singular.
+      is_self_adjoint:  Expect that this operator is equal to its hermitian
+        transpose.  If `diag.dtype` is real, this is auto-set to `True`.
+      is_positive_definite:  Expect that this operator is positive definite,
+        meaning the quadratic form `x^H A x` has positive real part for all
+        nonzero `x`.  Note that we do not require the operator to be
+        self-adjoint to be positive-definite.  See:
+        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
+      is_square:  Expect that this operator acts like square [batch] matrices.
+      name: A name for this `LinearOperator`.
+
+    Raises:
+      TypeError:  If `diag.dtype` is not an allowed type.
+      ValueError:  If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
+    """
+
+    with ops.name_scope(name, values=[diagonals]):
+      if diagonals_format not in _DIAGONAL_FORMATS:
+        raise ValueError(
+            'Diagonals Format must be one of compact, matrix, sequence'
+            ', got : {}'.format(diagonals_format))
+      if diagonals_format == _SEQUENCE:
+        self._diagonals = [linear_operator_util.convert_nonref_to_tensor(
+            d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)]
+        dtype = self._diagonals[0].dtype
+      else:
+        self._diagonals = linear_operator_util.convert_nonref_to_tensor(
+            diagonals, name='diagonals')
+        dtype = self._diagonals.dtype
+      self._diagonals_format = diagonals_format
+
+      super(LinearOperatorTridiag, self).__init__(
+          dtype=dtype,
+          is_non_singular=is_non_singular,
+          is_self_adjoint=is_self_adjoint,
+          is_positive_definite=is_positive_definite,
+          is_square=is_square,
+          name=name)
+
+  def _shape(self):
+    if self.diagonals_format == _MATRIX:
+      return self.diagonals.shape
+    if self.diagonals_format == _COMPACT:
+      # Remove the second to last dimension that contains the value 3.
+      d_shape = self.diagonals.shape[:-2].concatenate(
+          self.diagonals.shape[-1])
+    else:
+      broadcast_shape = array_ops.broadcast_static_shape(
+          self.diagonals[0].shape[:-1],
+          self.diagonals[1].shape[:-1])
+      broadcast_shape = array_ops.broadcast_static_shape(
+          broadcast_shape,
+          self.diagonals[2].shape[:-1])
+      d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1])
+    return d_shape.concatenate(d_shape[-1])
+
+  def _shape_tensor(self, diagonals=None):
+    diagonals = diagonals if diagonals is not None else self.diagonals
+    if self.diagonals_format == _MATRIX:
+      return array_ops.shape(diagonals)
+    if self.diagonals_format == _COMPACT:
+      d_shape = array_ops.shape(diagonals[..., 0, :])
+    else:
+      broadcast_shape = array_ops.broadcast_dynamic_shape(
+          array_ops.shape(self.diagonals[0])[:-1],
+          array_ops.shape(self.diagonals[1])[:-1])
+      broadcast_shape = array_ops.broadcast_dynamic_shape(
+          broadcast_shape,
+          array_ops.shape(self.diagonals[2])[:-1])
+      d_shape = array_ops.concat(
+          [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0)
+    return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1)
+
+  def _assert_self_adjoint(self):
+    # Check the diagonal has non-zero imaginary, and the super and subdiagonals
+    # are conjugate.
+
+    asserts = []
+    diag_message = (
+        'This tridiagonal operator contained non-zero '
+        'imaginary values on the diagonal.')
+    off_diag_message = (
+        'This tridiagonal operator has non-conjugate '
+        'subdiagonal and superdiagonal.')
+
+    if self.diagonals_format == _MATRIX:
+      asserts += [check_ops.assert_equal(
+          self.diagonals, linalg.adjoint(self.diagonals),
+          message='Matrix was not equal to its adjoint.')]
+    elif self.diagonals_format == _COMPACT:
+      diagonals = ops.convert_to_tensor(self.diagonals)
+      asserts += [linear_operator_util.assert_zero_imag_part(
+          diagonals[..., 1, :], message=diag_message)]
+      # Roll the subdiagonal so the shifted argument is at the end.
+      subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1)
+      asserts += [check_ops.assert_equal(
+          math_ops.conj(subdiag[..., :-1]),
+          diagonals[..., 0, :-1],
+          message=off_diag_message)]
+    else:
+      asserts += [linear_operator_util.assert_zero_imag_part(
+          self.diagonals[1], message=diag_message)]
+      subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1)
+      asserts += [check_ops.assert_equal(
+          math_ops.conj(subdiag[..., :-1]),
+          self.diagonals[0][..., :-1],
+          message=off_diag_message)]
+    return control_flow_ops.group(asserts)
+
+  def _construct_adjoint_diagonals(self, diagonals):
+    # Constructs adjoint tridiagonal matrix from diagonals.
+    if self.diagonals_format == _SEQUENCE:
+      diagonals = [math_ops.conj(d) for d in reversed(diagonals)]
+      # The subdiag and the superdiag swap places, so we need to shift the
+      # padding argument.
+      diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1)
+      diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1)
+      return diagonals
+    elif self.diagonals_format == _MATRIX:
+      return linalg.adjoint(diagonals)
+    else:
+      diagonals = math_ops.conj(diagonals)
+      superdiag, diag, subdiag = array_ops.unstack(
+          diagonals, num=3, axis=-2)
+      # The subdiag and the superdiag swap places, so we need
+      # to shift all arguments.
+      new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1)
+      new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1)
+      return array_ops.stack([new_superdiag, diag, new_subdiag], axis=-2)
+
+  def _matmul(self, x, adjoint=False, adjoint_arg=False):
+    diagonals = self.diagonals
+    if adjoint:
+      diagonals = self._construct_adjoint_diagonals(diagonals)
+    x = linalg.adjoint(x) if adjoint_arg else x
+    return linalg.tridiagonal_matmul(
+        diagonals, x,
+        diagonals_format=self.diagonals_format)
+
+  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
+    diagonals = self.diagonals
+    if adjoint:
+      diagonals = self._construct_adjoint_diagonals(diagonals)
+
+    # TODO(b/144860784): Remove the broadcasting code below once
+    # tridiagonal_solve broadcasts.
+
+    rhs_shape = array_ops.shape(rhs)
+    k = self._shape_tensor(diagonals)[-1]
+    broadcast_shape = array_ops.broadcast_dynamic_shape(
+        self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
+    rhs = array_ops.broadcast_to(
+        rhs, array_ops.concat(
+            [broadcast_shape, rhs_shape[-2:]], axis=-1))
+    if self.diagonals_format == _MATRIX:
+      diagonals = array_ops.broadcast_to(
+          diagonals, array_ops.concat(
+              [broadcast_shape, [k, k]], axis=-1))
+    elif self.diagonals_format == _COMPACT:
+      diagonals = array_ops.broadcast_to(
+          diagonals, array_ops.concat(
+              [broadcast_shape, [3, k]], axis=-1))
+    else:
+      diagonals = [
+          array_ops.broadcast_to(d, array_ops.concat(
+              [broadcast_shape, [k]], axis=-1)) for d in diagonals]
+
+    y = linalg.tridiagonal_solve(
+        diagonals, rhs,
+        diagonals_format=self.diagonals_format,
+        transpose_rhs=adjoint_arg,
+        conjugate_rhs=adjoint_arg)
+    return y
+
+  def _diag_part(self):
+    if self.diagonals_format == _MATRIX:
+      return array_ops.matrix_diag_part(self.diagonals)
+    elif self.diagonals_format == _SEQUENCE:
+      diagonal = self.diagonals[1]
+      return array_ops.broadcast_to(
+          diagonal, self.shape_tensor()[:-1])
+    else:
+      return self.diagonals[..., 1, :]
+
+  def _to_dense(self):
+    if self.diagonals_format == _MATRIX:
+      return self.diagonals
+
+    if self.diagonals_format == _COMPACT:
+      return gen_array_ops.matrix_diag_v3(
+          self.diagonals,
+          k=(-1, 1),
+          num_rows=-1,
+          num_cols=-1,
+          align='LEFT_RIGHT',
+          padding_value=0.)
+
+    diagonals = [ops.convert_to_tensor(d) for d in self.diagonals]
+    diagonals = array_ops.stack(diagonals, axis=-2)
+
+    return gen_array_ops.matrix_diag_v3(
+        diagonals,
+        k=(-1, 1),
+        num_rows=-1,
+        num_cols=-1,
+        align='LEFT_RIGHT',
+        padding_value=0.)
+
+  @property
+  def diagonals(self):
+    return self._diagonals
+
+  @property
+  def diagonals_format(self):
+    return self._diagonals_format
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt
new file mode 100644
index 0000000..0609904
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-tridiag.pbtxt
@@ -0,0 +1,185 @@
+path: "tensorflow.linalg.LinearOperatorTridiag"
+tf_class {
+  is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_tridiag.LinearOperatorTridiag\'>"
+  is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "H"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "batch_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "diagonals"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "diagonals_format"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "domain_dimension"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph_parents"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_non_singular"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_positive_definite"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_self_adjoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_square"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "range_dimension"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "tensor_rank"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'diagonals\', \'diagonals_format\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'compact\', \'None\', \'None\', \'None\', \'None\', \'LinearOperatorTridiag\'], "
+  }
+  member_method {
+    name: "add_to_tensor"
+    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+  }
+  member_method {
+    name: "adjoint"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+  }
+  member_method {
+    name: "assert_non_singular"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+  }
+  member_method {
+    name: "assert_positive_definite"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+  }
+  member_method {
+    name: "assert_self_adjoint"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+  }
+  member_method {
+    name: "batch_shape_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+  }
+  member_method {
+    name: "cholesky"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+  }
+  member_method {
+    name: "cond"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cond\'], "
+  }
+  member_method {
+    name: "determinant"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+  }
+  member_method {
+    name: "diag_part"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+  }
+  member_method {
+    name: "domain_dimension_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+  }
+  member_method {
+    name: "eigvals"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'eigvals\'], "
+  }
+  member_method {
+    name: "inverse"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+  }
+  member_method {
+    name: "log_abs_determinant"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+  }
+  member_method {
+    name: "matmul"
+    argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+  }
+  member_method {
+    name: "matvec"
+    argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+  }
+  member_method {
+    name: "range_dimension_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+  }
+  member_method {
+    name: "shape_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+  }
+  member_method {
+    name: "solve"
+    argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+  }
+  member_method {
+    name: "solvevec"
+    argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+  }
+  member_method {
+    name: "tensor_rank_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+  }
+  member_method {
+    name: "to_dense"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+  }
+  member_method {
+    name: "trace"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
index 632400c..264294d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
@@ -73,6 +73,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "LinearOperatorTridiag"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "LinearOperatorZeros"
     mtype: "<type \'type\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt
new file mode 100644
index 0000000..0609904
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-tridiag.pbtxt
@@ -0,0 +1,185 @@
+path: "tensorflow.linalg.LinearOperatorTridiag"
+tf_class {
+  is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_tridiag.LinearOperatorTridiag\'>"
+  is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+  is_instance: "<class \'tensorflow.python.module.module.Module\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
+  is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "H"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "batch_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "diagonals"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "diagonals_format"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "domain_dimension"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph_parents"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_non_singular"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_positive_definite"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_self_adjoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "is_square"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name_scope"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "range_dimension"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "submodules"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "tensor_rank"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'diagonals\', \'diagonals_format\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'compact\', \'None\', \'None\', \'None\', \'None\', \'LinearOperatorTridiag\'], "
+  }
+  member_method {
+    name: "add_to_tensor"
+    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+  }
+  member_method {
+    name: "adjoint"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+  }
+  member_method {
+    name: "assert_non_singular"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+  }
+  member_method {
+    name: "assert_positive_definite"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+  }
+  member_method {
+    name: "assert_self_adjoint"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+  }
+  member_method {
+    name: "batch_shape_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+  }
+  member_method {
+    name: "cholesky"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+  }
+  member_method {
+    name: "cond"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cond\'], "
+  }
+  member_method {
+    name: "determinant"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+  }
+  member_method {
+    name: "diag_part"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+  }
+  member_method {
+    name: "domain_dimension_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+  }
+  member_method {
+    name: "eigvals"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'eigvals\'], "
+  }
+  member_method {
+    name: "inverse"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+  }
+  member_method {
+    name: "log_abs_determinant"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+  }
+  member_method {
+    name: "matmul"
+    argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+  }
+  member_method {
+    name: "matvec"
+    argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+  }
+  member_method {
+    name: "range_dimension_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+  }
+  member_method {
+    name: "shape_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+  }
+  member_method {
+    name: "solve"
+    argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+  }
+  member_method {
+    name: "solvevec"
+    argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+  }
+  member_method {
+    name: "tensor_rank_tensor"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+  }
+  member_method {
+    name: "to_dense"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+  }
+  member_method {
+    name: "trace"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+  }
+  member_method {
+    name: "with_name_scope"
+    argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index 041041f..7d6f02a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -73,6 +73,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "LinearOperatorTridiag"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "LinearOperatorZeros"
     mtype: "<type \'type\'>"
   }