Don't crash on empty RHS in matrix_triangular_solve on GPU.

PiperOrigin-RevId: 279424548
Change-Id: I59b10866bd99f9e92e8714fc1b1a05674b301ce8
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index 16fb29f..61bc4aa 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -175,7 +175,7 @@
     const ConstMatrixMap& rhs = inputs[1];
     MatrixMap& output = outputs->at(0);
 
-    if (matrix.rows() == 0 || rhs.cols() == 0) {
+    if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
       // To be consistent with the MatrixInverse op, we define the solution for
       // an empty set of equation as the empty matrix.
       return;
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index ec2ed12..32ab612 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -153,7 +153,7 @@
   def testNonSquareMatrix(self):
     # A non-square matrix should cause an error.
     matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       with self.assertRaises(ValueError):
         self._verifySolve(matrix, matrix)
       with self.assertRaises(ValueError):
@@ -165,7 +165,7 @@
     # right-hand sides.
     matrix = np.array([[1., 0.], [0., 1.]])
     rhs = np.array([[1., 0.]])
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       with self.assertRaises(ValueError):
         self._verifySolve(matrix, rhs)
       with self.assertRaises(ValueError):
@@ -176,6 +176,7 @@
   def testNotInvertible(self):
     # The input should be invertible.
     # The matrix is singular because it has a zero on the diagonal.
+    # FIXME(rmlarsen): The GPU kernel does not check for singularity.
     singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
     with self.cached_session():
       with self.assertRaisesOpError("Input matrix is not invertible."):