blob: 9d36a0f9fb6a1f00d06010efa34861173fa05ead [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Enable the use of M_* math constants.
// NOTE: this must be first in the file to ensure that if cmath is transitively
// included by any other header it has the define set on first processing.
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
#include <vector>
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace chlo {
namespace {
struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ConstantLikeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto result_ty = op.getType().cast<ShapedType>();
// Unranked uses are not supported. Consider `mhlo-transform-unranked-hlo`.
if (!result_ty.hasRank()) return failure();
// Lower to MHLO constant if statically shaped.
if (result_ty.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
op, DenseElementsAttr::get(result_ty, op.value()));
return success();
}
// Lower to broadcasted constant.
ConstantLikeOp::Adaptor transformed(operands);
auto loc = op.getLoc();
Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
loc, extent_tensor_type, transformed.operand());
Type shape_ty =
RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
Value shape =
rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
return success();
}
};
template <typename FTy>
Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
Location loc, Value x,
const std::vector<FTy> &coefficients) {
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
for (FTy c : coefficients) {
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
poly = rewriter.create<mhlo::AddOp>(
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
}
return poly;
}
// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const double kMaxlog = 7.09782712893383996843E2;
const std::vector<double> kErfcPCoefficients{
2.46196981473530512524E-10, 5.64189564831068821977E-1,
7.46321056442269912687E0, 4.86371970985681366614E1,
1.96520832956077098242E2, 5.26445194995477358631E2,
9.34528527171957607540E2, 1.02755188689515710272E3,
5.57535335369399327526E2};
const std::vector<double> kErfcQCoefficients{
1.00000000000000000000E0, 1.32281951154744992508E1,
8.67072140885989742329E1, 3.54937778887819891062E2,
9.75708501743205489753E2, 1.82390916687909736289E3,
2.24633760818710981792E3, 1.65666309194161350182E3,
5.57535340817727675546E2};
const std::vector<double> kErfcRCoefficients{
5.64189583547755073984E-1, 1.27536670759978104416E0,
5.01905042251180477414E0, 6.16021097993053585195E0,
7.40974269950448939160E0, 2.97886665372100240670E0};
const std::vector<double> kErfcSCoefficients{
1.00000000000000000000E0, 2.26052863220117276590E0,
9.39603524938001434673E0, 1.20489539808096656605E1,
1.70814450747565897222E1, 9.60896809063285878198E0,
3.36907645100081516050E0};
// Let z = -x^2.
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
// Materialize polynomial approximation for x in [1, 8) as
// erfc(x) = exp(z) P(|x|) / Q(|x|).
Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x,
kErfcPCoefficients);
Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p);
Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x,
kErfcQCoefficients);
Value erfc_approx_1_8 =
rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q);
// Materialize polynomial approximation for x in >= 8 as
// erfc(x) exp(z) R(|x|) / S(|x|).
Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x,
kErfcRCoefficients);
Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r);
Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x,
kErfcSCoefficients);
Value erfc_approx_8_inf =
rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s);
// Combine polynomial approximations for x >= 1.
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT);
Value erfc_approx = rewriter.create<mhlo::SelectOp>(
loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf);
// Clamp to prevent overflow and materialize approximation for large x as
// erfc(x) = 0.
Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>(
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
Value erfc_approx_clamped =
rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx);
// Derive approximation for x <= -1 as
// erfc(x) = 2 - erfc(-x).
// Reuse previously materialized approximations all of which take |x| as their
// argument.
Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value two_sub_erfc_approx_clamped =
rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
return rewriter.create<mhlo::SelectOp>(
loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped);
}
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64ForMagnituteLEOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const std::vector<double> kErfTCoefficients{
9.60497373987051638749E0, 9.00260197203842689217E1,
2.23200534594684319226E3, 7.00332514112805075473E3,
5.55923013010394962768E4};
const std::vector<double> kErfUCoefficients{
1.00000000000000000000E0, 3.35617141647503099647E1,
5.21357949780152679795E2, 4.59432382970980127987E3,
2.26290000613890934246E4, 4.92673942608635921086E4};
// Materialize polynomial approximation for |x| <= 1 as
// erf(x) = x T(x^2) / U(x^2).
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
kErfTCoefficients);
Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t);
Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq,
kErfUCoefficients);
return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u);
}
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
// Rely on erf approximation for |x| < 1
// erf(x) = erf_approx(x)
Value erf_approx =
MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
// Rely on erfc approximation for |x| >= 1 and materialize erf as
// erf(x) = 1 - erfc_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erfc_approx =
MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx);
// Materialize approximation selection based on argument.
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx,
erfc_based_approx);
}
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
// Rely on erfc approximation for |x| >= 1
// erfc(x) = erfc_approx(x)
Value erfc_approx =
MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
// Rely on erf approximation for |x| < 1 and materialize erfc as
// erfc(x) = 1 - erf_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erf_approx =
MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
// Materialize approximation selection based on argument.
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
erfc_approx);
}
// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const double kMaxlog = 88.72283905206835;
const std::vector<float> kErfcPCoefficients{
+2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
-5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
+3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
};
const std::vector<float> kErfcRCoefficients{
-1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
+2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
-2.820767439740514E-1, +5.641895067754075E-1,
};
// Let z = -x^2.
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
// Materialize polynomial approximation for x >= 1 as
// erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
// erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
Value exp_z_mul_one_div_abs_x =
rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
Value poly_p = MaterializePolynomialApproximation(
rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
Value poly_r = MaterializePolynomialApproximation(
rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
Value poly =
rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
Value erfc_approx =
rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
// Clamp to prevent overflow and materialize approximation for large x as
// erfc(x) = 0.
Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
Value erfc_approx_clamped =
rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
// Derive approximation for x <= -1 as
// erfc(x) = 2 - erfc(-x).
// Reuse previously materialized approximations all of which take |x| as their
// argument.
Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
Value two_sub_erfc_approx =
rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
erfc_approx_clamped);
}
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kErfTCoefficients{
+7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
-2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
+1.128379165726710E+0,
};
// Materialize polynomial approximation for |x| <= 1 as
// erf(x) = x T(x^2).
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
kErfTCoefficients);
return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
}
// This is the same approximation as used in Eigen.
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
const std::vector<float> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
// Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4] as
// erf(x) = x * Alpha(x^2) / Beta(x^2).
Value alpha_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
Value beta_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
}
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
// Rely on erfc approximation for |x| >= 1
// erfc(x) = erfc_approx(x)
Value erfc_approx =
MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
// Rely on erf approximation for |x| < 1 and materialize erfc as
// erfc(x) = 1 - erf_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erf_approx =
MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
// Materialize approximation selection based on argument.
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
erfc_approx);
}
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args, FloatType min_precision_ty,
Value callback(ConversionPatternRewriter &,
Location, ValueRange)) {
auto original_ty =
getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
// Upcast arguments if necessary.
llvm::SmallVector<Value, 2> casted_args;
if (needs_upcast) {
for (Value a : args) {
casted_args.push_back(
rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
}
args = casted_args;
}
Value result = callback(rewriter, loc, args);
// Cast back if necessary.
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
}
return result;
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
ErfOp::Adaptor transformed(operands);
Value x = transformed.operand();
Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64, f32, and f16.
if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
if (ty.isF64()) {
rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x));
return success();
}
rewriter.replaceOp(op, MaterializeWithUpcast(
rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfApproximationF32));
return success();
}
};
struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
using OpConversionPattern<ErfcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfcOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
ErfcOp::Adaptor transformed(operands);
Value x = transformed.operand();
Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64, f32, and f16.
if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
if (ty.isF64()) {
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
return success();
}
rewriter.replaceOp(op, MaterializeWithUpcast(
rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfcApproximationF32));
return success();
}
};
// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
// [7, 9] seemed to be the least sensitive to the quality of the log function.
// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
// for a particularly inaccurate log function.
constexpr double kLanczosGamma = 7; // aka g
constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// lgamma(z + 1) = (log(2) + log(pi)) / 2
// + (z + 1/2) * log(t(z))
// - t(z) + log(a(z))
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1, x);
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
Value z =
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
Value quotient = rewriter.create<mhlo::DivOp>(
loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczos_plus_half =
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
Value log_term =
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
// Note that t(z) may be large and we need to be careful not to overflow to
// infinity in the relevant term
// r = (z + 1/2) * log(t(z)) - t(z).
// Therefore, we compute this as
// r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
Value sum = rewriter.create<mhlo::SubOp>(
loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
// Compute the final result (modulo reflection) as
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
Value lgamma = rewriter.create<mhlo::AddOp>(
loc,
rewriter.create<mhlo::AddOp>(
loc,
getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
r),
log_a);
// Compute the reflected value for x < 0.5 as
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
//
// The abs is needed because lgamma is the log of the absolute value of the
// gamma function.
//
// We have to be careful when computing the final term above. gamma(x) goes
// to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
// term. The slope is large, so precision is particularly important.
//
// Because abs(sin(pi * x)) has period of 1 we can equivalently use
// abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
// more numerically accurate: It doesn't overflow to inf like pi * x would and
// if x is an integer it evaluates to exactly 0 which is important because we
// then take the log of this value, and log(0) is inf.
//
// We don't have a frac(x) primitive in HLO and computing it is tricky, but
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
// purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
//
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
// [0, 1] is symmetric across the line Y=0.5.
//
// Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
// pi * abs_frac for values of abs_frac close to 1.
Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
Value abs_frac = rewriter.create<mhlo::SubOp>(
loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
Value reduce_abs_frac =
rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
abs_frac = rewriter.create<mhlo::SelectOp>(
loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
abs_frac);
// Materialize reflection.
Value reflection_denom = rewriter.create<mhlo::LogOp>(
loc,
rewriter.create<mhlo::SinOp>(
loc, rewriter.create<mhlo::MulOp>(
loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
loc,
rewriter.create<mhlo::SubOp>(
loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
reflection_denom),
lgamma);
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
// then it "wins" and the result is +/-inf.
Value finite_reflection_denom =
rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
Value neg_reflection_denom =
rewriter.create<mhlo::NegOp>(loc, reflection_denom);
lgamma_reflection = rewriter.create<mhlo::SelectOp>(
loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
// Select whether or not to rely on the reflection.
lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
lgamma_reflection, lgamma);
// Materialize +/-inf behavior as
// lgamma(+/-inf) = +inf.
Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
return rewriter.create<mhlo::SelectOp>(
loc, x_is_inf,
chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
lgamma);
}
// Express `cosh` as
// cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2))
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
CoshOp::Adaptor transformed(operands);
Value x = transformed.operand();
Value log_one_half =
rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
Value exp_add = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
Value exp_sub = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
return rewriter.create<mhlo::AddOp>(loc, exp_add, exp_sub);
}
struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
using OpConversionPattern<CoshOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
CoshOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
CoshOp::Adaptor transformed(operands);
Value x = transformed.operand();
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
// TODO(hinsu): Support operands with complex element types by always
// using the formula for large x. The compare op is not legal for complex
// numbers.
return failure();
}
rewriter.replaceOp(op,
MaterializeWithUpcast(rewriter, op.getLoc(), operands,
rewriter.getF32Type(),
&MaterializeCoshApproximation));
return success();
}
};
// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1, x);
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
Value z =
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value zero = getConstantLike(rewriter, loc, 0.0, x);
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
Value a_prime = zero;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
a_prime = rewriter.create<mhlo::SubOp>(
loc, a_prime,
rewriter.create<mhlo::DivOp>(
loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
a = rewriter.create<mhlo::AddOp>(
loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczos_plus_half =
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
Value log_term =
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
// Materialize the final result (modulo reflection) as
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
Value digamma = rewriter.create<mhlo::SubOp>(
loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
lanczos_gamma_div_t);
// We need to be careful how we compute cot(pi * input) below: For
// near-integral arguments, pi * input can lose precision.
//
// Input is already known to be less than 0.5 (otherwise we don't have to
// reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
// increase precision of pi * x and the resulting cotangent.
Value reduced_x = rewriter.create<mhlo::AddOp>(
loc, x,
rewriter.create<mhlo::AbsOp>(
loc, rewriter.create<mhlo::FloorOp>(
loc, rewriter.create<mhlo::AddOp>(
loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
// Materialize reflection for inputs less than 0.5 as
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
Value pi = getConstantLike(rewriter, loc, M_PI, x);
Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
Value reflection = rewriter.create<mhlo::SubOp>(
loc, digamma,
rewriter.create<mhlo::DivOp>(
loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
// Select whether or not to rely on the reflection.
digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
digamma);
// Digamma has poles at negative integers and zero; return nan for those.
const StringAttr kLE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value is_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
return rewriter.create<mhlo::SelectOp>(
loc, is_pole,
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
x),
digamma);
}
Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
assert(args.size() == 2);
Value x = args[0];
Value q = args[1];
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
Value a = q;
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
Value neg_power = zero;
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
for (int i = 0; i < 9; ++i) {
a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
}
a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
Value neg_power_mul_a_div_x_minus_one =
rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
neg_power_mul_a_div_x_minus_one);
Value a_inverse_square = rewriter.create<mhlo::DivOp>(
loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
Value horner_sum = zero;
Value factor = one;
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
Value factor_lhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
Value factor_rhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
horner_sum = rewriter.create<mhlo::MulOp>(
loc, factor,
rewriter.create<mhlo::MulOp>(
loc, a_inverse_square,
rewriter.create<mhlo::AddOp>(
loc, horner_sum,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
}
Value zero_point_five_like_neg_power =
chlo::getConstantLike(rewriter, loc, .5, neg_power);
Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
s = rewriter.create<mhlo::AddOp>(
loc, s,
rewriter.create<mhlo::MulOp>(
loc, neg_power,
rewriter.create<mhlo::AddOp>(
loc, zero_point_five_like_neg_power,
rewriter.create<mhlo::MulOp>(
loc, x_div_a,
rewriter.create<mhlo::AddOp>(
loc,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
a),
horner_sum)))));
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
Value output = rewriter.create<mhlo::SelectOp>(
loc,
rewriter.create<mhlo::CompareOp>(
loc, abs_neg_power,
rewriter.create<mhlo::MulOp>(
loc, abs_initial_sum,
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
kLT),
initial_sum, s);
// Function is not defined for x < 1.
Value nan = chlo::getConstantLike(
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
output);
// For q <= 0, x must be an integer.
const StringAttr kLE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
const StringAttr kNE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
Value x_not_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
Value x_domain_error =
rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
// For all integer q <= 0, zeta has a pole. The limit is only defined as
// +inf if x is and even integer.
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value inf = chlo::getConstantLike(rewriter, loc,
std::numeric_limits<double>::infinity(), x);
Value q_is_int = rewriter.create<mhlo::CompareOp>(
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value x_is_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
Value x_is_even = rewriter.create<mhlo::CompareOp>(
loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
output = rewriter.create<mhlo::SelectOp>(
loc, at_pole,
rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
// For x = 1, this is the harmonic series and diverges.
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
return output;
}
Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
PolygammaOp::Adaptor transformed(args);
Value n = transformed.n();
Value x = transformed.x();
// Handle integer n > 0.
Value one = getConstantLike(rewriter, loc, 1.0, x);
Value two = getConstantLike(rewriter, loc, 2.0, x);
Value sign = rewriter.create<mhlo::SubOp>(
loc,
rewriter.create<mhlo::MulOp>(loc, two,
rewriter.create<mhlo::RemOp>(loc, n, two)),
one);
Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
Value result = rewriter.create<mhlo::MulOp>(
loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
// Handle n = 0.
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value zero = getConstantLike(rewriter, loc, 0.0, x);
Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
result = rewriter.create<mhlo::SelectOp>(
loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
// Check that n is a natural number. Return nan, otherwise.
const StringAttr kNE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
Value non_int = rewriter.create<mhlo::CompareOp>(
loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
return rewriter.create<mhlo::SelectOp>(
loc, non_natural,
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
x),
result);
}
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LgammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeDigamma));
return success();
}
};
struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
using OpConversionPattern<PolygammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
PolygammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
&MaterializePolygamma));
return success();
}
};
Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
auto result_ty = x.getType().cast<ShapedType>();
// TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
Value two = rewriter.create<mhlo::ConstOp>(
loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2));
Type extent_tensor_type = shape::getExtentTensorType(x.getContext());
Value uncasted_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, x);
Type shape_ty =
RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
Value shape = rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
Value two_with_x_shape = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, result_ty, two, shape, rewriter.getI64TensorAttr({}));
Value log_two = rewriter.create<mhlo::LogOp>(loc, two_with_x_shape);
Value log_one_half = rewriter.create<mhlo::NegOp>(loc, log_two);
Value exp_add = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
Value exp_sub = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub);
}
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value large_sinh_result =
MaterializeSinhApproximationForLargeX(rewriter, loc, operands);
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x);
Value exp_neg_x =
rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x));
Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x);
Value two = getConstantLike(rewriter, loc, 2.0, x);
Value small_sinh_result =
rewriter.create<mhlo::DivOp>(loc, exp_difference, two);
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1.0, x);
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result,
large_sinh_result);
}
struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
using OpConversionPattern<SinhOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
SinhOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX(
rewriter, op.getLoc(), operands));
return success();
}
rewriter.replaceOp(op,
MaterializeWithUpcast(rewriter, op.getLoc(), operands,
rewriter.getF32Type(),
&MaterializeSinhApproximation));
return success();
}
};
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ZetaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
&MaterializeZeta));
return success();
}
};
struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
BroadcastSelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
typename BroadcastSelectOp::Adaptor transformed(operands);
Value pred = transformed.pred();
Value on_true = transformed.on_true();
Value on_false = transformed.on_false();
auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!pred_type || !on_true_type || !on_false_type || !result_type) {
return failure();
}
auto loc = op.getLoc();
Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
Value on_false_shape =
rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
int64_t result_rank = std::max(
{pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
Value broadcastable_cstr =
rewriter.createOrFold<shape::CstrBroadcastableOp>(
loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
auto assuming_op = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assuming_op.doRegion());
Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(op.getContext()),
ValueRange{pred_shape, on_true_shape, on_false_shape},
/*error=*/nullptr);
auto shape_type =
RankedTensorType::get({result_rank}, rewriter.getIndexType());
result_extents =
rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
Value broadcasted_pred = pred;
// Pred has an implicit broadcast for scalars, so use that when convenient.
if (pred_type.getRank() > 0) {
auto pred_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
pred_type.getElementType()),
pred, result_extents,
rewriter.getI64TensorAttr(pred_broadcast_dimensions));
}
auto on_true_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
on_true_type.getElementType()),
on_true, result_extents,
rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
auto on_false_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
on_false_type.getElementType()),
on_false, result_extents,
rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
// And generate the final non-broadcasted ternary op.
Value final_result = rewriter.create<mhlo::SelectOp>(
loc, result_type, broadcasted_pred, broadcasted_on_true,
broadcasted_on_false);
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite for statically determinable non-broadcasting cases.
typename ChloOpTy::Adaptor transformed(operands);
auto lhs_type =
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhs_type =
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return failure();
// Requires rank broadcast.
if (lhs_type.getRank() != rhs_type.getRank()) return failure();
// Any dynamic dimension may require broadcasting and requires more
// analysis.
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
return failure();
for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
auto lhs_extent = std::get<0>(extents);
auto rhs_extent = std::get<1>(extents);
if (lhs_extent != rhs_extent) {
return failure();
}
}
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
operands, rewriter)});
return success();
}
};
// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding mhlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
// Specifically, this includes the following cases:
// - Same rank broadcast (operands have the same static rank).
// - Different-rank broadcast, either without a broadcast_dims attribte or
// with the broadcast_dims attribute set to map to a prefix padding.
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
// The restriction on broadcast_dims derives from the definition of the
// `shape.broadcast` op, which only supports prefix-padding.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
auto result_type =
op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type || !result_type) return failure();
// Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op.broadcast_dimensions();
if (broadcast_dimensions &&
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
// it is incompatible with unranked inputs. If this warning is emitted
// in real programs, it is an indication that the feature should be
// implemented versus just falling back on the more standard definition
// of numpy-like prefix-padding.
op.emitWarning() << "unsupported non prefix-padded dynamic rank "
<< "broadcast_dimensions = " << *broadcast_dimensions;
return failure();
}
// Compute result shape.
auto loc = op.getLoc();
// Insert a constraint on the shapes being broadcastable and insert all
// future code into an assuming block reliant on the constraint.
Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
auto broadcastable_cstr =
rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
auto assuming_op = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assuming_op.doRegion());
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents =
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
// downstream canonicalizations fold them away if possible. This is
// because, in the dynamic case, there are many corner cases regarding
// when it is safe to omit, and some of them require analysis to prove
// properly.
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
lhs_type.getElementType()),
lhs, result_extents,
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
rhs_type.getElementType()),
rhs, result_extents,
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
// And generate the final non-broadcasted binary op.
Value final_result = Adaptor::CreateOp(
op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success();
}
};
#include "generated_chlo_legalize_to_hlo.inc"
} // namespace
void PopulateChloBroadcastingPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved.
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
context, patterns, 10);
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns->insert<ConvertSelectOp>(context);
patterns->insert<ConvertConstantLikeOp>(context);
}
void PopulateDecomposeChloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(*patterns);
// Other patterns.
// clang-format off
patterns->insert<ConvertCoshOp,
ConvertDigammaOp,
ConvertErfOp,
ConvertErfcOp,
ConvertLgammaOp,
ConvertPolygammaOp,
ConvertSinhOp,
ConvertZetaOp>(context);
// clang-format on
}
} // namespace chlo
} // namespace mlir