Make LinearOperatorBlockDiag tape safety check different diagonal components.
PiperOrigin-RevId: 289172423
Change-Id: I337870843934bdfe2d49caf1977a2613560cb709
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
index dc501b1..abaf9bf 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
@@ -213,20 +213,16 @@
self.assertEqual(2, len(inverse.operators))
def test_tape_safe(self):
- matrix = variables_module.Variable([[1., 0.], [0., 1.]])
+ matrices = []
+ for _ in range(4):
+ matrices.append(variables_module.Variable(
+ linear_operator_test_util.random_positive_definite_matrix(
+ [2, 2], dtype=dtypes.float32, force_well_conditioned=True)))
+
operator = block_diag.LinearOperatorBlockDiag(
- [
- linalg.LinearOperatorFullMatrix(
- matrix,
- is_self_adjoint=True,
- is_positive_definite=True,
- ),
- linalg.LinearOperatorFullMatrix(
- matrix,
- is_self_adjoint=True,
- is_positive_definite=True,
- ),
- ],
+ [linalg.LinearOperatorFullMatrix(
+ matrix, is_self_adjoint=True,
+ is_positive_definite=True) for matrix in matrices],
is_self_adjoint=True,
is_positive_definite=True,
)