#include "tensorflow/compiler/xla/service/eigh_expander.h"
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
// Parallel two-sided Jacobi symmetric eigendecomposition.
// The implementation follows the approach described in:
// Brent, Richard P., and Franklin T. Luk. "The solution of singular-value and
// symmetric eigenvalue problems on multiprocessor arrays." SIAM Journal on
// Scientific and Statistical Computing 6.1 (1985): 69-84.
// Where the Brent/Luk paper uses "processors", we use "vector elements".
namespace xla {
namespace {
// A 2x2 symmetric Eigendecomposition of a matrix A.
// If
// G = [[ c, s],
// [-s, c]]
// matmul(G_T, G) = I
// and
// G @ [[rt1, 0 ], @ G.T = A
// [ 0, rt2]]
struct Eigh2x2 {
// Eigenvalues
XlaOp rt1;
XlaOp rt2;
// First row of Eigenvector matrix.
XlaOp c; // cosine.
XlaOp s; // sine.
// sqrt(x**2 + y**2), calculated avoiding overflow.
XlaOp Hypot(XlaOp x, XlaOp y) {
x = Abs(x);
y = Abs(y);
auto xy_min = Min(x, y);
auto xy_max = Max(x, y);
auto out = xy_max * Sqrt(ScalarLike(x, 1) + Square(xy_min / xy_max));
return Select(Eq(xy_min, xy_max), xy_min * ScalarLike(xy_min, std::sqrt(2.)),
// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
// a Jacobi rotation computes a rotation matrix G = [[c, s], [-s, c]], such that
// G_T * A[[p, q], [p, q]] * G
// is diagonalized. We do this by computing a 2x2 eigendecomposition.
// In this parallel Jacobi algorithm, we simultaneously compute Jacobi rotations
// for all of the matrix diagonal elements at the same time. The matrix diagonal
// elements correspond to different rows and columns of the original matrix and
// their rotations do not interfere and hence can be computed in parallel.
// The algorithm is based on slaev2/claev2 from LAPACK, modified to allow for
// vectorization.
// In addition, slaev2 always returns the largest eigenvalue as rt1, which has
// the effect of swapping eigenvalues around in the Jacob algorithm. This does
// not converge when used in a parallel Jacobi algorithm, so we modify the
// algorithm to maintain the following symmetry property:
// slaev2(a, b, c) has the opposite Eigenvalue order from slaev2(c, b, a)
// def symmetric_eigendecomposition_2x2(a, b, c):
// # Input matrix [[a, b], [b, c]].
// ac_sum = a + c
// ac_diff = a - c
// two_b = 2*b
// rt = hypot(ac_diff, two_b)
// which_max_abs = np.abs(a) > np.abs(c)
// ac_max = np.where(which_max_abs, a, c)
// ac_min = np.where(which_max_abs, c, a)
// rt1 = np.float32(0.5)*(ac_sum + np.where(ac_sum < 0, -rt, rt))
// rt2 = np.where(ac_sum != 0, (ac_max / rt1)*ac_min - (b/rt1)*b,
// -np.float32(0.5)*rt)
// # Modification: don't sort the Eigenvalues.
// rt1, rt2 = (np.where(which_max_abs, rt1, rt2),
// np.where(which_max_abs, rt2, rt1))
// # Compute eigenvectors
// cs = ac_diff + np.where(ac_diff >= 0, rt, -rt)
// ct = -two_b / cs
// tn = -cs / two_b
// cosine = np.where(two_b != 0, np.float32(1) / np.sqrt(1 + tn*tn),
// np.float32(1))
// sine = np.where(two_b != 0, tn * cosine, np.float32(0))
// tmp = 1 / np.sqrt(1 + ct*ct)
// cosine = np.where(np.abs(cs) > np.abs(two_b), ct*tmp, cosine)
// sine = np.where(np.abs(cs) > np.abs(two_b), tmp, sine)
// same_sign = (ac_sum >= 0) == (ac_diff >= 0)
// # Modification: use Eigenvalues corresponding to the Eigenvectors above.
// same_sign = (same_sign == which_max_abs)
// cosine, sine = (np.where(same_sign, -sine, cosine),
// np.where(same_sign, cosine, sine))
// return rt1, rt2, cosine, -sine
StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
XlaOp w_br) {
TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type());
auto a = GetMatrixDiagonal(Real(w_tl));
auto b = GetMatrixDiagonal(w_tr);
auto abs_b = Abs(b);
XlaOp w;
if (is_complex) {
w = Select(Eq(abs_b, ZerosLike(abs_b)), FullLike(b, 1),
Conj(b) / Complex(abs_b, ZerosLike(abs_b)));
b = abs_b;
auto c = GetMatrixDiagonal(Real(w_br));
auto zero = ScalarLike(a, 0.0);
auto half = ScalarLike(a, 0.5);
auto neg_half = ScalarLike(a, -0.5);
auto one = ScalarLike(a, 1.0);
auto two = ScalarLike(a, 2.0);
auto ac_sum = a + c;
auto ac_diff = a - c;
auto two_b = two * b;
auto rt = Hypot(ac_diff, two_b);
// Compute eigenvalues
auto which_max_abs = Gt(Abs(a), Abs(c));
auto ac_max = Select(which_max_abs, a, c);
auto ac_min = Select(which_max_abs, c, a);
auto rt1 = half * (ac_sum + Select(Lt(ac_sum, zero), -rt, rt));
auto rt2 = Select(Ne(ac_sum, zero), (ac_max / rt1) * ac_min - (b / rt1) * b,
neg_half * rt);
std::tie(rt1, rt2) = std::make_tuple(Select(which_max_abs, rt1, rt2),
Select(which_max_abs, rt2, rt1));
// Compute eigenvectors
auto cs = ac_diff + Select(Ge(ac_diff, zero), rt, -rt);
auto ct = -two_b / cs;
auto tn = -cs / two_b;
auto cosine = Select(Ne(two_b, zero), Rsqrt(one + Square(tn)), one);
auto sine = Select(Ne(two_b, zero), tn * cosine, zero);
auto tmp = Rsqrt(one + Square(ct));
auto abs_cs_larger = Gt(Abs(cs), Abs(two_b));
cosine = Select(abs_cs_larger, ct * tmp, cosine);
sine = Select(abs_cs_larger, tmp, sine);
auto same_sign = Eq(Ge(ac_sum, zero), Ge(ac_diff, zero));
same_sign = Eq(same_sign, which_max_abs);
std::tie(cosine, sine) = std::make_tuple(Select(same_sign, -sine, cosine),
Select(same_sign, cosine, sine));
// Negate 'sine' because we are returning the first row of the rotation matrix
// not the first eigenvector.
if (is_complex) {
rt1 = Complex(rt1, ZerosLike(rt1));
rt2 = Complex(rt2, ZerosLike(rt2));
cosine = Complex(cosine, ZerosLike(cosine));
sine = Complex(sine, ZerosLike(sine)) * w;
return Eigh2x2{rt1, rt2, cosine, -sine};
// tl, tr, bl, br = (
// tl * c[:, None] - bl * s[:, None],
// tr * c[:, None] - br * s[:, None],
// tl * s[:, None] + bl * c[:, None],
// tr * s[:, None] + br * c[:, None],
// )
void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
XlaOp& bl, XlaOp& br) {
Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
std::vector<int64> broadcast_dims(shape.dimensions().size() - 1);
absl::c_iota(broadcast_dims, 0);
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
auto s_conj = MaybeConjugate(s, true);
std::tie(tl, tr, bl, br) =
std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj,
tl * s + bl * c, tr * s + br * c);
// tl, tr, bl, br = (
// tl * c[None, :] - tr * s[None, :],
// tl * s[None, :] + tr * c[None, :],
// bl * c[None, :] - br * s[None, :],
// bl * s[None, :] + br * c[None, :],
// )
void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
XlaOp& bl, XlaOp& br) {
Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
std::vector<int64> broadcast_dims(shape.dimensions().size() - 1);
absl::c_iota(broadcast_dims, 0);
broadcast_dims.back() = shape.dimensions().size() - 1;
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
auto s_conj = MaybeConjugate(s, true);
std::tie(tl, tr, bl, br) =
std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s,
bl * s_conj + br * c);
// def permute_rows_in_col(top, bottom):
// top_out = np.zeros_like(l)
// top_out[0] = top[0]
// top_out[1] = bottom[0]
// top_out[2:] = top[1:-1]
// bottom_out = np.zeros_like(r)
// bottom_out[:-1] = bottom[1:]
// bottom_out[-1] = top[-1]
// return top_out, bottom_out
void PermuteRowsInColumn(XlaOp& top, XlaOp& bottom) {
XlaBuilder* builder = top.builder();
Shape shape = builder->GetShape(top).ValueOrDie();
int64 k = ShapeUtil::GetDimension(shape, -1);
if (k <= 1) {
int ndim = shape.dimensions_size();
std::tie(top, bottom) =
{SliceInMinorDims(top, {0, 0}, {1, k}),
SliceInMinorDims(bottom, {0, 0}, {1, k}),
SliceInMinorDims(top, {1, 0}, {k - 1, k})},
ndim - 2),
{SliceInMinorDims(bottom, {1, 0}, {k, k}),
SliceInMinorDims(top, {k - 1, 0}, {k, k})},
ndim - 2));
void PermuteColumnsInRow(XlaOp& left, XlaOp& right) {
XlaBuilder* builder = left.builder();
Shape shape = builder->GetShape(left).ValueOrDie();
int64 k = ShapeUtil::GetDimension(shape, -1);
if (k <= 1) {
int ndim = shape.dimensions_size();
std::tie(left, right) =
{SliceInMinorDims(left, {0}, {1}),
SliceInMinorDims(right, {0}, {1}),
SliceInMinorDims(left, {1}, {k - 1})},
ndim - 1),
{SliceInMinorDims(right, {1}, {k}),
SliceInMinorDims(left, {k - 1}, {k})},
ndim - 1));
// Performs one round of parallel Jacobi rotations; n-1 rounds make a sweep.
// After each rotation, we permute the rows and columns of the quadrants of the
// matrix. The effect of the permutations is that all pairs of rows end up
// on the diagonal of the quadrants after n-1 rounds. The permutations are an
// implicit way of computing a tournament for n players such that each player
// plays every other player exactly once in n - 1 rounds. See the Brent/Luk
// paper for more details.
Status ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl,
XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl,
XlaOp& v_br) {
TF_ASSIGN_OR_RETURN(Eigh2x2 rotation,
HermitianEigenDecomposition2x2(w_tl, w_tr, w_br));
ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br);
ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br);
w_tl = SetMatrixDiagonal(w_tl, rotation.rt1);
w_tr = SetMatrixDiagonal(w_tr, ZerosLike(rotation.rt1));
w_bl = SetMatrixDiagonal(w_bl, ZerosLike(rotation.rt1));
w_br = SetMatrixDiagonal(w_br, rotation.rt2);
PermuteColumnsInRow(w_tl, w_tr);
PermuteColumnsInRow(w_bl, w_br);
PermuteRowsInColumn(w_tl, w_bl);
PermuteRowsInColumn(w_tr, w_br);
// Apply the rotations to the eigenvector matrix.
// TODO(phawkins): we could omit this if we aren't interested in computing the
// eigenvectors.
ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br);
PermuteRowsInColumn(v_tl, v_bl);
PermuteRowsInColumn(v_tr, v_br);
return Status::OK();
struct FrobeniusNorms {
XlaOp off_diagonal_sq_norm;
XlaOp frobenius_sq_norm;
StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr,
XlaOp w_bl, XlaOp w_br) {
XlaBuilder* builder = w_tl.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl));
const int64 num_dims = shape.rank();
auto square_norm = [](XlaOp x) -> XlaOp {
return Real(x * MaybeConjugate(x, true));
auto off_diag = [](XlaOp x) {
return Select(GetDiagonalMask(x), ZerosLike(x), x);
PrimitiveType norm_type =
? primitive_util::ComplexComponentType(shape.element_type())
: shape.element_type();
auto zero = ScalarLike(Real(w_tl), 0.0);
FrobeniusNorms norms;
norms.frobenius_sq_norm =
Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) +
zero, CreateScalarAddComputation(norm_type, builder),
{num_dims - 2, num_dims - 1});
norms.off_diagonal_sq_norm =
Reduce(square_norm(off_diag(w_tl)) + square_norm(w_tr) +
square_norm(w_bl) + square_norm(off_diag(w_br)),
zero, CreateScalarAddComputation(norm_type, builder),
{num_dims - 2, num_dims - 1});
return norms;
StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
int64 n, int max_iters,
PrimitiveType index_type,
XlaBuilder* builder) {
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
auto iter_cond = Lt(values[0], ScalarLike(values[0], max_iters));
XlaOp w_tl, w_tr, w_bl, w_br;
std::tie(w_tl, w_tr, w_bl, w_br) =
std::make_tuple(values[2], values[3], values[4], values[5]);
ComputeFrobeniusNorms(w_tl, w_tr, w_bl, w_br));
auto tol = norms.frobenius_sq_norm * Square(values[1]);
auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_sq_norm),
xla::ConstantR0<bool>(cond_builder, false),
CreateScalarOrComputation(PRED, cond_builder));
return And(iter_cond, tol_cond);
auto while_body_fn =
[&](absl::Span<const XlaOp> values,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
std::vector<XlaOp> sweep_values(values.begin() + 1, values.end());
n - 1, S32,
[&](XlaOp iter, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br;
std::tie(tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br) =
std::make_tuple(values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7],
TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl,
v_tr, v_bl, v_br));
return std::vector<XlaOp>{tol, w_tl, w_tr, w_bl, w_br,
v_tl, v_tr, v_bl, v_br};
sweep_values, "ApplyRotations", body_builder));
std::vector<XlaOp> output(values.size());
output[0] = values[0] + ScalarLike(values[0], 1);
std::copy(sweep_values.begin(), sweep_values.end(), output.begin() + 1);
return output;
return WhileLoopHelper(while_cond_fn, while_body_fn, initial_values,
"EighJacobiSweeps", builder);
} // namespace
Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) {
XlaBuilder* builder = v.builder();
TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
const int64 num_dims = v_shape.rank();
auto dimensions = v_shape.dimensions();
std::vector<int64> broadcast_dims(num_dims - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims[num_dims - 2] = num_dims - 1;
w = BroadcastInDim(w, dimensions, broadcast_dims);
XlaOp sort_result =
Sort({w, v},
{w_shape.element_type(), v_shape.element_type()}, builder),
num_dims - 1);
w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
v = GetTupleElement(sort_result, 1);
return Status::OK();
// This is the cyclic Jacobi iteration.
// def jacobi(A):
// n, _ = A.shape
// tl = A[:n // 2, :n // 2]
// bl = A[n // 2:, :n // 2]
// tr = A[:n // 2, n // 2:]
// br = A[n // 2:, n // 2:]
// v_tl = np.eye(n // 2, dtype=A.dtype)
// v_tr = np.zeros((n // 2, n // 2), A.dtype)
// v_bl = np.zeros((n // 2, n // 2), A.dtype)
// v_br = np.eye(n // 2, dtype=A.dtype)
// frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
// np.square(bl) + np.square(br)))
// diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
// np.square(np.diag(br))))
// off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
// frobenius_norm + diag_norm)
// while off_diag_norm > 1e-6 * frobenius_norm:
// for i in range(n - 1):
// c, s = sym_schur2x2(tl, tr, br)
// tl, tr, bl, br = (
// tl * c[:, None] - bl * s[:, None],
// tr * c[:, None] - br * s[:, None],
// tl * s[:, None] + bl * c[:, None],
// tr * s[:, None] + br * c[:, None],
// )
// tl, tr, bl, br = (
// tl * c[None, :] - tr * s[None, :],
// tl * s[None, :] + tr * c[None, :],
// bl * c[None, :] - br * s[None, :],
// bl * s[None, :] + br * c[None, :],
// )
// tl, bl = permute_rows_in_col(tl, bl)
// tr, br = permute_rows_in_col(tr, br)
// tl, tr = permute_cols_in_row(tl, tr)
// bl, br = permute_cols_in_row(bl, br)
// v_tl, v_tr, v_bl, v_br = (
// v_tl * c[:, None] - v_bl * s[:, None],
// v_tr * c[:, None] - v_br * s[:, None],
// v_tl * s[:, None] + v_bl * c[:, None],
// v_tr * s[:, None] + v_br * c[:, None],
// )
// v_tl, v_bl = permute_rovs_in_col(v_tl, v_bl)
// v_tr, v_br = permute_rovs_in_col(v_tr, v_br)
// frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
// np.square(bl) + np.square(br)))
// diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
// np.square(np.diag(br))))
// off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
// frobenius_norm + diag_norm)
// return A, V
XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int64 num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument(
"Arguments to Eigen decomposition must have rank >= 2: got shape %s.",
PrimitiveType type = a_shape.element_type();
if (!primitive_util::IsFloatingPointType(type) &&
!primitive_util::IsComplexType(type)) {
return InvalidArgument(
"Type of the input matrix must be floating point "
"or complex: got %s.",
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
if (m != n) {
return InvalidArgument(
"Arguments to symmetric eigendecomposition must be square matrices: "
"got shape (%d, %d).",
m, n);
const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
if (m <= 1) {
return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))});
a = Symmetrize(a, lower);
const int64 k = CeilOfRatio(n, int64{2});
// tl = A[:n // 2, :n // 2]
// bl = A[n // 2:, :n // 2]
// tr = A[:n // 2, n // 2:]
// br = A[n // 2:, n // 2:]
auto tl = SliceInMinorDims(a, {0, 0}, {k, k});
auto bl = SliceInMinorDims(a, {k, 0}, {n, k});
auto tr = SliceInMinorDims(a, {0, k}, {k, n});
auto br = SliceInMinorDims(a, {k, k}, {n, n});
if (n % 2) {
auto zero = Zero(builder, type);
tr = PadInDim(tr, zero, num_dims - 1, /*pad_lo=*/0, /*pad_hi=*/1);
bl = PadInDim(bl, zero, num_dims - 2, /*pad_lo=*/0, /*pad_hi=*/1);
PaddingConfig config = MakeNoPaddingConfig(num_dims);
config.mutable_dimensions(num_dims - 2)->set_edge_padding_high(1);
config.mutable_dimensions(num_dims - 1)->set_edge_padding_high(1);
br = Pad(br, zero, config);
// v_tl = np.eye(n // 2, dtype=A.dtype)
// v_tr = np.zeros((n // 2, n // 2), A.dtype)
// v_bl = np.zeros((n // 2, n // 2), A.dtype)
// v_br = np.eye(n // 2, dtype=A.dtype)
auto v_tl = Broadcast(IdentityMatrix(builder, type, k, k), batch_dims);
auto v_br = v_tl;
auto v_tr = ZerosLike(v_tl);
auto v_bl = v_tr;
TF_ASSIGN_OR_RETURN(auto output, Sweeps(
Zero(builder, S32),
ScalarLike(Real(a), tol),
k * 2, max_iter, S32, builder));
std::tie(tl, tr, bl, br) =
std::make_tuple(output[2], output[3], output[4], output[5]);
std::tie(v_tl, v_tr, v_bl, v_br) =
std::make_tuple(output[6], output[7], output[8], output[9]);
auto w = ConcatInDim(
builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))},
num_dims - 2);
auto v = ConcatInDim(builder,
{ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
num_dims - 2);
if (n % 2) {
w = SliceInMinorDims(w, {0}, {n});
v = SliceInMinorDims(v, {0, 0}, {n, n});
v = MaybeConjugate(TransposeInMinorDims(v), true);
TF_RETURN_IF_ERROR(SortByEigenvalues(v, w));
return Tuple(builder, {v, w});
static const char* kEighCustomCallName = "Eigh";
bool EighExpander::InstructionMatchesPattern(HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCustomCall &&
instruction->custom_call_target() == kEighCustomCallName;
StatusOr<HloInstruction*> EighExpander::ExpandInstruction(
HloInstruction* instruction) {
const string name =
absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
HloModule* module = instruction->parent()->parent();
HloComputation*& computation =
computation_cache_.emplace(name, nullptr).first->second;
if (!computation) {
// Builds a new expansion.
// TODO(b/62327888): We do something unusual here: we build the computation
// using the XlaBuilder API, which is nominally an XLA client API. We do
// this because the external APIs for building complicated computations
// (XlaBuilder) are much more ergonomic than the internal ones. As it turns
// out, XlaBuilder isn't really a client API—what it does is build a
// HloModuleProto protocol buffer, that we can then deserialize and clone
// into our HloModule. Ideally we would avoid the protocol buffer step;
// that is left as an exercise for future work.
XlaBuilder builder(name);
TF_RET_CHECK(instruction->operand_count() == 1);
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
std::vector<std::string> config_strs =
absl::StrSplit(instruction->raw_backend_config_string(), ',');
int lower;
int64 max_iter;
float tol;
if (config_strs.size() != 3 || !absl::SimpleAtoi(config_strs[0], &lower) ||
!absl::SimpleAtoi(config_strs[1], &max_iter) ||
!absl::SimpleAtof(config_strs[2], &tol)) {
return Internal("Unable to parse arguments to Eigh custom call, got: %s",
XlaOp result = BuildEigh(a, lower, max_iter, tol);
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
instruction->shape(), instruction->operands(), computation));
} // namespace xla