blob: 02a45a7b53518969e0ab1690f0cdb2e1205cd4ba [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/math.h"
#include <cmath>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
namespace {
// Evaluate the polynomial given `x` and coefficients in decreasing order.
template <typename FP>
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
static_assert(std::is_floating_point<FP>::value,
"Template-argument 'FP' must be a floating-point type");
XlaOp poly = ScalarLike(x, 0.0);
for (FP c : coefficients) {
poly = poly * x + ScalarLike(x, c);
}
return poly;
}
// Evaluate the chebyshev polynomial given `x` and coefficients in decreasing
// order.
template <typename FP>
XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) {
static_assert(std::is_floating_point<FP>::value,
"Template-argument 'FP' must be a floating-point type");
XlaOp b0 = ScalarLike(x, 0.0);
XlaOp b1 = ScalarLike(x, 0.0);
XlaOp b2 = ScalarLike(x, 0.0);
for (FP c : coefficients) {
b2 = b1;
b1 = b0;
b0 = x * b1 - b2 + ScalarLike(x, c);
}
return ScalarLike(x, 0.5) * (b0 - b2);
}
} // namespace
// Returns operation(operand), except if `operand` is one of the types in
// upcast_types, in which case first converts it to F32, and then converts the
// result down to the original type.
static XlaOp DoWithUpcastToF32(XlaOp operand,
absl::Span<const PrimitiveType> upcast_types,
const std::function<XlaOp(XlaOp)>& operation) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
PrimitiveType elem_ty = shape.element_type();
bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
if (needs_upcast) {
operand = ConvertElementType(operand, F32);
}
XlaOp result = operation(operand);
if (needs_upcast) {
result = ConvertElementType(result, elem_ty);
}
return result;
});
}
// TODO(jlebar): Use this function in more places in this file to restrict the
// domain of other functions.
static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
auto& b = *operand.builder();
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
auto elem_ty = shape.element_type();
if (!primitive_util::IsFloatingPointType(elem_ty)) {
return InvalidArgument(
"Operands to %s must be real-valued floating-point, but got %s",
op_name, PrimitiveType_Name(elem_ty));
}
return ::tensorflow::OkStatus();
}
XlaOp IsPosInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
// Note that this is only correct for floating-point types. If we wanted it
// to be correct for all types, we'd need to Gt(MaxFiniteValue).
return Eq(operand, MaxValue(&b, shape.element_type()));
});
}
XlaOp IsNegInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
// Note that this is only correct for floating-point types. If we wanted it
// to be correct for all types, we'd need to Lt(MinFiniteValue).
return Eq(operand, MinValue(&b, shape.element_type()));
});
}
XlaOp IsInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
return IsPosInf(Abs(operand));
});
}
XlaOp IsNan(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
return Ne(operand, operand);
});
}
XlaOp IsNegZero(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
// The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
// (sign bit on, all other bits off).
switch (shape.element_type()) {
case F64:
return Eq(BitcastConvertType(operand, U64),
ConstantR0WithType(&b, U64, uint64_t{1} << 63));
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F16:
case BF16:
// Not all XLA backends handle U16 well, so we convert to F32/U32.
// TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
// backends that *do* support it.
return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
default:
LOG(FATAL) << "Expected real fp type.";
}
});
}
XlaOp Square(XlaOp operand) { return operand * operand; }
XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
// Computes an approximation of the error function complement (1 - erf(x)).
//
// Precondition: abs(x) >= 1. Otherwise, use ErfImpl.
//
// This follows Cephes's f32 implementation of erfc.
static XlaOp ErfcImpl32(XlaOp x) {
// Coefficients for erfc(f32), from Cephes.
const double kMaxlog = 88.72283905206835;
// erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2
static const std::array<float, 9> kErfcPCoefficient{
+2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
-5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
+3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
};
// erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog
static const std::array<float, 8> kErfcRCoefficient{
-1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
+2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
-2.820767439740514E-1, +5.641895067754075E-1,
};
XlaOp abs_x = Abs(x);
XlaOp z = Exp(-x * x);
XlaOp q = ScalarLike(x, 1) / abs_x;
XlaOp y = q * q;
XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
EvaluatePolynomial<float>(y, kErfcPCoefficient),
EvaluatePolynomial<float>(y, kErfcRCoefficient));
y = z * q * p;
XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
}
// Compute a polynomial approximation of the error function.
//
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
//
// This follows Cephes's f32 implementation of erf.
static XlaOp ErfImpl32Cephes(XlaOp x) {
// Coefficients for by erf(f32), from Cephes.
//
// erf(x) = x P(x^2), 0 < x < 1
static const std::array<float, 7> kErfTCoefficient{
+7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
-2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
+1.128379165726710E+0,
};
return x * EvaluatePolynomial<float>(x * x, kErfTCoefficient);
}
static XlaOp ErfcImpl64(XlaOp x) {
// Coefficients for erfc(f64), from Cephes.
const double kMaxlog = 7.09782712893383996843E2;
// erfc(x) = exp(-x^2) P(|x|) / Q(|x|), 1 < x < 8
static const std::array<double, 9> kErfcPCoefficient{
2.46196981473530512524E-10, 5.64189564831068821977E-1,
7.46321056442269912687E0, 4.86371970985681366614E1,
1.96520832956077098242E2, 5.26445194995477358631E2,
9.34528527171957607540E2, 1.02755188689515710272E3,
5.57535335369399327526E2};
static const std::array<double, 9> kErfcQCoefficient{
1.00000000000000000000E0, 1.32281951154744992508E1,
8.67072140885989742329E1, 3.54937778887819891062E2,
9.75708501743205489753E2, 1.82390916687909736289E3,
2.24633760818710981792E3, 1.65666309194161350182E3,
5.57535340817727675546E2};
// erfc(x) = exp(-x^2) R(|x|) / S(|x|), 8 <= x < kMaxlog
static const std::array<double, 6> kErfcRCoefficient{
5.64189583547755073984E-1, 1.27536670759978104416E0,
5.01905042251180477414E0, 6.16021097993053585195E0,
7.40974269950448939160E0, 2.97886665372100240670E0};
static const std::array<double, 7> kErfcSCoefficient{
1.00000000000000000000E0, 2.26052863220117276590E0,
9.39603524938001434673E0, 1.20489539808096656605E1,
1.70814450747565897222E1, 9.60896809063285878198E0,
3.36907645100081516050E0};
XlaOp z = -x * x;
XlaOp abs_x = Abs(x);
XlaOp y =
Select(Lt(abs_x, ScalarLike(x, 8.0)),
Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcPCoefficient) /
EvaluatePolynomial<double>(abs_x, kErfcQCoefficient),
Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcRCoefficient) /
EvaluatePolynomial<double>(abs_x, kErfcSCoefficient));
XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
}
// Compute a polynomial approximation of the error function.
//
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
static XlaOp ErfImpl64(XlaOp x) {
// Coefficients for by erf(f64), from Cephes.
//
// erf(x) = x T(x^2) / U(x^2), 0 < x < 1
static std::array<double, 5> kErfTCoefficient{
9.60497373987051638749E0, 9.00260197203842689217E1,
2.23200534594684319226E3, 7.00332514112805075473E3,
5.55923013010394962768E4};
static std::array<double, 6> kErfUCoefficient{
1.00000000000000000000E0, 3.35617141647503099647E1,
5.21357949780152679795E2, 4.59432382970980127987E3,
2.26290000613890934246E4, 4.92673942608635921086E4};
XlaOp z = x * x;
return x * EvaluatePolynomial<double>(z, kErfTCoefficient) /
EvaluatePolynomial<double>(z, kErfUCoefficient);
}
XlaOp Erfc(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
// erfc(x) =
// erfc_impl(x) if x > 1
// 1 - erf_impl(x) otherwise
if (shape.element_type() == F64) {
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl64(x),
ScalarLike(x, 1) - ErfImpl64(x));
}
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
ScalarLike(x, 1) - ErfImpl32Cephes(x));
});
});
}
// Compute a polynomial approximation of the error function.
// This is the same approximation used by Eigen.
static XlaOp ErfImpl32(XlaOp x) {
static const std::array<float, 7> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
static const std::array<float, 5> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
auto x2 = x * x;
return x * EvaluatePolynomial<float>(x2, kAlpha) /
EvaluatePolynomial<float>(x2, kBeta);
}
XlaOp Erf(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
// erf(x) =
// erf_impl(x) if x < 1
// 1 - erfc_impl(x) otherwise
if (shape.element_type() == F64) {
return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x),
ScalarLike(x, 1) - ErfcImpl64(x));
}
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return ErfImpl32(x); });
});
}
namespace {
// Approximation for the inverse error function from
// Giles, M., "Approximating the erfinv function".
// The approximation has the form:
// w = -log((1 - x) * (1 + x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
// } else {
// w = sqrt(w) - 3
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
XlaOp ErfInv32(XlaOp x) {
constexpr int kDegree = 9;
constexpr std::array<float, 9> w_less_than_5_constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
constexpr std::array<float, 9> w_greater_than_5_constants = {
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
// log(1+arg) when arg is close to zero. For more details, see
// https://en.cppreference.com/w/cpp/numeric/math/log1p
auto w = -Log1p(-x * x);
auto lt = Lt(w, ScalarLike(x, 5.0));
auto coefficient = [&](int i) {
return Select(lt, FullLike(x, w_less_than_5_constants[i]),
FullLike(x, w_greater_than_5_constants[i]));
};
w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
auto p = coefficient(0);
for (int i = 1; i < kDegree; ++i) {
p = coefficient(i) + p * w;
}
// Result modulo edge cases.
XlaOp result = p * x;
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
// indeterminate, and can give nan or -/+inf.)
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
return Select(Eq(Abs(x), ScalarLike(x, 1)),
x * MaxValue(&b, shape.element_type()), result);
});
}
XlaOp ErfInv64(XlaOp x) {
constexpr std::array<double, 23> w_less_than_6_25_constants = {
-3.6444120640178196996e-21, -1.685059138182016589e-19,
1.2858480715256400167e-18, 1.115787767802518096e-17,
-1.333171662854620906e-16, 2.0972767875968561637e-17,
6.6376381343583238325e-15, -4.0545662729752068639e-14,
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
1.051212273321532285e-09, -4.1126339803469836976e-09,
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
0.0001867342080340571352, -0.00074070253416626697512,
-0.0060336708714301490533, 0.24015818242558961693,
1.6536545626831027356};
constexpr std::array<double, 19> w_less_than_16_constants = {
2.2137376921775787049e-09, 9.0756561938885390979e-08,
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
1.5027403968909827627e-06, -4.013867526981545969e-06,
2.9234449089955446044e-06, 1.2475304481671778723e-05,
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
2.4031110387097893999e-05, -0.0003550375203628474796,
0.00095328937973738049703, -0.0016882755560235047313,
0.0024914420961078508066, -0.0037512085075692412107,
0.005370914553590063617, 1.0052589676941592334,
3.0838856104922207635,
};
constexpr std::array<double, 17> w_greater_than_16_constants = {
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
1.5076572693500548083e-09, -3.7894654401267369937e-09,
7.6157012080783393804e-09, -1.4960026627149240478e-08,
2.9147953450901080826e-08, -6.7711997758452339498e-08,
2.2900482228026654717e-07, -9.9298272942317002539e-07,
4.5260625972231537039e-06, -1.9681778105531670567e-05,
7.5995277030017761139e-05, -0.00021503011930044477347,
-0.00013871931833623122026, 1.0103004648645343977,
4.8499064014085844221,
};
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
// log(1+arg) when arg is close to zero. For more details, see
// https://en.cppreference.com/w/cpp/numeric/math/log1p
auto w = -Log1p(-x * x);
auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
auto lt_16 = Lt(w, ScalarLike(x, 16));
auto coefficient = [&](int i) {
auto c = FullLike(x, w_less_than_6_25_constants[i]);
if (i < 19) {
c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
}
if (i < 17) {
c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
}
return c;
};
auto sqrt_w = Sqrt(w);
w = Select(lt_6_25, w - ScalarLike(x, 3.125),
sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
auto p = coefficient(0);
for (int i = 1; i < 17; ++i) {
p = coefficient(i) + p * w;
}
for (int i = 17; i < 19; ++i) {
p = Select(lt_16, coefficient(i) + p * w, p);
}
for (int i = 19; i < 23; ++i) {
p = Select(lt_6_25, coefficient(i) + p * w, p);
}
// Result modulo edge cases.
XlaOp result = p * x;
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
// indeterminate, and can give nan or -/+inf.)
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
return Select(Eq(Abs(x), ScalarLike(x, 1)),
x * MaxValue(&b, shape.element_type()), result);
});
}
} // namespace
XlaOp ErfInv(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (shape.element_type() == F64) {
return ErfInv64(x);
}
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return ErfInv32(x); });
});
}
namespace {
// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
// 9] seemed to be the least sensitive to the quality of the log function. In
// particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
// for a particularly inaccurate log function.
static constexpr double kLanczosGamma = 7; // aka g
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
static constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
} // namespace
// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
// t(z) = z + kLanczosGamma + 1/2
// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
XlaOp Lgamma(XlaOp input) {
auto do_it = [](XlaOp input) {
XlaOp one_half = ScalarLike(input, 0.5);
XlaOp one = ScalarLike(input, 1);
XlaOp pi = ScalarLike(input, M_PI);
XlaOp log_pi = ScalarLike(input, std::log(M_PI));
XlaOp log_sqrt_two_pi =
ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
XlaOp log_lanczos_gamma_plus_one_half =
ScalarLike(input, std::log(kLanczosGamma + 0.5));
XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
// If the input is less than 0.5 use Euler's reflection formula:
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
XlaOp need_to_reflect = Lt(input, one_half);
XlaOp z = Select(need_to_reflect, -input, input - one);
XlaOp x = base_lanczos_coeff;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
XlaOp index = ScalarLike(input, i);
x = x + lanczos_coefficient / (z + index + one);
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
// the device.
// log(t) = log(kLanczosGamma + 0.5 + z)
// = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
XlaOp t = lanczos_gamma_plus_one_half + z;
XlaOp log_t = log_lanczos_gamma_plus_one_half +
Log1p(z / lanczos_gamma_plus_one_half);
// Compute the final result (modulo reflection). t(z) may be large, and we
// need to be careful not to overflow to infinity in the first term of
//
// (z + 1/2) * log(t(z)) - t(z).
//
// Therefore we compute this as
//
// (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
//
XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
// Compute the reflected value, used when x < 0.5:
//
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
//
// (The abs is because lgamma is the log of the absolute value of the gamma
// function.)
//
// We have to be careful when computing the final term above. gamma(x) goes
// to +/-inf at every integer x < 0, and this is controlled by the
// sin(pi * x) term. The slope is large, so precision is particularly
// important.
//
// Because abs(sin(pi * x)) has period 1, we can equivalently use
// abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
// is more numerically accurate: It doesn't overflow to inf like pi * x can,
// and if x is an integer, it evaluates to 0 exactly, which is significant
// because we then take the log of this value, and log(0) is inf.
//
// We don't have a frac(x) primitive in XLA and computing it is tricky, but
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
// our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
//
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
// [0, 1] is symmetric across the line Y=0.5.
//
XlaOp abs_input = Abs(input);
XlaOp abs_frac_input = abs_input - Floor(abs_input);
// Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
// precision of pi * abs_frac_input for values of abs_frac_input close to 1.
XlaOp reduced_frac_input =
Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
// then it "wins" and the result is +/-inf.
XlaOp reflection =
Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
-reflection_denom);
XlaOp result = Select(need_to_reflect, reflection, log_y);
// lgamma(+/-inf) = +inf.
XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
return Select(IsInf(input), inf_bcast, result);
};
auto& b = *input.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
// F16 and BF16 don't provide sufficient precision for intermediate results
// here (although it's better than you might expect!), so do the
// computations in F32.
return DoWithUpcastToF32(input, {BF16, F16}, do_it);
});
}
// Computes an approximation of the lbeta function which is equivalent to
// log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
static XlaOp Lbeta(XlaOp a, XlaOp b) {
// Beta(a, b) can be computed using Gamma as per
// http://dlmf.nist.gov/5.12.E1 as follows:
// Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
//
// To avoid overflow, we compute in the log domain.
//
// As per http://dlmf.nist.gov/4.8.E2 we can transform:
// Log(a * b)
// into:
// Log(a) + Log(b)
//
// Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
// Log(a - b)
// into:
// Log(a) - Log(b)
//
// This means that we can compute Log(Beta(a, b)) by:
// Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
}
// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
// t(z) = z + kLanczosGamma + 1/2
// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
XlaOp Digamma(XlaOp input) {
auto do_it = [](XlaOp input) {
XlaOp zero = ScalarLike(input, 0);
XlaOp one_half = ScalarLike(input, 0.5);
XlaOp one = ScalarLike(input, 1);
XlaOp pi = ScalarLike(input, M_PI);
XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
XlaOp log_lanczos_gamma_plus_one_half =
ScalarLike(input, std::log(kLanczosGamma + 0.5));
XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
// If the input is less than 0.5 use Euler's reflection formula:
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
XlaOp need_to_reflect = Lt(input, one_half);
XlaOp z = Select(need_to_reflect, -input, input - one);
XlaOp num = zero;
XlaOp denom = base_lanczos_coeff;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
XlaOp index = ScalarLike(input, i);
num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
denom = denom + lanczos_coefficient / (z + index + one);
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
// the device.
// log(t) = log(kLanczosGamma + 0.5 + z)
// = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
XlaOp t = lanczos_gamma_plus_one_half + z;
XlaOp log_t = log_lanczos_gamma_plus_one_half +
Log1p(z / lanczos_gamma_plus_one_half);
XlaOp y = log_t + num / denom - lanczos_gamma / t;
// We need to be careful how we compute cot(pi * input) below: For
// near-integral values of `input`, pi * input can lose precision.
//
// Input is already known to be less than 0.5 (otherwise we don't have to
// reflect). We shift values smaller than -0.5 into the range [-.5, .5] to
// increase precision of pi * input and the resulting cotangent.
XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
XlaOp reflection =
y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
XlaOp real_result = Select(need_to_reflect, reflection, y);
// Digamma has poles at negative integers and zero; return nan for those.
return Select(And(Le(input, zero), Eq(input, Floor(input))),
FullLike(input, std::numeric_limits<float>::quiet_NaN()),
real_result);
};
auto& b = *input.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
return DoWithUpcastToF32(input, {BF16, F16}, do_it);
});
}
// Incomplete gamma functions
namespace {
enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
// Helper function for computing Igamma using a power series.
template <kIgammaMode mode>
XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
xla::PrimitiveType type) {
// vals: (enabled, r, c, ans, x)
// 'enabled' is a predication mask that says for which elements we should
// execute the loop body. Disabled elements have no effect in the loop body.
// TODO(phawkins): in general this isn't an optimal implementation on any
// backend. For example, on GPU, we should probably vectorize to the warp
// size, and then run independent loops for each warp's worth of
// data.
auto cond = [&](absl::Span<const XlaOp> vals,
XlaBuilder* builder) -> StatusOr<XlaOp> {
XlaOp enabled = vals[0];
return Any(enabled);
};
auto body = [&](absl::Span<const XlaOp> vals,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp enabled = vals[0];
XlaOp r = vals[1];
XlaOp c = vals[2];
XlaOp ans = vals[3];
XlaOp x = vals[4];
XlaOp dc_da = vals[5];
XlaOp dans_da = vals[6];
r = r + ScalarLike(r, 1);
dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
dans_da = dans_da + dc_da;
c = c * (x / r);
ans = ans + c;
XlaOp conditional;
if (mode == VALUE) {
conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
} else {
conditional =
And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
}
return std::vector<XlaOp>{
conditional,
Select(enabled, r, vals[1]),
Select(enabled, c, vals[2]),
Select(enabled, ans, vals[3]),
Select(enabled, x, vals[4]),
Select(enabled, dc_da, vals[5]),
Select(enabled, dans_da, vals[6]),
};
};
auto& b = *ax.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<XlaOp> vals = {
enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
FullLike(a, 0),
};
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
XlaOp ans = vals[3];
XlaOp dans_da = vals[6];
if (mode == VALUE) {
return (ans * ax) / a;
}
XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
switch (mode) {
case DERIVATIVE:
return ax * (ans * dlogax_da + dans_da) / a;
case SAMPLE_DERIVATIVE:
default:
return -(dans_da + ans * dlogax_da) * x / a;
}
});
}
// Helper function for computing Igammac using a continued fraction.
template <kIgammaMode mode>
XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
xla::PrimitiveType type) {
// vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
auto cond = [&](absl::Span<const XlaOp> vals,
XlaBuilder* builder) -> StatusOr<XlaOp> {
XlaOp enabled = vals[0];
XlaOp c = vals[5];
return And(Lt(c, ScalarLike(c, 2000)), Any(enabled));
};
auto body = [&](absl::Span<const XlaOp> vals,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp enabled = vals[0];
XlaOp ans = vals[1];
XlaOp t = vals[2];
XlaOp y = vals[3];
XlaOp z = vals[4];
XlaOp c = vals[5];
XlaOp pkm1 = vals[6];
XlaOp qkm1 = vals[7];
XlaOp pkm2 = vals[8];
XlaOp qkm2 = vals[9];
XlaOp dpkm2_da = vals[10];
XlaOp dqkm2_da = vals[11];
XlaOp dpkm1_da = vals[12];
XlaOp dqkm1_da = vals[13];
XlaOp dans_da = vals[14];
c = c + ScalarLike(c, 1);
y = y + ScalarLike(y, 1);
z = z + ScalarLike(z, 2);
XlaOp yc = y * c;
XlaOp pk = pkm1 * z - pkm2 * yc;
XlaOp qk = qkm1 * z - qkm2 * yc;
XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
XlaOp r = pk / qk;
t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
ans = Select(qk_is_nonzero, r, ans);
XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
XlaOp dans_da_new =
Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
XlaOp grad_conditional =
Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dqkm2_da = dqkm1_da;
dpkm1_da = dpk_da;
dqkm1_da = dqk_da;
XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
XlaOp conditional;
if (mode == VALUE) {
conditional = And(enabled, Gt(t, Epsilon(builder, type)));
} else {
conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
}
return std::vector<XlaOp>{conditional,
Select(enabled, ans, vals[1]),
Select(enabled, t, vals[2]),
Select(enabled, y, vals[3]),
Select(enabled, z, vals[4]),
c,
Select(enabled, pkm1, vals[6]),
Select(enabled, qkm1, vals[7]),
Select(enabled, pkm2, vals[8]),
Select(enabled, qkm2, vals[9]),
Select(enabled, dpkm2_da, vals[10]),
Select(enabled, dqkm2_da, vals[11]),
Select(enabled, dpkm1_da, vals[12]),
Select(enabled, dqkm1_da, vals[13]),
Select(enabled, dans_da_new, vals[14])};
};
auto& b = *ax.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
XlaOp y = ScalarLike(a, 1) - a;
XlaOp z = x + y + ScalarLike(x, 1);
XlaOp c = ScalarLike(x, 0);
XlaOp pkm2 = FullLike(x, 1);
XlaOp qkm2 = x;
XlaOp pkm1 = x + ScalarLike(x, 1);
XlaOp qkm1 = z * x;
XlaOp ans = pkm1 / qkm1;
XlaOp t = FullLike(x, 1);
XlaOp dpkm2_da = FullLike(x, 0);
XlaOp dqkm2_da = FullLike(x, 0);
XlaOp dpkm1_da = FullLike(x, 0);
XlaOp dqkm1_da = -x;
XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
std::vector<XlaOp> vals = {enabled, ans, t, y, z,
c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
ans = vals[1];
if (mode == VALUE) {
return ans * ax;
}
dans_da = vals[14];
XlaOp dlogax_da = Log(x) - Digamma(a);
switch (mode) {
case DERIVATIVE:
return ax * (ans * dlogax_da + dans_da);
case SAMPLE_DERIVATIVE:
default:
return -(dans_da + ans * dlogax_da) * x;
}
});
}
} // namespace
XlaOp Igamma(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp is_nan = Or(IsNan(a), IsNan(x));
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
XlaOp x_is_infinity =
Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
ax = Exp(ax);
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(
use_igammac,
ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(x_is_infinity, FullLike(output, 1), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
return output;
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to Igamma must have equal shapes and types; got %s and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
a_x_type = F32;
}
XlaOp result = doit(a, x, a_x_type);
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
XlaOp IgammaGradA(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp is_nan = Or(IsNan(a), IsNan(x));
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
ax = Exp(ax);
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(use_igammac,
-IgammacContinuedFraction<DERIVATIVE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<DERIVATIVE>(
ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
return output;
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to IgammaGradA must have equal shapes and types; got %s "
"and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(a, x, a_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
// Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp is_nan = Or(IsNan(a), IsNan(x));
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
ax = Exp(ax);
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(use_igammac,
-IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<SAMPLE_DERIVATIVE>(
ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
return output;
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to RandomGammaGrad must have equal shapes and types; got "
"%s and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(a, x, a_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
XlaOp Igammac(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp out_of_range = Or(Le(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igamma = Or(Lt(x, ScalarLike(x, 1)), Lt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
XlaOp enabled = Not(Or(out_of_range, underflow));
ax = Exp(ax);
XlaOp result =
Select(use_igamma,
ScalarLike(a, 1) - IgammaSeries<VALUE>(
ax, x, a, And(enabled, use_igamma), type),
IgammacContinuedFraction<VALUE>(
ax, x, a, And(enabled, Not(use_igamma)), type));
XlaOp x_is_infinity =
Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
result = Select(x_is_infinity, ZerosLike(result), result);
return Select(out_of_range, FullLike(a, 1), result);
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to Igammac must have equal shapes and types; "
"got %s and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igammac", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
a_x_type = F32;
}
XlaOp result = doit(a, x, a_x_type);
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
XlaOp RoundToEven(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Reject non-real non-fp inputs (What does it even mean to round a complex
// number? Do you round each component equally? In that case, you should
// just ask for that explicitly.)
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
auto half = ScalarLike(x, 0.5);
auto one = ScalarLike(x, 1.0);
auto two = ScalarLike(x, 2.0);
auto round_val = Floor(x);
auto fraction = x - round_val;
auto nearest_even_int = round_val - two * Floor(half * x);
auto is_odd = Eq(nearest_even_int, one);
return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
round_val + one, round_val);
});
}
// Trigonometric functions.
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
// pi if x == -1
// For complex:
// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
XlaOp Acos(XlaOp x) {
XlaBuilder* b = x.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
if (primitive_util::IsComplexType(shape.element_type())) {
auto one = ScalarLike(x, 1);
auto imag_one = Complex(
Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
One(b, primitive_util::ComplexComponentType(shape.element_type())));
auto result =
Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
return result;
}
return Select(Ne(x, FullLike(x, -1)),
ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
ScalarLike(x, 1.0) + x),
FullLike(x, M_PI));
});
}
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XlaOp Asin(XlaOp x) {
return ScalarLike(x, 2.0) *
Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
}
XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
XlaOp Tan(XlaOp x) {
return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
}
// Hyperbolic trigonometric functions.
// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
// = log(x + sqrt((x+1)*(x-1)))
// acosh(x) = nan if x < -1
//
// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
// log(2*x) = log(2) + log(x). (Note this works because negative x never
// overflows; x < -1 simply yields nan. This is quite different than asinh!)
XlaOp Acosh(XlaOp x) {
XlaBuilder* b = x.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one = ScalarLike(x, 1);
auto neg_one = ScalarLike(x, -1);
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
// return
//
// nan if x < -1
// log(x) + log(2) if x >= sqrt_max_value
// log(x + sqrt((x+1)*(x-1))) otherwise
//
// TODO(jlebar): For now, we ignore the question of overflow if x is a
// complex type, because we don't yet have exhaustive tests for complex trig
// functions.
auto naive_result = Log(x + Sqrt((x + one) * (x - one)));
if (primitive_util::IsComplexType(shape.element_type())) {
return naive_result;
}
auto overflow_result = Log(x) + Log(ScalarLike(x, 2));
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
return Select(Lt(x, neg_one), nan,
Select(Ge(x, sqrt_max_value), overflow_result, naive_result));
});
}
// asinh(x) = log(x + sqrt(x^2 + 1))
//
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
// as 2*x and return log(2) + log(x).
//
// If x is negative, the above would give us some trouble; we can't approximate
// the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) =
// -asinh(x).
XlaOp Asinh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one = ScalarLike(x, 1);
// Let a = abs(x). Compute
//
// y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or
// y = log(a) + log(2) otherwise
//
// and then return
//
// y * sign(x).
//
// TODO(jlebar): For now, we ignore the question of overflow if x is a
// complex type, because we don't yet have exhaustive tests for complex trig
// functions.
if (primitive_util::IsComplexType(shape.element_type())) {
return Log(x + Sqrt(x * x + one));
}
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
// arithmetic. However, we would like to retain the low order term of this,
// which is around 0.5 * x**2 using a binomial expansion.
// Let z = sqrt(a**2 + 1)
// log(a + sqrt(a**2 + 1)) =
// log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
// log((a + a**2 + 1 + a * z + z) / (1 + z)) =
// log(1 + a + a**2 / (1 + z)) =
// log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
// This rewrite retains the lower order term.
auto a = Abs(x);
auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
auto naive_result = Log(a + Sqrt(a * a + one));
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
Select(Le(a, one), small_result, naive_result));
};
// These upcasts are not strictly necessary on all platforms to get within our
// error tolerances, so we could relax this if it ever mattered.
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
return b->ReportErrorOrReturn(do_it(x));
});
}
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
// atanh(x) = nan otherwise
XlaOp Atanh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
// because we don't yet have exhaustive tests for complex trig functions.
if (primitive_util::IsComplexType(shape.element_type())) {
return naive_result;
}
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
};
return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
return b->ReportErrorOrReturn(do_it(x));
});
}
// Cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Cosh(XlaOp x) {
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
auto log_one_half = Log(ScalarLike(x, 0.5));
return Exp(x + log_one_half) + Exp(-x + log_one_half);
});
}
// Sinh(x) = (e^x - e^-x) / 2
// = e^(x + log(1/2)) - e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to +/-inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Sinh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one_half = ScalarLike(x, 0.5);
auto log_one_half = Log(ScalarLike(x, 0.5));
auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
if (primitive_util::IsComplexType(shape.element_type())) {
return large_sinh_result;
}
// Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
// values of x.
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
// 0.
// Rewrite this to avoid that. We use expm1(x) because that preserves the
// first order term of the taylor series of e^x.
// (e^(x) - e^(-x)) / 2. =
// (e^(x) - 1 + 1 - e^(-x)) / 2.
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
auto expm1 = Expm1(x);
auto one = ScalarLike(x, 1.);
auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
};
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
return b->ReportErrorOrReturn(do_it(x));
});
}
XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
auto perform_conj =
primitive_util::IsComplexType(shape.element_type()) && conjugate;
return perform_conj ? Conj(x) : x;
});
}
XlaOp NextAfter(XlaOp from, XlaOp to) {
auto builder = from.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
int bitwidth = primitive_util::BitWidth(shape.element_type());
auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
auto from_as_int = BitcastConvertType(from, int_type);
auto to_as_int = BitcastConvertType(to, int_type);
// The result is NaN if either "from" or "to" are NaN.
auto from_is_nan = Ne(from, from);
auto to_is_nan = Ne(to, to);
auto nan_input = Or(from_is_nan, to_is_nan);
auto result_for_nan =
Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
shape.dimensions());
result_for_nan = BitcastConvertType(result_for_nan, int_type);
// The sign bit is the MSB.
const int64_t sign_mask = int64_t{1} << (bitwidth - 1);
// Discard the sign bit to make the result non-negative.
auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
// When both "from" and "to" are equal, the result is "to".
// N.B. It would not make a difference if we chose the result to be "from".
auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
auto result_for_equal = to_as_int;
// When both "from" and "to" are both 0, the result is "to". This ensures we
// get a zero signed like "to".
auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
auto result_for_both_zero = to_as_int;
auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
// If from == 0 && to != 0, we need to return the smallest subnormal number
// signed like "to".
auto result_for_from_zero_to_non_zero =
Or(to_sign, ScalarLike(from_as_int, 1));
// If the sign of "from" and "to" disagree:
// - we need to make the magnitude of "from" smaller so that it is closer to
// zero.
//
// Otherwise the signs agree:
// - "from" with a magnitude larger than "to" means we need to make the
// magnitude smaller.
// - "from" with a magnitude smaller than "to" means we need to make the
// magnitude larger.
// - "from" with the same magnitude and sign as "to" has already been
// handled.
auto signs_disagree = Ne(from_sign, to_sign);
auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
auto result_has_smaller_magnitude =
Or(from_magnitude_larger_than_to, signs_disagree);
auto magnitude_adjustment =
Select(result_has_smaller_magnitude,
Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
auto result = Add(from_as_int, magnitude_adjustment);
// Handle from == ±0.
result = Select(from_is_zero,
Select(to_is_zero, result_for_both_zero,
result_for_from_zero_to_non_zero),
result);
// Handle from == to.
result = Select(from_and_to_are_equal, result_for_equal, result);
// Handle isnan(from) || isnan(to).
result = Select(nan_input, result_for_nan, result);
// Cast back to the original type.
return BitcastConvertType(result, shape.element_type());
});
}
// Computes an approximation to the modified Bessel function of the first kind,
// zeroth order.
// The following implementation follows Cephes' F32 and F64 implementation of
// i0e.
static XlaOp I0eImpl32(XlaOp x) {
static const std::array<float, 18> kI0eCoeffsA{
-1.30002500998624804212E-8f, 6.04699502254191894932E-8f,
-2.67079385394061173391E-7f, 1.11738753912010371815E-6f,
-4.41673835845875056359E-6f, 1.64484480707288970893E-5f,
-5.75419501008210370398E-5f, 1.88502885095841655729E-4f,
-5.76375574538582365885E-4f, 1.63947561694133579842E-3f,
-4.32430999505057594430E-3f, 1.05464603945949983183E-2f,
-2.37374148058994688156E-2f, 4.93052842396707084878E-2f,
-9.49010970480476444210E-2f, 1.71620901522208775349E-1f,
-3.04682672343198398683E-1f, 6.76795274409476084995E-1f};
static const std::array<float, 7> kI0eCoeffsB{
3.39623202570838634515E-9f, 2.26666899049817806459E-8f,
2.04891858946906374183E-7f, 2.89137052083475648297E-6f,
6.88975834691682398426E-5f, 3.36911647825569408990E-3f,
8.04490411014108831608E-1f};
x = Abs(x);
auto half = xla::ScalarLike(x, 0.5);
auto two = xla::ScalarLike(x, 2.0);
auto thirty_two = xla::ScalarLike(x, 32.0);
auto result_le_8 =
EvaluateChebyshevPolynomial<float>(half * x - two, kI0eCoeffsA);
auto result_gt_8 =
EvaluateChebyshevPolynomial<float>(thirty_two / x - two, kI0eCoeffsB) /
Sqrt(x);
return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
}
static XlaOp I0eImpl64(XlaOp x) {
static const std::array<double, 30> kI0eCoeffsA{
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
static const std::array<double, 25> kI0eCoeffsB{
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};
x = Abs(x);
auto half = xla::ScalarLike(x, 0.5);
auto two = xla::ScalarLike(x, 2.0);
auto thirty_two = xla::ScalarLike(x, 32.0);
auto result_le_8 =
EvaluateChebyshevPolynomial<double>(half * x - two, kI0eCoeffsA);
auto result_gt_8 =
EvaluateChebyshevPolynomial<double>(thirty_two / x - two, kI0eCoeffsB) /
Sqrt(x);
return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
}
XlaOp BesselI0e(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI0e", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (shape.element_type() == F64) {
return I0eImpl64(x);
}
// I0eF32Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return I0eImpl32(x); });
});
}
// Computes an approximation to the modified Bessel function of the first kind,
// first order.
// The following implementation follows Cephes' F32 and F64 implementation of
// i1e.
static XlaOp I1eImpl32(XlaOp x) {
static const std::array<float, 17> kI1eCoeffsA{
9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
2.52587186443633654823E-1f};
static const std::array<float, 7> kI1eCoeffsB{
-3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
-2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
-1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
7.78576235018280120474E-1f};
XlaOp z = Abs(x);
auto half = xla::ScalarLike(x, 0.5);
auto two = xla::ScalarLike(x, 2.0);
auto thirty_two = xla::ScalarLike(x, 32.0);
auto result_le_8 =
z * EvaluateChebyshevPolynomial<float>(half * z - two, kI1eCoeffsA);
auto result_gt_8 =
EvaluateChebyshevPolynomial<float>(thirty_two / z - two, kI1eCoeffsB) /
Sqrt(z);
return Sign(x) *
Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
}
static XlaOp I1eImpl64(XlaOp x) {
static const std::array<double, 29> kI1eCoeffsA{
2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1};
static const std::array<double, 25> kI1eCoeffsB{
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1};
XlaOp z = Abs(x);
auto half = xla::ScalarLike(x, 0.5);
auto two = xla::ScalarLike(x, 2.0);
auto thirty_two = xla::ScalarLike(x, 32.0);
auto result_le_8 =
z * EvaluateChebyshevPolynomial<double>(half * z - two, kI1eCoeffsA);
auto result_gt_8 =
EvaluateChebyshevPolynomial<double>(thirty_two / z - two, kI1eCoeffsB) /
Sqrt(z);
return Sign(x) *
Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
}
XlaOp BesselI1e(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI1e", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (shape.element_type() == F64) {
return I1eImpl64(x);
}
// I1eF32Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return I1eImpl32(x); });
});
}
// I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
// arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
// DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
static XlaOp LentzThompsonBarnettAlgorithm(
int64_t num_iterations, double small, double threshold,
const ForEachIndexBodyFunction& nth_partial_numerator,
const ForEachIndexBodyFunction& nth_partial_denominator,
absl::Span<const XlaOp> inputs, absl::string_view name) {
auto& b = *inputs.front().builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RET_CHECK(num_iterations < INT32_MAX);
enum {
// Position in the evaluation.
kIterationIdx,
// Whether or not we have reached the desired tolerance.
kValuesUnconvergedIdx,
// Ratio between nth canonical numerator and the nth-1 canonical
// numerator.
kCIdx,
// Ratio between nth-1 canonical denominator and the nth canonical
// denominator.
kDIdx,
// Computed approximant in the evaluation.
kHIdx,
// Inputs follow all of the other state.
kFirstInputIdx,
};
auto while_cond_fn = [num_iterations](
absl::Span<const XlaOp> values,
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
auto iteration = values[kIterationIdx];
auto iterations_remain_cond =
Lt(iteration, ScalarLike(iteration, num_iterations));
auto values_unconverged_cond = values[kValuesUnconvergedIdx];
return And(iterations_remain_cond, values_unconverged_cond);
};
auto while_body_fn =
[small, threshold, &nth_partial_numerator, &nth_partial_denominator](
absl::Span<const XlaOp> values,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp iteration = values[kIterationIdx];
TF_ASSIGN_OR_RETURN(
std::vector<XlaOp> partial_numerator,
nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
body_builder));
TF_RET_CHECK(partial_numerator.size() == 1);
TF_ASSIGN_OR_RETURN(
std::vector<XlaOp> partial_denominator,
nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
body_builder));
TF_RET_CHECK(partial_denominator.size() == 1);
auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
auto small_constant = FullLike(c, small);
c = Select(Lt(Abs(c), small_constant), small_constant, c);
auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
d = Select(Lt(Abs(d), small_constant), small_constant, d);
d = Reciprocal(d);
auto delta = c * d;
auto h = values[kHIdx] * delta;
std::vector<XlaOp> updated_values(values.size());
updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
updated_values[kCIdx] = c;
updated_values[kDIdx] = d;
updated_values[kHIdx] = h;
std::copy(values.begin() + kFirstInputIdx, values.end(),
updated_values.begin() + kFirstInputIdx);
// If any values are greater than the tolerance, we have not converged.
auto tolerance_comparison =
Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
updated_values[kValuesUnconvergedIdx] =
ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
CreateScalarOrComputation(PRED, body_builder));
return updated_values;
};
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
nth_partial_denominator(Zero(&b, U32), inputs, &b));
TF_RET_CHECK(partial_denominator.size() == 1);
auto h = partial_denominator[0];
auto small_constant = FullLike(h, small);
h = Select(Lt(Abs(h), small_constant), small_constant, h);
std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
values[kIterationIdx] = One(&b, U32);
values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
values[kCIdx] = h;
values[kDIdx] = FullLike(h, 0.0);
values[kHIdx] = h;
std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
values, name, &b));
return values[kHIdx];
});
}
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
auto& builder = *x.builder();
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
TF_ASSIGN_OR_RETURN(Shape b_shape, builder.GetShape(b));
TF_ASSIGN_OR_RETURN(Shape x_shape, builder.GetShape(x));
if (b_shape.element_type() != shape.element_type() ||
x_shape.element_type() != shape.element_type()) {
return InvalidArgument(
"Operands to RegularizedIncompleteBeta must have identical types, "
"got shapes %s, %s, and %s",
shape.ToString(), b_shape.ToString(), x_shape.ToString());
}
if (!primitive_util::IsFloatingPointType(shape.element_type())) {
return InvalidArgument(
"Operands to RegularizedIncompleteBeta must be real-valued "
"floating-point, but got %s",
PrimitiveType_Name(shape.element_type()));
}
PrimitiveType element_type = shape.element_type();
if (element_type == F16 || element_type == BF16) {
element_type = F32;
a = ConvertElementType(a, F32);
b = ConvertElementType(b, F32);
x = ConvertElementType(x, F32);
}
// The partial numerator for the incomplete beta function is given
// here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
// case: the partial numerator for the first iteration is one.
auto NthPartialBetaincNumerator =
[&](XlaOp iteration, absl::Span<const XlaOp> inputs,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto a = inputs[0];
auto b = inputs[1];
auto x = inputs[2];
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
auto iteration_is_even =
Eq(iteration_bcast % FullLike(iteration_bcast, 2),
FullLike(iteration_bcast, 0));
auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
m = ConvertElementType(m, element_type);
auto one = FullLike(a, 1.0);
auto two = FullLike(a, 2.0);
// Partial numerator terms.
auto even_numerator =
-(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
auto odd_numerator =
m * (b - m) * x / ((a + two * m - one) * (a + two * m));
auto one_numerator = ScalarLike(x, 1.0);
auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
return std::vector<XlaOp>{
Select(iteration_is_one, one_numerator, numerator)};
};
auto NthPartialBetaincDenominator =
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto x = inputs[2];
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
return std::vector<XlaOp>{
Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
};
// Determine if the inputs are out of range.
auto result_is_nan =
Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
Lt(x, ScalarLike(x, 0.0))),
Gt(x, ScalarLike(x, 1.0)));
// The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
// as per: http://dlmf.nist.gov/8.17.E23
//
// Otherwise, we can rewrite using the symmetry relation as per:
// http://dlmf.nist.gov/8.17.E4
auto converges_rapidly =
Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
auto a_orig = a;
a = Select(converges_rapidly, a, b);
b = Select(converges_rapidly, b, a_orig);
x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
XlaOp continued_fraction;
// Thresholds and iteration counts taken from Cephes.
if (element_type == F32) {
continued_fraction = LentzThompsonBarnettAlgorithm(
/*num_iterations=*/200,
/*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
/*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
"Betainc");
} else {
TF_RET_CHECK(element_type == F64);
continued_fraction = LentzThompsonBarnettAlgorithm(
/*num_iterations=*/600,
/*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
/*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
"Betainc");
}
// We want to compute the regularized complete beta function so we need to
// combine the continued fraction with a few more terms as well as dividing
// it by Beta(a, b). To avoid overflow, we compute in the log domain.
// See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
// formula.
auto lbeta = Lbeta(a, b);
auto result =
continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
result = Select(result_is_nan, NanValue(&builder, element_type), result);
// We have an additional fixup to do if we are taking advantage of the
// symmetry relation.
auto out =
Select(converges_rapidly, result, Sub(FullLike(result, 1.0), result));
return shape.element_type() == element_type
? out
: ConvertElementType(out, shape.element_type());
});
}
XlaOp Polygamma(XlaOp n, XlaOp x) {
auto& builder = *x.builder();
auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp n_plus_one = n + ScalarLike(n, 1.);
XlaOp sign =
(ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
// Check that n is a natural number.
output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
ScalarLike(n, nan), output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
if (n_shape != x_shape) {
return InvalidArgument(
"Arguments to Polygamma must have equal shapes and types; "
"got %s and %s",
n_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
n_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
n = ConvertElementType(n, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(n, x, n_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, n_shape.element_type());
}
return result;
});
}
XlaOp Zeta(XlaOp x, XlaOp q) {
auto& builder = *x.builder();
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
// (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
// These are ordered in reverse.
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
XlaOp a = q;
XlaOp neg_power = ScalarLike(a, 0.);
XlaOp initial_sum = Pow(q, Neg(x));
for (int i = 0; i < 9; ++i) {
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
initial_sum = initial_sum + neg_power;
}
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
XlaOp a_inverse_square = Reciprocal(Square(a));
XlaOp horner_sum = ScalarLike(a, 0.);
XlaOp factor = ScalarLike(a, 1.);
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
factor =
(x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
horner_sum = factor * a_inverse_square *
(horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
}
s = s + neg_power *
(ScalarLike(neg_power, 0.5) +
x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
const double nan = std::numeric_limits<double>::quiet_NaN();
const double inf = std::numeric_limits<double>::infinity();
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
XlaOp output =
Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
initial_sum, s);
// This is the harmonic series.
output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
// Function is not defined for x < 1.
output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
// For q <= 0, x must be an integer.
XlaOp x_domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
output = Select(x_domain_error, ScalarLike(x, nan), output);
// For all integer q <= 0, zeta has a pole. The limit is only defined as
// +inf if x is and even integer.
XlaOp at_pole = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
XlaOp x_is_even_int =
And(Eq(Rem(x, ScalarLike(x, 2.)), ScalarLike(x, 0.)), Eq(x, Floor(x)));
output = Select(
at_pole, Select(x_is_even_int, ScalarLike(x, inf), ScalarLike(x, nan)),
output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
if (x_shape != q_shape) {
return InvalidArgument(
"Arguments to Zeta must have equal shapes and types; got %s and %s",
x_shape.ToString(), q_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
x_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
x = ConvertElementType(x, F32);
q = ConvertElementType(q, F32);
}
XlaOp result = doit(x, q, x_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, x_shape.element_type());
}
return result;
});
}
} // namespace xla