[XLA] Add support for complex numbers to Qr decomposition expander.
PiperOrigin-RevId: 333208193
Change-Id: Ic9adc699a11ffcc23a0ae518b54ee29cce8569ce
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index f396e61..d2d6fe7 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -74,8 +74,14 @@
def _test(self, dtype, shape, full_matrices):
np.random.seed(1)
- x_np = np.random.uniform(
- low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+ def rng():
+ return np.random.uniform(
+ low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+ x_np = rng()
+ if np.issubdtype(dtype, np.complexfloating):
+ x_np += rng() * dtype(1j)
with self.session() as sess:
x_tf = array_ops.placeholder(dtype)
@@ -102,7 +108,7 @@
self.CheckUnitary(q_tf_val)
SIZES = [1, 2, 5, 10, 32, 100, 300]
- DTYPES = [np.float32]
+ DTYPES = [np.float32, np.complex64]
PARAMS = itertools.product(SIZES, SIZES, DTYPES)
@parameterized.parameters(*PARAMS)
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
index 66ec40a..7aebb76 100644
--- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -41,7 +41,7 @@
bool full_matrices_;
};
-REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
+REGISTER_XLA_OP(Name("Qr"), QROp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/service/qr_expander.cc b/tensorflow/compiler/xla/service/qr_expander.cc
index 8bc06f9..d1b1526 100644
--- a/tensorflow/compiler/xla/service/qr_expander.cc
+++ b/tensorflow/compiler/xla/service/qr_expander.cc
@@ -63,13 +63,16 @@
// x_copy = np.copy(x)
// x_copy[:k+1] = 0
// xnorm = norm2(x_copy)
-// if xnorm == 0:
+// if xnorm == 0 and np.imag(alpha) == 0:
// beta = alpha
// tau = 0
// v = np.zeros_like(x)
// else:
-// beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
-// tau = (beta - alpha) / beta
+// beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
+// if np.issubdtype(x.dtype, np.complexfloating):
+// tau = (beta - alpha) / beta
+// else:
+// tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
// v = x / (alpha - beta)
// v[k] = 1
// return (v, tau, beta)
@@ -86,7 +89,6 @@
const int64 minor_dim = batch_dims.size();
XlaOp zero = ScalarLike(x, 0.0);
- XlaOp one = ScalarLike(x, 1.0);
// alpha = x[k]
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
@@ -96,20 +98,46 @@
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
/*broadcast_dimensions=*/{minor_dim});
- // sigma = np.dot(x[k+1:], x[k+1:])
- // TODO(phawkins): this calculation may be numerically unstable.
- auto sigma = Reduce(x_after_k * x_after_k, zero,
- CreateScalarAddComputation(type, builder), {minor_dim});
- // mu = np.sqrt(x[k]*x[k] + sigma)
- auto mu = Sqrt(Square(alpha) + sigma);
+ XlaOp sigma_is_zero;
+ if (primitive_util::IsComplexType(type)) {
+ // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
+ // TODO(phawkins): this calculation may be numerically unstable.
+ auto x_squared = Real(x_after_k * Conj(x_after_k));
+ auto sigma =
+ Reduce(x_squared, ScalarLike(x_squared, 0.0),
+ CreateScalarAddComputation(
+ primitive_util::ComplexComponentType(type), builder),
+ {minor_dim});
+ // mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
+ auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
- auto sigma_is_zero = Eq(sigma, zero);
+ sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
+ sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
- *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu);
- *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
- (*beta - alpha) / *beta);
+ *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
+ ScalarLike(mu, -1)) *
+ mu;
+ *beta = Select(sigma_is_zero, Real(alpha), *beta);
+ *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
+ } else {
+ // sigma = np.dot(x[k+1:], x[k+1:])
+ // TODO(phawkins): this calculation may be numerically unstable.
+ auto sigma = Reduce(x_after_k * x_after_k, zero,
+ CreateScalarAddComputation(type, builder), {minor_dim});
+ // mu = np.sqrt(x[k]*x[k] + sigma)
+ auto mu = Sqrt(Square(alpha) + sigma);
+ sigma_is_zero = Eq(sigma, zero);
+
+ XlaOp one = ScalarLike(x, 1.0);
+ *beta = Select(Lt(alpha, zero), one, -one) * mu;
+ *beta = Select(sigma_is_zero, alpha, *beta);
+ *tau = (*beta - alpha) / *beta;
+ }
+ *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
+
auto divisor =
- Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
+ Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
+ alpha - ConvertElementType(*beta, type));
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
std::vector<int64>(batch_dims.size(), 1));
@@ -136,8 +164,8 @@
// taus = np.zeros([n])
// for j in xrange(min(m, n)):
// v, tau, beta = house(a[:, j], j)
-// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis],
-// np.dot(v[np.newaxis, :], a[:, j+1:]))
+// a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
+// np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
// # Form column j explicitly rather than relying on the precision of the
// # Householder update.
// a[j, j] = beta
@@ -187,13 +215,14 @@
shape.push_back(1);
shape.push_back(m);
auto v_broadcast = Reshape(v, shape);
- // a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:]))
+ // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
+ // (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
// We use masking rather than a loop-variant shape to handle the j+1:
// indexing.
- auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)),
- precision);
+ auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
+ Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
vva = BatchDot(v_broadcast, true, vva, false, precision);
- a = a - Mul(tau, vva,
+ a = a - Mul(MaybeConjugate(tau, true), vva,
/*broadcast_dimensions=*/batch_dim_indices);
// a[j, j] = beta
@@ -205,7 +234,8 @@
auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
auto new_x = Mul(x, predecessor_mask,
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
- Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
+ Mul(ConvertElementType(beta, type), mask,
+ /*broadcast_dimensions=*/batch_dim_indices);
new_x = Add(
new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
@@ -257,7 +287,7 @@
// t = np.eye(n) * -taus
// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
// # multiplication to many matrix-vector products.
-// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n)
+// vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
// for i in range(1, n):
// t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
// return t
@@ -293,8 +323,8 @@
auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
auto t = eye;
- auto vtv =
- BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision);
+ auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
+ /*transpose_y=*/false, precision);
vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
vtv = (vtv + eye) * tau_scale;
@@ -313,8 +343,8 @@
// (a, taus) = qr(a[i:, i:i+k])
// y = np.eye(m, n) + np.tril(a, -1)
// t = CompactWYRepresentation(vs, taus, m-i, k)
-// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
-// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
+// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
+// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
// return (q, a)
StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
@@ -361,21 +391,23 @@
auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
m - i, k, precision));
- // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
- auto yt =
- BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
+ // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
+ auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
+ /*transpose_y=*/true, precision);
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
- auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
- /*transpose_y=*/false, precision);
+ auto a_update =
+ BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
+ /*transpose_y=*/false, precision);
a_update = BatchDot(yt, a_update, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
- // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
+ // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
auto q_update = BatchDot(q_panel, y, precision);
- q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
- /*transpose_y=*/true, precision);
+ q_update =
+ BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
+ /*transpose_y=*/true, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}