[XLA] Add support for complex numbers to Qr decomposition expander.

PiperOrigin-RevId: 333208193
Change-Id: Ic9adc699a11ffcc23a0ae518b54ee29cce8569ce
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index f396e61..d2d6fe7 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -74,8 +74,14 @@
 
   def _test(self, dtype, shape, full_matrices):
     np.random.seed(1)
-    x_np = np.random.uniform(
-        low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+    def rng():
+      return np.random.uniform(
+          low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+    x_np = rng()
+    if np.issubdtype(dtype, np.complexfloating):
+      x_np += rng() * dtype(1j)
 
     with self.session() as sess:
       x_tf = array_ops.placeholder(dtype)
@@ -102,7 +108,7 @@
       self.CheckUnitary(q_tf_val)
 
   SIZES = [1, 2, 5, 10, 32, 100, 300]
-  DTYPES = [np.float32]
+  DTYPES = [np.float32, np.complex64]
   PARAMS = itertools.product(SIZES, SIZES, DTYPES)
 
   @parameterized.parameters(*PARAMS)
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
index 66ec40a..7aebb76 100644
--- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -41,7 +41,7 @@
   bool full_matrices_;
 };
 
-REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
+REGISTER_XLA_OP(Name("Qr"), QROp);
 
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/xla/service/qr_expander.cc b/tensorflow/compiler/xla/service/qr_expander.cc
index 8bc06f9..d1b1526 100644
--- a/tensorflow/compiler/xla/service/qr_expander.cc
+++ b/tensorflow/compiler/xla/service/qr_expander.cc
@@ -63,13 +63,16 @@
 //   x_copy = np.copy(x)
 //   x_copy[:k+1] = 0
 //   xnorm = norm2(x_copy)
-//   if xnorm == 0:
+//   if xnorm == 0 and np.imag(alpha) == 0:
 //     beta = alpha
 //     tau = 0
 //     v = np.zeros_like(x)
 //   else:
-//     beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
-//     tau = (beta - alpha) / beta
+//     beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
+//     if np.issubdtype(x.dtype, np.complexfloating):
+//       tau = (beta - alpha) / beta
+//     else:
+//       tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
 //     v = x / (alpha - beta)
 //   v[k] = 1
 //   return (v, tau, beta)
@@ -86,7 +89,6 @@
   const int64 minor_dim = batch_dims.size();
 
   XlaOp zero = ScalarLike(x, 0.0);
-  XlaOp one = ScalarLike(x, 1.0);
 
   // alpha = x[k]
   XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
@@ -96,20 +98,46 @@
   XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
                         /*broadcast_dimensions=*/{minor_dim});
 
-  // sigma = np.dot(x[k+1:], x[k+1:])
-  // TODO(phawkins): this calculation may be numerically unstable.
-  auto sigma = Reduce(x_after_k * x_after_k, zero,
-                      CreateScalarAddComputation(type, builder), {minor_dim});
-  // mu = np.sqrt(x[k]*x[k] + sigma)
-  auto mu = Sqrt(Square(alpha) + sigma);
+  XlaOp sigma_is_zero;
+  if (primitive_util::IsComplexType(type)) {
+    // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
+    // TODO(phawkins): this calculation may be numerically unstable.
+    auto x_squared = Real(x_after_k * Conj(x_after_k));
+    auto sigma =
+        Reduce(x_squared, ScalarLike(x_squared, 0.0),
+               CreateScalarAddComputation(
+                   primitive_util::ComplexComponentType(type), builder),
+               {minor_dim});
+    // mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
+    auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
 
-  auto sigma_is_zero = Eq(sigma, zero);
+    sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
+    sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
 
-  *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu);
-  *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
-                (*beta - alpha) / *beta);
+    *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
+                   ScalarLike(mu, -1)) *
+            mu;
+    *beta = Select(sigma_is_zero, Real(alpha), *beta);
+    *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
+  } else {
+    // sigma = np.dot(x[k+1:], x[k+1:])
+    // TODO(phawkins): this calculation may be numerically unstable.
+    auto sigma = Reduce(x_after_k * x_after_k, zero,
+                        CreateScalarAddComputation(type, builder), {minor_dim});
+    // mu = np.sqrt(x[k]*x[k] + sigma)
+    auto mu = Sqrt(Square(alpha) + sigma);
+    sigma_is_zero = Eq(sigma, zero);
+
+    XlaOp one = ScalarLike(x, 1.0);
+    *beta = Select(Lt(alpha, zero), one, -one) * mu;
+    *beta = Select(sigma_is_zero, alpha, *beta);
+    *tau = (*beta - alpha) / *beta;
+  }
+  *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
+
   auto divisor =
-      Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
+      Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
+             alpha - ConvertElementType(*beta, type));
 
   auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
                        std::vector<int64>(batch_dims.size(), 1));
@@ -136,8 +164,8 @@
 //   taus = np.zeros([n])
 //   for j in xrange(min(m, n)):
 //     v, tau, beta = house(a[:, j], j)
-//     a[:, j+1:] -= tau * np.dot(v[:, np.newaxis],
-//                                np.dot(v[np.newaxis, :], a[:, j+1:]))
+//     a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
+//                                np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
 //     # Form column j explicitly rather than relying on the precision of the
 //     # Householder update.
 //     a[j, j] = beta
@@ -187,13 +215,14 @@
     shape.push_back(1);
     shape.push_back(m);
     auto v_broadcast = Reshape(v, shape);
-    // a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:]))
+    // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
+    //     (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
     // We use masking rather than a loop-variant shape to handle the j+1:
     // indexing.
-    auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)),
-                        precision);
+    auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
+                        Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
     vva = BatchDot(v_broadcast, true, vva, false, precision);
-    a = a - Mul(tau, vva,
+    a = a - Mul(MaybeConjugate(tau, true), vva,
                 /*broadcast_dimensions=*/batch_dim_indices);
 
     // a[j, j] = beta
@@ -205,7 +234,8 @@
     auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
     auto new_x = Mul(x, predecessor_mask,
                      /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
-                 Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
+                 Mul(ConvertElementType(beta, type), mask,
+                     /*broadcast_dimensions=*/batch_dim_indices);
     new_x = Add(
         new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
         /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
@@ -257,7 +287,7 @@
 //   t = np.eye(n) * -taus
 //   # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
 //   # multiplication to many matrix-vector products.
-//   vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n)
+//   vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
 //   for i in range(1, n):
 //     t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
 //   return t
@@ -293,8 +323,8 @@
   auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
   auto t = eye;
 
-  auto vtv =
-      BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision);
+  auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
+                      /*transpose_y=*/false, precision);
   vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
   vtv = (vtv + eye) * tau_scale;
 
@@ -313,8 +343,8 @@
 //     (a, taus) = qr(a[i:, i:i+k])
 //     y = np.eye(m, n) + np.tril(a, -1)
 //     t = CompactWYRepresentation(vs, taus, m-i, k)
-//     a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
-//     q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
+//     a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
+//     q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
 //   return (q, a)
 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
     XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
@@ -361,21 +391,23 @@
         auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
                                         m - i, k, precision));
 
-    // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
-    auto yt =
-        BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
+    // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
+    auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
+                       /*transpose_y=*/true, precision);
     auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
-    auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
-                             /*transpose_y=*/false, precision);
+    auto a_update =
+        BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
+                 /*transpose_y=*/false, precision);
     a_update = BatchDot(yt, a_update, precision);
     a_panel = a_panel + a_update;
     a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
 
-    // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
+    // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
     auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
     auto q_update = BatchDot(q_panel, y, precision);
-    q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
-                        /*transpose_y=*/true, precision);
+    q_update =
+        BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
+                 /*transpose_y=*/true, precision);
     q_panel = q_panel + q_update;
     q = UpdateSliceInMinorDims(q, q_panel, {0, i});
   }