Rewrote the implementation of the complex sqrt and rsqrt methods.
The old implementations of sqrt and rsqrt just called the pow function, which
was very inefficient.
PiperOrigin-RevId: 263180636
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2b90d77..c2d5ffc 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -776,18 +776,10 @@
FDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kSqrt: {
- auto a = EmitExtractReal(operand_value);
- auto b = EmitExtractImag(operand_value);
- auto c = llvm::ConstantFP::get(a->getType(), 0.5);
- auto d = llvm::ConstantFP::get(b->getType(), 0.0);
- return EmitComplexPower(op, a, b, c, d);
+ return EmitComplexSqrt(op, component_type, operand_value);
}
case HloOpcode::kRsqrt: {
- auto a = EmitExtractReal(operand_value);
- auto b = EmitExtractImag(operand_value);
- auto c = llvm::ConstantFP::get(a->getType(), -0.5);
- auto d = llvm::ConstantFP::get(b->getType(), 0.0);
- return EmitComplexPower(op, a, b, c, d);
+ return EmitComplexRsqrt(op, component_type, operand_value);
}
case HloOpcode::kNegate:
return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
@@ -878,9 +870,18 @@
// Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
// sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
// = |a| * sqrt(1 + (b/a)^2)
-// With the assumption that |a| >= |b|
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
- PrimitiveType prim_type, llvm::Value* operand_value) {
+// With the assumption that |a| >= |b|.
+//
+// This method returns the min, max, and sqrt term for this calculation. This is
+// done to prevent potential overflow errors that can occur from multiplying the
+// max with the sqrt term. (i.e. when calculating the sqrt of the absolute
+// value, we can take the sqrt of the max and the sqrt term before multiplying
+// them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
+// sqrt(1 + (b/a)^2).
+StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
+ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
+ llvm::Value* operand_value,
+ bool return_sqrt) {
llvm::Value* real = EmitExtractReal(operand_value);
llvm::Value* imag = EmitExtractImag(operand_value);
llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
@@ -893,15 +894,187 @@
llvm::Value* div = FDiv(min, max);
llvm::Value* div_sq = FMul(div, div);
llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
- TF_ASSIGN_OR_RETURN(llvm::Value * sqrt,
- EmitSqrt(prim_type, FAdd(one, div_sq)));
+ llvm::Value* one_p_div_sq = FAdd(one, div_sq);
+ TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
+ return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
+}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
+ PrimitiveType prim_type, llvm::Value* operand_value) {
+ llvm::Value* min;
+ llvm::Value* max;
+ llvm::Value* sqrt;
+ TF_ASSIGN_OR_RETURN(
+ std::tie(min, max, sqrt),
+ EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
llvm::Value* result = FMul(max, sqrt);
- // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), result is NaN.
- // In such cases, we return min.
+ // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+ // In such cases, we return `min` instead of `result`.
return Select(FCmpUNO(result, result), min, result);
}
+// Calculates ComplexAbs in the same way, except using:
+// sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
+ PrimitiveType prim_type, llvm::Value* operand_value) {
+ llvm::Value* min;
+ llvm::Value* max;
+ llvm::Value* one_p_div_sq;
+ TF_ASSIGN_OR_RETURN(
+ std::tie(min, max, one_p_div_sq),
+ EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
+ TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
+ TF_ASSIGN_OR_RETURN(llvm::Value * pow,
+ EmitPow(prim_type, one_p_div_sq,
+ llvm::ConstantFP::get(max->getType(), .25)));
+ llvm::Value* result = FMul(sqrt_max, pow);
+ // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+ // In such cases, we return `min` instead of `result`.
+ return Select(FCmpUNO(result, result), min, result);
+}
+
+// Calculates ComplexAbs in the same way, except using:
+// rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
+ PrimitiveType prim_type, llvm::Value* operand_value) {
+ llvm::Value* min;
+ llvm::Value* max;
+ llvm::Value* sqrt;
+ TF_ASSIGN_OR_RETURN(
+ std::tie(min, max, sqrt),
+ EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
+ TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
+ TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
+ llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
+ TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
+ // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
+ // In such cases, we return rsqrt(min) instead of `result`.
+ return Select(FCmpUNO(result, result), rsqrt_min, result);
+}
+
+// Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
+// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
+// = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
+// = r^0.5 * [cos(t/2) + i*sin(t/2)]
+// = sqrt(r) * [cos(t/2) + i*sin(t/2)]
+// where r = |a+bi| and t = atan2(b,a)
+// TODO(bixia): See doc for implementation without atan2.
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
+ const HloInstruction* op, PrimitiveType prim_type,
+ llvm::Value* operand_value) {
+ llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
+ ->getElementType(0);
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * r,
+ EmitSqrtComplexAbs(prim_type, operand_value));
+
+ llvm::Value* a = EmitExtractReal(operand_value);
+ llvm::Value* b = EmitExtractImag(operand_value);
+ TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+
+ llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
+ llvm::Value* angle = FMul(t, c);
+ TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
+ TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
+
+ llvm::Value* real_part;
+ llvm::Value* imag_part;
+
+ llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+
+ if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
+ llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+ llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
+ llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+ llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+ {b}, {b->getType()}, b_);
+
+ real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
+ Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
+ zero, FMul(r, cos)));
+
+ llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
+ imag_part =
+ Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
+ Select(FCmpUNO(r, r), nan,
+ Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
+ } else {
+ real_part = FMul(r, cos);
+ imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
+ }
+
+ return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
+ EmitComposeComplex(op, real_part, imag_part));
+}
+
+// Similar to Sqrt, we can use our EmitComplexPower formula, but set
+// c=-0.5 and d=0. We get:
+// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
+// = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
+// = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
+// = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
+// where r = |a+bi| and t = atan2(b,a).
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
+ const HloInstruction* op, PrimitiveType prim_type,
+ llvm::Value* operand_value) {
+ llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
+ ->getElementType(0);
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * r,
+ EmitRsqrtComplexAbs(prim_type, operand_value));
+
+ llvm::Value* a = EmitExtractReal(operand_value);
+ llvm::Value* b = EmitExtractImag(operand_value);
+ TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
+
+ llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
+ llvm::Value* angle = FMul(t, c);
+ TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
+ TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
+
+ llvm::Value* real_part = FMul(r, cos);
+ llvm::Value* imag_part = FMul(r, sin);
+
+ if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
+ llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+ llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
+ llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+ llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+ // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
+ llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
+ llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
+ llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
+
+ llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+ {a}, {a->getType()}, b_);
+ llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+ {b}, {b->getType()}, b_);
+
+ llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
+ real_part = Select(
+ is_zero_zero, inf,
+ Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
+ a_signed_zero, FMul(r, cos)));
+ imag_part = Select(
+ is_zero_zero, nan,
+ Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
+ neg_b_signed_zero, FMul(r, sin)));
+ } else {
+ llvm::Value* zero = llvm::ConstantFP::get(type, 0);
+ llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
+ llvm::Value* nan = llvm::ConstantFP::getNaN(type);
+
+ llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
+ real_part = Select(is_zero_zero, inf, FMul(r, cos));
+ imag_part = Select(is_zero_zero, nan, FMul(r, sin));
+ }
+
+ return EmitComposeComplex(op, real_part, imag_part);
+}
+
// (a+bi)^(c+di) =
// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
@@ -1149,7 +1322,7 @@
return Select(x_is_small, for_small_x, for_large_x);
}
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
{value->getType()}, b_);
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 3ba669c..99833a5 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -143,9 +143,26 @@
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x);
+ virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
+ EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value,
+ bool return_sqrt);
+
virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
llvm::Value* operand_value);
+ virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type,
+ llvm::Value* operand_value);
+ virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs(
+ PrimitiveType prim_type, llvm::Value* operand_value);
+
+ virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op,
+ PrimitiveType prim_type,
+ llvm::Value* operand_value);
+
+ virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
+ PrimitiveType prim_type,
+ llvm::Value* operand_value);
+
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
index fe54768..5014aa9 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -250,6 +250,16 @@
}
template <class... Args>
+ llvm::Value* FCmpOGT(Args&&... args) {
+ return mixin_builder()->CreateFCmpOGT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOGE(Args&&... args) {
+ return mixin_builder()->CreateFCmpOGE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
llvm::Value* FCmpOLT(Args&&... args) {
return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
}
diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
index f8eb738..ea6ecd1 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
@@ -692,7 +692,7 @@
this->known_incorrect_fn_ = [&](int64 v) {
double f = this->ConvertValue(v);
return (T == C128 &&
- std::abs(f) > std::numeric_limits<float>::max() / 2) ||
+ std::abs(f) > std::numeric_limits<double>::max() / 2) ||
f == -std::numeric_limits<double>::infinity();
};
}
@@ -738,6 +738,20 @@
Run(Log, [](complex64 x) { return std::log<float>(x); });
}
+XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) {
+ Run(Sqrt, [](complex64 x) {
+ return static_cast<complex64>(
+ std::sqrt<double>(static_cast<complex128>(x)));
+ });
+}
+
+XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) {
+ Run(Rsqrt, [](complex64 x) {
+ return static_cast<complex64>(
+ complex128(1, 0) / std::sqrt<double>(static_cast<complex128>(x)));
+ });
+}
+
// The current libc++ implementation of the complex tanh function provides
// less accurate results when the denomenator of a complex tanh is small, due
// to floating point precision loss. To avoid this issue for complex64 numbers,
@@ -807,6 +821,35 @@
Run(Log, [](complex128 x) { return std::log<double>(x); });
}
+XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) {
+ // Similar to the Tanh bug.
+ known_incorrect_fn_ = [&](int64 v) {
+ double f = this->ConvertValue(v);
+ return std::abs(f) > std::numeric_limits<double>::max() / 2;
+ };
+ Run(Sqrt, [](complex128 x) { return std::sqrt<double>(x); });
+}
+
+XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) {
+ ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
+ if (platform_ == "CUDA") {
+ // Edge case on CUDA backend where the Log of a complex number made up of
+ // the smallest denormals is more accurate than the interpreter backend.
+ error_spec_gen = [](complex128 x) {
+ constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
+ if (std::abs(x.real()) == denorm_min &&
+ std::abs(x.imag()) == denorm_min) {
+ return ErrorSpec(0.5, 0.5);
+ }
+ return GetDefaultSpecGenerator()(x);
+ };
+ }
+ Run(
+ Rsqrt,
+ [](complex128 x) { return complex128(1, 0) / std::sqrt<double>(x); },
+ error_spec_gen);
+}
+
XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) {
SetParamsForTanh();
Run(