[XLA] Add more eigendecomposition tests.

Expose SortByEigenvalues to subclasses of EighExpander.

PiperOrigin-RevId: 372424661
Change-Id: Ide90f07504636de6fb65abd1189531aed498f969
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 3a3a519..f9f4d06 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -468,6 +468,7 @@
     name = "self_adjoint_eig_test",
     srcs = ["self_adjoint_eig_test.cc"],
     real_hardware_only = True,
+    shard_count = 5,
     tags = ["optonly"],
     deps = [
         ":arithmetic",
diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
index 99f2591..b872b05 100644
--- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
+++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
@@ -285,7 +285,8 @@
 
 INSTANTIATE_TEST_SUITE_P(
     RandomEighTestInstantiation, RandomEighTest,
-    ::testing::Values(0, 1, 2, 3, 8, 16, 32, 256, 512),
+    ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
+                      512, 513),
     [](const ::testing::TestParamInfo<EighTestCase>& info) {
       const int64 size = info.param;
       return absl::StrCat(size);
diff --git a/tensorflow/compiler/xla/service/eigh_expander.cc b/tensorflow/compiler/xla/service/eigh_expander.cc
index 33b1c2f..abf02e1 100644
--- a/tensorflow/compiler/xla/service/eigh_expander.cc
+++ b/tensorflow/compiler/xla/service/eigh_expander.cc
@@ -33,6 +33,7 @@
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 
 // Parallel two-sided Jacobi symmetric eigendecomposition.
 //
@@ -130,7 +131,7 @@
 //   same_sign = (same_sign == which_max_abs)
 //   cosine, sine = (np.where(same_sign, -sine, cosine),
 //                   np.where(same_sign, cosine, sine))
-//   return rt1, rt2, cosine, sine
+//   return rt1, rt2, cosine, -sine
 StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
                                                  XlaOp w_br) {
   TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
@@ -408,7 +409,9 @@
                          "EighJacobiSweeps", builder);
 }
 
-StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
+}  // namespace
+
+Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) {
   XlaBuilder* builder = v.builder();
   TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
   TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
@@ -427,11 +430,9 @@
            num_dims - 1);
   w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
   v = GetTupleElement(sort_result, 1);
-  return std::make_pair(v, w);
+  return Status::OK();
 }
 
-}  // namespace
-
 // This is the cyclic Jacobi iteration.
 //
 // def jacobi(A):
@@ -586,7 +587,7 @@
     }
     v = MaybeConjugate(TransposeInMinorDims(v), true);
 
-    TF_ASSIGN_OR_RETURN(std::tie(v, w), SortByEigenvalues(v, w));
+    TF_RETURN_IF_ERROR(SortByEigenvalues(v, w));
     return Tuple(builder, {v, w});
   });
 }
diff --git a/tensorflow/compiler/xla/service/eigh_expander.h b/tensorflow/compiler/xla/service/eigh_expander.h
index ec282e7..6ce74e3 100644
--- a/tensorflow/compiler/xla/service/eigh_expander.h
+++ b/tensorflow/compiler/xla/service/eigh_expander.h
@@ -34,6 +34,8 @@
 
   virtual XlaOp BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol);
 
+  Status SortByEigenvalues(XlaOp& v, XlaOp& w);
+
  private:
   // Mapping from op signatures to existing computations.
   absl::flat_hash_map<string, HloComputation*> computation_cache_;