[XLA] Add more eigendecomposition tests.
Expose SortByEigenvalues to subclasses of EighExpander.
PiperOrigin-RevId: 372424661
Change-Id: Ide90f07504636de6fb65abd1189531aed498f969
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 3a3a519..f9f4d06 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -468,6 +468,7 @@
name = "self_adjoint_eig_test",
srcs = ["self_adjoint_eig_test.cc"],
real_hardware_only = True,
+ shard_count = 5,
tags = ["optonly"],
deps = [
":arithmetic",
diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
index 99f2591..b872b05 100644
--- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
+++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc
@@ -285,7 +285,8 @@
INSTANTIATE_TEST_SUITE_P(
RandomEighTestInstantiation, RandomEighTest,
- ::testing::Values(0, 1, 2, 3, 8, 16, 32, 256, 512),
+ ::testing::Values(0, 1, 2, 3, 8, 16, 32, 77, 129, 203, 256, 257, 493, 511,
+ 512, 513),
[](const ::testing::TestParamInfo<EighTestCase>& info) {
const int64 size = info.param;
return absl::StrCat(size);
diff --git a/tensorflow/compiler/xla/service/eigh_expander.cc b/tensorflow/compiler/xla/service/eigh_expander.cc
index 33b1c2f..abf02e1 100644
--- a/tensorflow/compiler/xla/service/eigh_expander.cc
+++ b/tensorflow/compiler/xla/service/eigh_expander.cc
@@ -33,6 +33,7 @@
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
// Parallel two-sided Jacobi symmetric eigendecomposition.
//
@@ -130,7 +131,7 @@
// same_sign = (same_sign == which_max_abs)
// cosine, sine = (np.where(same_sign, -sine, cosine),
// np.where(same_sign, cosine, sine))
-// return rt1, rt2, cosine, sine
+// return rt1, rt2, cosine, -sine
StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
XlaOp w_br) {
TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
@@ -408,7 +409,9 @@
"EighJacobiSweeps", builder);
}
-StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
+} // namespace
+
+Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) {
XlaBuilder* builder = v.builder();
TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
@@ -427,11 +430,9 @@
num_dims - 1);
w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
v = GetTupleElement(sort_result, 1);
- return std::make_pair(v, w);
+ return Status::OK();
}
-} // namespace
-
// This is the cyclic Jacobi iteration.
//
// def jacobi(A):
@@ -586,7 +587,7 @@
}
v = MaybeConjugate(TransposeInMinorDims(v), true);
- TF_ASSIGN_OR_RETURN(std::tie(v, w), SortByEigenvalues(v, w));
+ TF_RETURN_IF_ERROR(SortByEigenvalues(v, w));
return Tuple(builder, {v, w});
});
}
diff --git a/tensorflow/compiler/xla/service/eigh_expander.h b/tensorflow/compiler/xla/service/eigh_expander.h
index ec282e7..6ce74e3 100644
--- a/tensorflow/compiler/xla/service/eigh_expander.h
+++ b/tensorflow/compiler/xla/service/eigh_expander.h
@@ -34,6 +34,8 @@
virtual XlaOp BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol);
+ Status SortByEigenvalues(XlaOp& v, XlaOp& w);
+
private:
// Mapping from op signatures to existing computations.
absl::flat_hash_map<string, HloComputation*> computation_cache_;