Rewrote the implementation of the complex sqrt and rsqrt methods.

The old implementations of sqrt and rsqrt just called the pow function, which
was very inefficient.

PiperOrigin-RevId: 263180636
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2b90d77..c2d5ffc 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -776,18 +776,10 @@
                              FDiv(EmitExtractImag(operand_value), cplx_abs)));
     }
     case HloOpcode::kSqrt: {
-      auto a = EmitExtractReal(operand_value);
-      auto b = EmitExtractImag(operand_value);
-      auto c = llvm::ConstantFP::get(a->getType(), 0.5);
-      auto d = llvm::ConstantFP::get(b->getType(), 0.0);
-      return EmitComplexPower(op, a, b, c, d);
+      return EmitComplexSqrt(op, component_type, operand_value);
     }
     case HloOpcode::kRsqrt: {
-      auto a = EmitExtractReal(operand_value);
-      auto b = EmitExtractImag(operand_value);
-      auto c = llvm::ConstantFP::get(a->getType(), -0.5);
-      auto d = llvm::ConstantFP::get(b->getType(), 0.0);
-      return EmitComplexPower(op, a, b, c, d);
+      return EmitComplexRsqrt(op, component_type, operand_value);
     }
     case HloOpcode::kNegate:
       return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
@@ -878,9 +870,18 @@
 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
 //                 = |a| * sqrt(1 + (b/a)^2)
-// With the assumption that |a| >= |b|
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
-    PrimitiveType prim_type, llvm::Value* operand_value) {
+// With the assumption that |a| >= |b|.
+//
+// This method returns the min, max, and sqrt term for this calculation. This is
+// done to prevent potential overflow errors that can occur from multiplying the
+// max with the sqrt term. (i.e. when calculating the sqrt of the absolute
+// value, we can take the sqrt of the max and the sqrt term before multiplying
+// them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
+// sqrt(1 + (b/a)^2).
+StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
+ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
+                                         llvm::Value* operand_value,
+                                         bool return_sqrt) {
   llvm::Value* real = EmitExtractReal(operand_value);
   llvm::Value* imag = EmitExtractImag(operand_value);
   llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
@@ -893,15 +894,187 @@
   llvm::Value* div = FDiv(min, max);
   llvm::Value* div_sq = FMul(div, div);
   llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
-  TF_ASSIGN_OR_RETURN(llvm::Value * sqrt,
-                      EmitSqrt(prim_type, FAdd(one, div_sq)));
+  llvm::Value* one_p_div_sq = FAdd(one, div_sq);
+  TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
+  return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
+}
 
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
+    PrimitiveType prim_type, llvm::Value* operand_value) {
+  llvm::Value* min;
+  llvm::Value* max;
+  llvm::Value* sqrt;
+  TF_ASSIGN_OR_RETURN(
+      std::tie(min, max, sqrt),
+      EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
   llvm::Value* result = FMul(max, sqrt);
-  // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), result is NaN.
-  // In such cases, we return min.
+  // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+  // In such cases, we return `min` instead of `result`.
   return Select(FCmpUNO(result, result), min, result);
 }
 
+// Calculates ComplexAbs in the same way, except using:
+// sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
+    PrimitiveType prim_type, llvm::Value* operand_value) {
+  llvm::Value* min;
+  llvm::Value* max;
+  llvm::Value* one_p_div_sq;
+  TF_ASSIGN_OR_RETURN(
+      std::tie(min, max, one_p_div_sq),
+      EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
+  TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
+  TF_ASSIGN_OR_RETURN(llvm::Value * pow,
+                      EmitPow(prim_type, one_p_div_sq,
+                              llvm::ConstantFP::get(max->getType(), .25)));
+  llvm::Value* result = FMul(sqrt_max, pow);
+  // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+  // In such cases, we return `min` instead of `result`.
+  return Select(FCmpUNO(result, result), min, result);
+}
+
+// Calculates ComplexAbs in the same way, except using:
+// rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
+    PrimitiveType prim_type, llvm::Value* operand_value) {
+  llvm::Value* min;
+  llvm::Value* max;
+  llvm::Value* sqrt;
+  TF_ASSIGN_OR_RETURN(
+      std::tie(min, max, sqrt),
+      EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
+  TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
+  TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
+  llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
+  TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
+  // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+  // In such cases, we return rsqrt(min) instead of `result`.
+  return Select(FCmpUNO(result, result), rsqrt_min, result);
+}
+
+// Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
+//   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
+// = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
+// = r^0.5 * [cos(t/2) + i*sin(t/2)]
+// = sqrt(r) * [cos(t/2) + i*sin(t/2)]
+// where r = |a+bi| and t = atan2(b,a)
+// TODO(bixia): See doc for implementation without atan2.
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
+    const HloInstruction* op, PrimitiveType prim_type,
+    llvm::Value* operand_value) {
+  llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
+                         ->getElementType(0);
+
+  TF_ASSIGN_OR_RETURN(llvm::Value * r,
+                      EmitSqrtComplexAbs(prim_type, operand_value));
+
+  llvm::Value* a = EmitExtractReal(operand_value);
+  llvm::Value* b = EmitExtractImag(operand_value);
+  TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+
+  llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
+  llvm::Value* angle = FMul(t, c);
+  TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
+  TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
+
+  llvm::Value* real_part;
+  llvm::Value* imag_part;
+
+  llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+
+  if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
+    llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+    llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
+    llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+    llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+                                                      {b}, {b->getType()}, b_);
+
+    real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
+                       Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
+                              zero, FMul(r, cos)));
+
+    llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
+        llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
+    imag_part =
+        Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
+               Select(FCmpUNO(r, r), nan,
+                      Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
+  } else {
+    real_part = FMul(r, cos);
+    imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
+  }
+
+  return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
+                EmitComposeComplex(op, real_part, imag_part));
+}
+
+// Similar to Sqrt, we can use our EmitComplexPower formula, but set
+// c=-0.5 and d=0. We get:
+//   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
+// = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
+// = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
+// = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
+// where r = |a+bi| and t = atan2(b,a).
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
+    const HloInstruction* op, PrimitiveType prim_type,
+    llvm::Value* operand_value) {
+  llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
+                         ->getElementType(0);
+
+  TF_ASSIGN_OR_RETURN(llvm::Value * r,
+                      EmitRsqrtComplexAbs(prim_type, operand_value));
+
+  llvm::Value* a = EmitExtractReal(operand_value);
+  llvm::Value* b = EmitExtractImag(operand_value);
+  TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+
+  llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
+  llvm::Value* angle = FMul(t, c);
+  TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
+  TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
+
+  llvm::Value* real_part = FMul(r, cos);
+  llvm::Value* imag_part = FMul(r, sin);
+
+  if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
+    llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+    llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
+    llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+    llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+    // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
+    llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
+        llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
+    llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
+        llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
+    llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
+
+    llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+                                                      {a}, {a->getType()}, b_);
+    llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+                                                      {b}, {b->getType()}, b_);
+
+    llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
+    real_part = Select(
+        is_zero_zero, inf,
+        Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
+               a_signed_zero, FMul(r, cos)));
+    imag_part = Select(
+        is_zero_zero, nan,
+        Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
+               neg_b_signed_zero, FMul(r, sin)));
+  } else {
+    llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+    llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+    llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+
+    llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
+    real_part = Select(is_zero_zero, inf, FMul(r, cos));
+    imag_part = Select(is_zero_zero, nan, FMul(r, sin));
+  }
+
+  return EmitComposeComplex(op, real_part, imag_part);
+}
+
 // (a+bi)^(c+di) =
 //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
 //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
@@ -1149,7 +1322,7 @@
   return Select(x_is_small, for_small_x, for_large_x);
 }
 
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
                                                     llvm::Value* value) {
   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
                                       {value->getType()}, b_);
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 3ba669c..99833a5 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -143,9 +143,26 @@
   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
                                                      llvm::Value* x);
 
+  virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
+  EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value,
+                       bool return_sqrt);
+
   virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
                                                 llvm::Value* operand_value);
 
+  virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type,
+                                                    llvm::Value* operand_value);
+  virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs(
+      PrimitiveType prim_type, llvm::Value* operand_value);
+
+  virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op,
+                                                 PrimitiveType prim_type,
+                                                 llvm::Value* operand_value);
+
+  virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
+                                                  PrimitiveType prim_type,
+                                                  llvm::Value* operand_value);
+
   virtual llvm::Value* EmitExtractReal(llvm::Value* value);
   virtual llvm::Value* EmitExtractImag(llvm::Value* value);
 
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
index fe54768..5014aa9 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -250,6 +250,16 @@
   }
 
   template <class... Args>
+  llvm::Value* FCmpOGT(Args&&... args) {
+    return mixin_builder()->CreateFCmpOGT(std::forward<Args>(args)...);
+  }
+
+  template <class... Args>
+  llvm::Value* FCmpOGE(Args&&... args) {
+    return mixin_builder()->CreateFCmpOGE(std::forward<Args>(args)...);
+  }
+
+  template <class... Args>
   llvm::Value* FCmpOLT(Args&&... args) {
     return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
   }
diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
index f8eb738..ea6ecd1 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
@@ -692,7 +692,7 @@
     this->known_incorrect_fn_ = [&](int64 v) {
       double f = this->ConvertValue(v);
       return (T == C128 &&
-              std::abs(f) > std::numeric_limits<float>::max() / 2) ||
+              std::abs(f) > std::numeric_limits<double>::max() / 2) ||
              f == -std::numeric_limits<double>::infinity();
     };
   }
@@ -738,6 +738,20 @@
   Run(Log, [](complex64 x) { return std::log<float>(x); });
 }
 
+XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) {
+  Run(Sqrt, [](complex64 x) {
+    return static_cast<complex64>(
+        std::sqrt<double>(static_cast<complex128>(x)));
+  });
+}
+
+XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) {
+  Run(Rsqrt, [](complex64 x) {
+    return static_cast<complex64>(
+        complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
+  });
+}
+
 // The current libc++ implementation of the complex tanh function provides
 // less accurate results when the denomenator of a complex tanh is small, due
 // to floating point precision loss. To avoid this issue for complex64 numbers,
@@ -807,6 +821,35 @@
   Run(Log, [](complex128 x) { return std::log<double>(x); });
 }
 
+XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) {
+  // Similar to the Tanh bug.
+  known_incorrect_fn_ = [&](int64 v) {
+    double f = this->ConvertValue(v);
+    return std::abs(f) > std::numeric_limits<double>::max() / 2;
+  };
+  Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
+}
+
+XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) {
+  ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+  if (platform_ == "CUDA") {
+    // Edge case on CUDA backend where the Log of a complex number made up of
+    // the smallest denormals is more accurate than the interpreter backend.
+    error_spec_gen = [](complex128 x) {
+      constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
+      if (std::abs(x.real()) == denorm_min &&
+          std::abs(x.imag()) == denorm_min) {
+        return ErrorSpec(0.5, 0.5);
+      }
+      return GetDefaultSpecGenerator()(x);
+    };
+  }
+  Run(
+      Rsqrt,
+      [](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
+      error_spec_gen);
+}
+
 XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) {
   SetParamsForTanh();
   Run(