Don't crash on empty RHS in matrix_triangular_solve on GPU.
PiperOrigin-RevId: 279424548
Change-Id: I59b10866bd99f9e92e8714fc1b1a05674b301ce8
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index 16fb29f..61bc4aa 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -175,7 +175,7 @@
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
- if (matrix.rows() == 0 || rhs.cols() == 0) {
+ if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index ec2ed12..32ab612 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -153,7 +153,7 @@
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
- with self.cached_session():
+ with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, matrix)
with self.assertRaises(ValueError):
@@ -165,7 +165,7 @@
# right-hand sides.
matrix = np.array([[1., 0.], [0., 1.]])
rhs = np.array([[1., 0.]])
- with self.cached_session():
+ with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
@@ -176,6 +176,7 @@
def testNotInvertible(self):
# The input should be invertible.
# The matrix is singular because it has a zero on the diagonal.
+ # FIXME(rmlarsen): The GPU kernel does not check for singularity.
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
with self.cached_session():
with self.assertRaisesOpError("Input matrix is not invertible."):