blob: 8a0ddfc193d63e37d4cfc5a4184138d94f098c53 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Enable the use of M_* math constants.
// NOTE: this must be first in the file to ensure that if cmath is transitively
// included by any other header it has the define set on first processing.
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
#define _USE_MATH_DEFINES
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace chlo {
namespace {
struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ConstantLikeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().cast<ShapedType>();
// Unranked uses are not supported.
if (!resultTy.hasRank()) return failure();
// Lower to MHLO constant if statically shaped.
if (resultTy.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
op, DenseElementsAttr::get(resultTy, op.value()));
return success();
}
// Lower to broadcasted constant.
auto loc = op.getLoc();
Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.operand());
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
op, resultTy, constant, shape, rewriter.getI64TensorAttr({}));
return success();
}
};
template <typename FTy>
Value materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
Location loc, Value x,
ArrayRef<FTy> coefficients) {
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
for (FTy c : coefficients) {
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
poly = rewriter.create<mhlo::AddOp>(
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
}
return poly;
}
// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value materializeErfcApproximationF64ForMagnituteGeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const double kMaxlog = 7.09782712893383996843E2;
const double kErfcPCoefficients[] = {
2.46196981473530512524E-10, 5.64189564831068821977E-1,
7.46321056442269912687E0, 4.86371970985681366614E1,
1.96520832956077098242E2, 5.26445194995477358631E2,
9.34528527171957607540E2, 1.02755188689515710272E3,
5.57535335369399327526E2};
const double kErfcQCoefficients[] = {
1.00000000000000000000E0, 1.32281951154744992508E1,
8.67072140885989742329E1, 3.54937778887819891062E2,
9.75708501743205489753E2, 1.82390916687909736289E3,
2.24633760818710981792E3, 1.65666309194161350182E3,
5.57535340817727675546E2};
const double kErfcRCoefficients[] = {
5.64189583547755073984E-1, 1.27536670759978104416E0,
5.01905042251180477414E0, 6.16021097993053585195E0,
7.40974269950448939160E0, 2.97886665372100240670E0};
const double kErfcSCoefficients[] = {
1.00000000000000000000E0, 2.26052863220117276590E0,
9.39603524938001434673E0, 1.20489539808096656605E1,
1.70814450747565897222E1, 9.60896809063285878198E0,
3.36907645100081516050E0};
// Let z = -x^2.
Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
// Materialize polynomial approximation for x in [1, 8) as
// erfc(x) = exp(z) P(|x|) / Q(|x|).
Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value polP = materializePolynomialApproximation(
rewriter, loc, absX, llvm::makeArrayRef(kErfcPCoefficients));
Value expZMulPolyP = rewriter.create<mhlo::MulOp>(loc, expZ, polP);
Value polQ = materializePolynomialApproximation(
rewriter, loc, absX, llvm::makeArrayRef(kErfcQCoefficients));
Value erfcApprox18 = rewriter.create<mhlo::DivOp>(loc, expZMulPolyP, polQ);
// Materialize polynomial approximation for x in >= 8 as
// erfc(x) exp(z) R(|x|) / S(|x|).
Value polR = materializePolynomialApproximation(
rewriter, loc, absX, llvm::makeArrayRef(kErfcRCoefficients));
Value expZMulPolyR = rewriter.create<mhlo::MulOp>(loc, expZ, polR);
Value polS = materializePolynomialApproximation(
rewriter, loc, absX, llvm::makeArrayRef(kErfcSCoefficients));
Value erfcApprox8Inf = rewriter.create<mhlo::DivOp>(loc, expZMulPolyR, polS);
// Combine polynomial approximations for x >= 1.
Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
Value absXLt8 = rewriter.create<mhlo::CompareOp>(
loc, absX, eight, mhlo::ComparisonDirection::LT);
Value erfcApprox = rewriter.create<mhlo::SelectOp>(loc, absXLt8, erfcApprox18,
erfcApprox8Inf);
// Clamp to prevent overflow and materialize approximation for large x as
// erfc(x) = 0.
Value zLtNegMaxlog = rewriter.create<mhlo::CompareOp>(
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
mhlo::ComparisonDirection::LT);
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
Value erfcApproxClamped =
rewriter.create<mhlo::SelectOp>(loc, zLtNegMaxlog, zero, erfcApprox);
// Derive approximation for x <= -1 as
// erfc(x) = 2 - erfc(-x).
// Reuse previously materialized approximations all of which take |x| as their
// argument.
Value xLtZero = rewriter.create<mhlo::CompareOp>(
loc, x, zero, mhlo::ComparisonDirection::LT);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value twoSubErfcApproxClamped =
rewriter.create<mhlo::SubOp>(loc, two, erfcApproxClamped);
return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApproxClamped,
erfcApproxClamped);
}
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value materializeErfApproximationF64ForMagnituteLeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const double kErfTCoefficients[] = {
9.60497373987051638749E0, 9.00260197203842689217E1,
2.23200534594684319226E3, 7.00332514112805075473E3,
5.55923013010394962768E4};
const double kErfUCoefficients[] = {
1.00000000000000000000E0, 3.35617141647503099647E1,
5.21357949780152679795E2, 4.59432382970980127987E3,
2.26290000613890934246E4, 4.92673942608635921086E4};
// Materialize polynomial approximation for |x| <= 1 as
// erf(x) = x T(x^2) / U(x^2).
Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value polyT = materializePolynomialApproximation(
rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
Value xMulPolyT = rewriter.create<mhlo::MulOp>(loc, x, polyT);
Value polyU = materializePolynomialApproximation(
rewriter, loc, xSq, llvm::makeArrayRef(kErfUCoefficients));
return rewriter.create<mhlo::DivOp>(loc, xMulPolyT, polyU);
}
// This implementation is based on Cephes.
Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
// Rely on erf approximation for |x| < 1
// erf(x) = erf_approx(x)
Value erfApprox =
materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
// Rely on erfc approximation for |x| >= 1 and materialize erf as
// erf(x) = 1 - erfc_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erfcApprox =
materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
Value erfcBasedApprox = rewriter.create<mhlo::SubOp>(loc, one, erfcApprox);
// Materialize approximation selection based on argument.
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value absXLtOne = rewriter.create<mhlo::CompareOp>(
loc, absX, one, mhlo::ComparisonDirection::LT);
return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfApprox,
erfcBasedApprox);
}
Value materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
// Rely on erfc approximation for |x| >= 1
// erfc(x) = erfc_approx(x)
Value erfcApprox =
materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
// Rely on erf approximation for |x| < 1 and materialize erfc as
// erfc(x) = 1 - erf_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erfApprox =
materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
Value erfBasedApprox = rewriter.create<mhlo::SubOp>(loc, one, erfApprox);
// Materialize approximation selection based on argument.
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value absXLtOne = rewriter.create<mhlo::CompareOp>(
loc, absX, one, mhlo::ComparisonDirection::LT);
return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
erfcApprox);
}
// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value materializeErfcApproximationF32ForMagnitudeGeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const double kMaxlog = 88.72283905206835;
const float kErfcPCoefficients[] = {
+2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
-5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
+3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
};
const float kErfcRCoefficients[] = {
-1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
+2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
-2.820767439740514E-1, +5.641895067754075E-1,
};
// Let z = -x^2.
Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
// Materialize polynomial approximation for x >= 1 as
// erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
// erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value reciprocalXSq = rewriter.create<mhlo::DivOp>(loc, one, xSq);
Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
Value oneDivAbsX = rewriter.create<mhlo::DivOp>(loc, one, absX);
Value expZMulOneDivAbsX = rewriter.create<mhlo::MulOp>(loc, expZ, oneDivAbsX);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value absXLtTwo = rewriter.create<mhlo::CompareOp>(
loc, absX, two, mhlo::ComparisonDirection::LT);
Value polP = materializePolynomialApproximation(
rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcPCoefficients));
Value polR = materializePolynomialApproximation(
rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcRCoefficients));
Value poly = rewriter.create<mhlo::SelectOp>(loc, absXLtTwo, polP, polR);
Value erfcApprox = rewriter.create<mhlo::MulOp>(loc, expZMulOneDivAbsX, poly);
// Clamp to prevent overflow and materialize approximation for large x as
// erfc(x) = 0.
Value zLtNeqMaxlog = rewriter.create<mhlo::CompareOp>(
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
mhlo::ComparisonDirection::LT);
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
Value erfcApproxClamped =
rewriter.create<mhlo::SelectOp>(loc, zLtNeqMaxlog, zero, erfcApprox);
// Derive approximation for x <= -1 as
// erfc(x) = 2 - erfc(-x).
// Reuse previously materialized approximations all of which take |x| as their
// argument.
Value xLtZero = rewriter.create<mhlo::CompareOp>(
loc, x, zero, mhlo::ComparisonDirection::LT);
Value twoSubErfcApprox =
rewriter.create<mhlo::SubOp>(loc, two, erfcApproxClamped);
return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApprox,
erfcApproxClamped);
}
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value materializeErfApproximationF32ForMagnitudeLeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const float kErfTCoefficients[] = {
+7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
-2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
+1.128379165726710E+0,
};
// Materialize polynomial approximation for |x| <= 1 as
// erf(x) = x T(x^2).
Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value polyT = materializePolynomialApproximation(
rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
return rewriter.create<mhlo::MulOp>(loc, x, polyT);
}
// This is the same approximation as used in Eigen.
Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const float kAlpha[] = {
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
const float kBeta[] = {
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
// Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4] as
// erf(x) = x * Alpha(x^2) / Beta(x^2).
Value alphaPoly = materializePolynomialApproximation(
rewriter, loc, xSq, llvm::makeArrayRef(kAlpha));
Value betaPoly = materializePolynomialApproximation(
rewriter, loc, xSq, llvm::makeArrayRef(kBeta));
Value xMulAlphaPoly = rewriter.create<mhlo::MulOp>(loc, x, alphaPoly);
return rewriter.create<mhlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
}
Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
// Rely on erfc approximation for |x| >= 1
// erfc(x) = erfc_approx(x)
Value erfcApprox =
materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);
// Rely on erf approximation for |x| < 1 and materialize erfc as
// erfc(x) = 1 - erf_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erfApprox =
materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
Value erfBasedApprox = rewriter.create<mhlo::SubOp>(loc, one, erfApprox);
// Materialize approximation selection based on argument.
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value absXLtOne = rewriter.create<mhlo::CompareOp>(
loc, absX, one, mhlo::ComparisonDirection::LT);
return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
erfcApprox);
}
Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args, FloatType minPrecisionTy,
Value callback(ConversionPatternRewriter &,
Location, ValueRange)) {
auto originalTy = getElementTypeOrSelf(args.front().getType());
auto floatOriginalTy = originalTy.dyn_cast<FloatType>();
bool needsUpcast =
floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();
// Upcast arguments if necessary.
llvm::SmallVector<Value, 2> castedArgs;
if (needsUpcast) {
for (Value a : args) {
castedArgs.push_back(
rewriter.create<mhlo::ConvertOp>(loc, a, minPrecisionTy));
}
args = castedArgs;
}
Value result = callback(rewriter, loc, args);
// Cast back if necessary.
if (needsUpcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, originalTy);
}
return result;
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value x = adaptor.operand();
Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64, f32, f16 and bf16.
if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
return failure();
if (ty.isF64()) {
rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
return success();
}
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
rewriter.getF32Type(),
&materializeErfApproximationF32));
return success();
}
};
struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
using OpConversionPattern<ErfcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfcOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value x = adaptor.operand();
Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64, f32, f16 and bf16.
if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
return failure();
if (ty.isF64()) {
rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
return success();
}
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
rewriter.getF32Type(),
&materializeErfcApproximationF32));
return success();
}
};
// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
// [7, 9] seemed to be the least sensitive to the quality of the log function.
// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
// for a particularly inaccurate log function.
constexpr double kLanczosGamma = 7; // aka g
constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// lgamma(z + 1) = (log(2) + log(pi)) / 2
// + (z + 1/2) * log(t(z))
// - t(z) + log(a(z))
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value needToReflect = rewriter.create<mhlo::CompareOp>(
loc, x, half, mhlo::ComparisonDirection::LT);
Value negX = rewriter.create<mhlo::NegOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1, x);
Value xSubOne = rewriter.create<mhlo::SubOp>(loc, x, one);
Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
Value quotient = rewriter.create<mhlo::DivOp>(
loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex));
a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczosPlusHalf =
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
Value logTerm =
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
// Note that t(z) may be large and we need to be careful not to overflow to
// infinity in the relevant term
// r = (z + 1/2) * log(t(z)) - t(z).
// Therefore, we compute this as
// r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
Value tDivLogT = rewriter.create<mhlo::DivOp>(loc, t, logT);
Value sum = rewriter.create<mhlo::SubOp>(
loc, rewriter.create<mhlo::AddOp>(loc, z, half), tDivLogT);
Value r = rewriter.create<mhlo::MulOp>(loc, sum, logT);
// Compute the final result (modulo reflection) as
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
Value logA = rewriter.create<mhlo::LogOp>(loc, a);
Value lgamma = rewriter.create<mhlo::AddOp>(
loc,
rewriter.create<mhlo::AddOp>(
loc,
getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
r),
logA);
// Compute the reflected value for x < 0.5 as
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
//
// The abs is needed because lgamma is the log of the absolute value of the
// gamma function.
//
// We have to be careful when computing the final term above. gamma(x) goes
// to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
// term. The slope is large, so precision is particularly important.
//
// Because abs(sin(pi * x)) has period of 1 we can equivalently use
// abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
// more numerically accurate: It doesn't overflow to inf like pi * x would and
// if x is an integer it evaluates to exactly 0 which is important because we
// then take the log of this value, and log(0) is inf.
//
// We don't have a frac(x) primitive in HLO and computing it is tricky, but
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
// purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
//
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
// [0, 1] is symmetric across the line Y=0.5.
//
// Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
// pi * abs_frac for values of abs_frac close to 1.
Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
Value absFrac = rewriter.create<mhlo::SubOp>(
loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
Value reduceAbsFrac = rewriter.create<mhlo::CompareOp>(
loc, half, absFrac, mhlo::ComparisonDirection::LT);
absFrac = rewriter.create<mhlo::SelectOp>(
loc, reduceAbsFrac, rewriter.create<mhlo::SubOp>(loc, one, absFrac),
absFrac);
// Materialize reflection.
Value reflectionDenom = rewriter.create<mhlo::LogOp>(
loc,
rewriter.create<mhlo::SinOp>(
loc, rewriter.create<mhlo::MulOp>(
loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
Value lgammaReflection = rewriter.create<mhlo::SubOp>(
loc,
rewriter.create<mhlo::SubOp>(
loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
reflectionDenom),
lgamma);
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
// then it "wins" and the result is +/-inf.
Value finiteReflectionDenom =
rewriter.create<mhlo::IsFiniteOp>(loc, reflectionDenom);
Value negReflectionDenom = rewriter.create<mhlo::NegOp>(loc, reflectionDenom);
lgammaReflection = rewriter.create<mhlo::SelectOp>(
loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);
// Select whether or not to rely on the reflection.
lgamma = rewriter.create<mhlo::SelectOp>(loc, needToReflect, lgammaReflection,
lgamma);
// Materialize +/-inf behavior as
// lgamma(+/-inf) = +inf.
Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
return rewriter.create<mhlo::SelectOp>(
loc, xIsInf,
chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
lgamma);
}
// Uses `rewriter` to materialize the IR for generating a constant tensor of
// log(1/2) values with the same shape and type as `operand`, and associates the
// generated IR to code location `loc`.
//
// Since we currently only support generating integer constants, we actually
// generate the code for -log(2) (which equals log(1/2)).
// TODO(b/190374484): Remove when mhlo::ConstantLikeOp supports complex types.
Value materializeLogOneHalf(ConversionPatternRewriter &rewriter, Location loc,
Value operand) {
auto resultTy = operand.getType().cast<ShapedType>();
Value two = rewriter.create<mhlo::ConstOp>(
loc, hlo::GetScalarOfType(getElementTypeOrSelf(operand.getType()), 2));
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
Value twoWithOperandShape = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, resultTy, two, shape, rewriter.getI64TensorAttr({}));
Value logTwo = rewriter.create<mhlo::LogOp>(loc, twoWithOperandShape);
return rewriter.create<mhlo::NegOp>(loc, logTwo);
}
// Express `cosh` as
// cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2))
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
CoshOp::Adaptor transformed(operands);
Value x = transformed.operand();
// TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
Value logOneHalf = materializeLogOneHalf(rewriter, loc, x);
Value expAdd = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
Value expSub = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::SubOp>(loc, logOneHalf, x));
return rewriter.create<mhlo::AddOp>(loc, expAdd, expSub);
}
struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
using OpConversionPattern<CoshOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
CoshOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
rewriter.getF32Type(),
&materializeCoshApproximation));
return success();
}
};
// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value needToReflect = rewriter.create<mhlo::CompareOp>(
loc, x, half, mhlo::ComparisonDirection::LT);
Value negX = rewriter.create<mhlo::NegOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1, x);
Value xSubOne = rewriter.create<mhlo::SubOp>(loc, x, one);
Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value zero = getConstantLike(rewriter, loc, 0.0, x);
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
Value aPrime = zero;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
Value zTerm = rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex);
aPrime = rewriter.create<mhlo::SubOp>(
loc, aPrime,
rewriter.create<mhlo::DivOp>(
loc, coeff, rewriter.create<mhlo::MulOp>(loc, zTerm, zTerm)));
a = rewriter.create<mhlo::AddOp>(
loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, zTerm));
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczosPlusHalf =
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
Value logTerm =
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
// Materialize the final result (modulo reflection) as
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
Value aPrimeDivA = rewriter.create<mhlo::DivOp>(loc, aPrime, a);
Value lanczosGammaDivT = rewriter.create<mhlo::DivOp>(
loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
Value digamma = rewriter.create<mhlo::SubOp>(
loc, rewriter.create<mhlo::AddOp>(loc, logT, aPrimeDivA),
lanczosGammaDivT);
// We need to be careful how we compute cot(pi * input) below: For
// near-integral arguments, pi * input can lose precision.
//
// Input is already known to be less than 0.5 (otherwise we don't have to
// reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
// increase precision of pi * x and the resulting cotangent.
Value reducedX = rewriter.create<mhlo::AddOp>(
loc, x,
rewriter.create<mhlo::AbsOp>(
loc, rewriter.create<mhlo::FloorOp>(
loc, rewriter.create<mhlo::AddOp>(
loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
// Materialize reflection for inputs less than 0.5 as
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
Value pi = getConstantLike(rewriter, loc, M_PI, x);
Value piMulReducedX = rewriter.create<mhlo::MulOp>(loc, pi, reducedX);
Value cos = rewriter.create<mhlo::CosOp>(loc, piMulReducedX);
Value sin = rewriter.create<mhlo::SinOp>(loc, piMulReducedX);
Value reflection = rewriter.create<mhlo::SubOp>(
loc, digamma,
rewriter.create<mhlo::DivOp>(
loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
// Select whether or not to rely on the reflection.
digamma =
rewriter.create<mhlo::SelectOp>(loc, needToReflect, reflection, digamma);
// Digamma has poles at negative integers and zero; return nan for those.
Value isLeZero = rewriter.create<mhlo::CompareOp>(
loc, x, zero, mhlo::ComparisonDirection::LE);
Value isInt = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
mhlo::ComparisonDirection::EQ);
Value isPole = rewriter.create<mhlo::AndOp>(loc, isLeZero, isInt);
return rewriter.create<mhlo::SelectOp>(
loc, isPole,
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
x),
digamma);
}
Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
assert(args.size() == 2);
Value x = args[0];
Value q = args[1];
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
Value a = q;
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
Value negPower = zero;
Value negX = rewriter.create<mhlo::NegOp>(loc, x);
Value initialSum = rewriter.create<mhlo::PowOp>(loc, q, negX);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
for (int i = 0; i < 9; ++i) {
a = rewriter.create<mhlo::AddOp>(loc, a, one);
negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
initialSum = rewriter.create<mhlo::AddOp>(loc, initialSum, negPower);
}
a = rewriter.create<mhlo::AddOp>(loc, a, one);
negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
Value oneLikeX = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value xMinusOne = rewriter.create<mhlo::SubOp>(loc, x, oneLikeX);
Value negPowerMulA = rewriter.create<mhlo::MulOp>(loc, negPower, a);
Value negPowerMulADivXMinusOne =
rewriter.create<mhlo::DivOp>(loc, negPowerMulA, xMinusOne);
Value s =
rewriter.create<mhlo::AddOp>(loc, initialSum, negPowerMulADivXMinusOne);
Value aInverseSquare = rewriter.create<mhlo::DivOp>(
loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
Value hornerSum = zero;
Value factor = one;
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
Value factorLhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
Value factorRhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
factor = rewriter.create<mhlo::MulOp>(loc, factorLhs, factorRhs);
hornerSum = rewriter.create<mhlo::MulOp>(
loc, factor,
rewriter.create<mhlo::MulOp>(
loc, aInverseSquare,
rewriter.create<mhlo::AddOp>(
loc, hornerSum,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
}
Value zeroPointFiveLikeNegPower =
chlo::getConstantLike(rewriter, loc, .5, negPower);
Value xDivA = rewriter.create<mhlo::DivOp>(loc, x, a);
s = rewriter.create<mhlo::AddOp>(
loc, s,
rewriter.create<mhlo::MulOp>(
loc, negPower,
rewriter.create<mhlo::AddOp>(
loc, zeroPointFiveLikeNegPower,
rewriter.create<mhlo::MulOp>(
loc, xDivA,
rewriter.create<mhlo::AddOp>(
loc,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
a),
hornerSum)))));
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
Value absNegPower = rewriter.create<mhlo::AbsOp>(loc, negPower);
Value absInitialSum = rewriter.create<mhlo::AbsOp>(loc, initialSum);
Value output = rewriter.create<mhlo::SelectOp>(
loc,
rewriter.create<mhlo::CompareOp>(
loc, absNegPower,
rewriter.create<mhlo::MulOp>(
loc, absInitialSum,
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
mhlo::ComparisonDirection::LT),
initialSum, s);
// Function is not defined for x < 1.
Value nan = chlo::getConstantLike(
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
output = rewriter.create<mhlo::SelectOp>(
loc,
rewriter.create<mhlo::CompareOp>(loc, x, oneLikeX,
mhlo::ComparisonDirection::LT),
nan, output);
// For q <= 0, x must be an integer.
Value qLeZero = rewriter.create<mhlo::CompareOp>(
loc, q, zero, mhlo::ComparisonDirection::LE);
Value xNotInt = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
mhlo::ComparisonDirection::NE);
Value xDomainError = rewriter.create<mhlo::AndOp>(loc, qLeZero, xNotInt);
output = rewriter.create<mhlo::SelectOp>(loc, xDomainError, nan, output);
// For all integer q <= 0, zeta has a pole. The limit is only defined as
// +inf if x is and even integer.
Value inf = chlo::getConstantLike(rewriter, loc,
std::numeric_limits<double>::infinity(), x);
Value qIsInt = rewriter.create<mhlo::CompareOp>(
loc, q, rewriter.create<mhlo::FloorOp>(loc, q),
mhlo::ComparisonDirection::EQ);
Value atPole = rewriter.create<mhlo::AndOp>(loc, qLeZero, qIsInt);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value xIsInt = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
mhlo::ComparisonDirection::EQ);
Value xIsEven = rewriter.create<mhlo::CompareOp>(
loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero,
mhlo::ComparisonDirection::EQ);
Value xIsEvenInt = rewriter.create<mhlo::AndOp>(loc, xIsInt, xIsEven);
output = rewriter.create<mhlo::SelectOp>(
loc, atPole, rewriter.create<mhlo::SelectOp>(loc, xIsEvenInt, inf, nan),
output);
// For x = 1, this is the harmonic series and diverges.
output = rewriter.create<mhlo::SelectOp>(
loc,
rewriter.create<mhlo::CompareOp>(loc, x, one,
mhlo::ComparisonDirection::EQ),
inf, output);
return output;
}
Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
PolygammaOp::Adaptor transformed(args);
Value n = transformed.n();
Value x = transformed.x();
// Handle integer n > 0.
Value one = getConstantLike(rewriter, loc, 1.0, x);
Value two = getConstantLike(rewriter, loc, 2.0, x);
Value sign = rewriter.create<mhlo::SubOp>(
loc,
rewriter.create<mhlo::MulOp>(loc, two,
rewriter.create<mhlo::RemOp>(loc, n, two)),
one);
Value nPlusOne = rewriter.create<mhlo::AddOp>(loc, n, one);
Value expLgammaNp1 = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
Value result = rewriter.create<mhlo::MulOp>(
loc, rewriter.create<mhlo::MulOp>(loc, sign, expLgammaNp1), zeta);
// Handle n = 0.
Value zero = getConstantLike(rewriter, loc, 0.0, x);
Value nEqZero = rewriter.create<mhlo::CompareOp>(
loc, n, zero, mhlo::ComparisonDirection::EQ);
result = rewriter.create<mhlo::SelectOp>(
loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);
// Check that n is a natural number. Return nan, otherwise.
Value nonInt = rewriter.create<mhlo::CompareOp>(
loc, n, rewriter.create<mhlo::FloorOp>(loc, n),
mhlo::ComparisonDirection::NE);
Value negative = rewriter.create<mhlo::CompareOp>(
loc, n, zero, mhlo::ComparisonDirection::LT);
Value nonNatural = rewriter.create<mhlo::OrOp>(loc, nonInt, negative);
return rewriter.create<mhlo::SelectOp>(
loc, nonNatural,
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
x),
result);
}
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LgammaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FloatType minPrecisionTy = rewriter.getF32Type();
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
minPrecisionTy, &materializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FloatType minPrecisionTy = rewriter.getF32Type();
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
minPrecisionTy, &materializeDigamma));
return success();
}
};
Value materializeNextAfter(ConversionPatternRewriter &rewriter, Location loc,
ValueRange operands) {
NextAfterOp::Adaptor transformed(operands);
Value x = transformed.x();
Value y = transformed.y();
auto resultTy = x.getType().cast<ShapedType>();
auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
ImplicitLocOpBuilder b(loc, rewriter);
auto intTy = resultTy.clone(b.getIntegerType(bitwidth));
auto xAsInt = b.create<mhlo::BitcastConvertOp>(intTy, x);
auto yAsInt = b.create<mhlo::BitcastConvertOp>(intTy, y);
// The result is NaN if either "x" or "y" are NaN.
auto xIsNan = b.create<mhlo::CompareOp>(x, x, mhlo::ComparisonDirection::NE);
auto yIsNan = b.create<mhlo::CompareOp>(y, y, mhlo::ComparisonDirection::NE);
auto nanInput = b.create<mhlo::OrOp>(xIsNan, yIsNan);
auto resultForNan = getConstantLike(
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
auto resultForNanAsInt =
b.create<mhlo::BitcastConvertOp>(intTy, resultForNan);
// The sign bit is the MSB.
const int64_t signBit = int64_t{1} << (bitwidth - 1);
// Discard the sign bit to make the result non-negative.
auto signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
auto negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
auto xAbs = b.create<mhlo::AndOp>(xAsInt, negatedSignMask);
auto yAbs = b.create<mhlo::AndOp>(yAsInt, negatedSignMask);
// When both "x" and "y" are equal, the result is "y".
auto xAndYAreEqual =
b.create<mhlo::CompareOp>(x, y, mhlo::ComparisonDirection::EQ);
auto resultForEqual = yAsInt;
// When both "x" and "y" are 0, the result is "y". This is a separate case
// from above because "x" and "y" might have a different sign.
auto zero = getConstantLike(rewriter, loc, 0, xAsInt);
auto xIsZero =
b.create<mhlo::CompareOp>(xAbs, zero, mhlo::ComparisonDirection::EQ);
auto yIsZero =
b.create<mhlo::CompareOp>(yAbs, zero, mhlo::ComparisonDirection::EQ);
auto resultForBothZero = yAsInt;
auto xSign = b.create<mhlo::AndOp>(xAsInt, signMask);
auto ySign = b.create<mhlo::AndOp>(yAsInt, signMask);
// If from == 0 && to != 0, we need to return the smallest subnormal number
// signed like "to".
auto one = getConstantLike(rewriter, loc, 1, xAsInt);
auto resultForXZeroYNonZero = b.create<mhlo::OrOp>(ySign, one);
// If the sign of "x" and "y" disagree:
// - we need to make the magnitude of "from" smaller so that it is closer to
// zero.
//
// Otherwise the signs agree:
// - "x" with a magnitude larger than "y" means we need to make the magnitude
// smaller.
// - "x" with a magnitude smaller than "y" means we need to make the magnitude
// larger.
auto signsDisagree =
b.create<mhlo::CompareOp>(xSign, ySign, mhlo::ComparisonDirection::NE);
auto xMagnitudeLargerThanY =
b.create<mhlo::CompareOp>(xAbs, yAbs, mhlo::ComparisonDirection::GT);
auto resultHasSmallerMagnitude =
b.create<mhlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
auto magnitudeAdjustment =
b.create<mhlo::SelectOp>(resultHasSmallerMagnitude, minusOne, one);
Value result = b.create<mhlo::AddOp>(xAsInt, magnitudeAdjustment);
// Handle from == +-0.
result = b.create<mhlo::SelectOp>(
xIsZero,
b.create<mhlo::SelectOp>(yIsZero, resultForBothZero,
resultForXZeroYNonZero),
result);
// Handle from == to.
result = b.create<mhlo::SelectOp>(xAndYAreEqual, resultForEqual, result);
// Handle isnan(x) || isnan(y).
result = b.create<mhlo::SelectOp>(nanInput, resultForNanAsInt, result);
// Cast back to the original type.
return b.create<mhlo::BitcastConvertOp>(resultTy, result);
}
struct ConvertNextAfterOp : public OpConversionPattern<NextAfterOp> {
using OpConversionPattern<NextAfterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
NextAfterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(
op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
return success();
}
};
struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
using OpConversionPattern<PolygammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
PolygammaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
FloatType minPrecisionTy = rewriter.getF32Type();
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
minPrecisionTy, &materializePolygamma));
return success();
}
};
// Sinh(x) = (e^x - e^-x) / 2
// = e^(x + log(1/2)) - e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to +/-inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
// TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
Value logOneHalf = materializeLogOneHalf(rewriter, loc, x);
Value expAdd = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
Value expSub = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::SubOp>(loc, logOneHalf, x));
return rewriter.create<mhlo::SubOp>(loc, expAdd, expSub);
}
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value largeSinhResult =
materializeSinhApproximationForLargeX(rewriter, loc, operands);
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
// 0.
// Rewrite this to avoid that. We use expm1(x) because that preserves the
// first order term of the taylor series of e^x.
// (e^(x) - e^(-x)) / 2. =
// (e^(x) - 1 + 1 - e^(-x)) / 2.
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
Value expm1 = rewriter.create<mhlo::Expm1Op>(loc, x);
Value one = getConstantLike(rewriter, loc, 1.0, x);
Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
Value expm1PlusOne = rewriter.create<mhlo::AddOp>(loc, expm1, one);
Value ratio = rewriter.create<mhlo::DivOp>(loc, expm1, expm1PlusOne);
Value sum = rewriter.create<mhlo::AddOp>(loc, expm1, ratio);
Value smallSinhResult = rewriter.create<mhlo::MulOp>(loc, oneHalf, sum);
Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
Value absXLtOne = rewriter.create<mhlo::CompareOp>(
loc, absX, one, mhlo::ComparisonDirection::LT);
return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, smallSinhResult,
largeSinhResult);
}
struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
using OpConversionPattern<SinhOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
SinhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value x = adaptor.operand();
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
rewriter, op.getLoc(), adaptor.getOperands()));
return success();
}
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
rewriter.getF32Type(),
&materializeSinhApproximation));
return success();
}
};
Value materializeTan(ConversionPatternRewriter &rewriter, Location loc,
ValueRange operands) {
TanOp::Adaptor transformed(operands);
return rewriter.create<mhlo::DivOp>(
loc, rewriter.create<mhlo::SinOp>(loc, transformed.operand()),
rewriter.create<mhlo::CosOp>(loc, transformed.operand()));
}
struct ConvertTanOp : public OpConversionPattern<TanOp> {
using OpConversionPattern<TanOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
TanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
rewriter.getF32Type(), &materializeTan));
return success();
}
};
// Converts chlo.top_k to MHLO iota, sort, and slice ops.
//
// chlo.top_k sorts along last dimension of the input tensor and then returns
// the top K components' values and indices. This is translated into a few
// ops in MHLO: first generating an integer sequence for the indices,
// then sort both the original input tensor and the indices togheter, and
// at last slice out the top K components.
//
// For example, for the following IR:
//
// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
// (tensor<16x8xf32>, tensor<16x8xi32>)
//
// We will get:
//
// %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
// %2 = "mhlo.sort"(%input, %1) ({
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
// %arg3: tensor<i32>, %arg4: tensor<i32>):
// %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
// "mhlo.return"(%7) : (tensor<i1>) -> ()
// }) {dimension = 1 : i64, is_stable = true} : ...
// %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
// %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
// %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
// start_indices dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<16x16xf32>) -> tensor<16x8xf32>
// %6 = "mhlo.slice"(%4) ...
struct ConvertTopKOp : public OpConversionPattern<TopKOp> {
using OpConversionPattern<TopKOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
TopKOp op, OpAdaptor /*adaptor*/,
ConversionPatternRewriter &rewriter) const override {
// The last dimension of the operand's shape should be known so we can have
// clamped end_indices for slices. This is verified by the op.
auto operandType = op.operand().getType().cast<RankedTensorType>();
int64_t operandRank = operandType.getRank();
int64_t lastDimIndex = operandRank - 1;
int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
assert(lastDimSize != ShapedType::kDynamicSize);
// Create an Iota op for indices.
auto i32Type = rewriter.getIntegerType(32);
Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
Value iotaOp = rewriter.create<mhlo::IotaOp>(
op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
// Create the sort op. It takes two inputs, one for the original input, the
// other for the indices. Use TOTALORDER comparison type instead of the
// default comparison if the element type is of type float.
Type elementType = operandType.getElementType();
auto sortOp = CreateSortOp(&rewriter, op.getLoc(), {op.operand(), iotaOp},
{elementType, i32Type}, lastDimIndex,
/*is_stable=*/true,
/*direction=*/mhlo::ComparisonDirection::GT);
// Get the sorted input and index tuple element.
auto tupleFirstElement = sortOp.getResult(0);
auto tupleSecondElement = sortOp.getResult(1);
SmallVector<int64_t, 4> beginIndices(operandRank, 0);
auto endIndices = llvm::to_vector<4>(operandType.getShape());
endIndices.back() = std::min(static_cast<int64_t>(op.k()), lastDimSize);
SmallVector<int64_t, 4> strides(operandRank, 1);
// Get the slice for the top K elements.
auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
Value values = rewriter.create<mhlo::SliceOp>(
op.getLoc(), tupleFirstElement,
DenseIntElementsAttr::get(indicesTy, beginIndices),
DenseIntElementsAttr::get(indicesTy, endIndices),
DenseIntElementsAttr::get(indicesTy, strides));
Value indices = rewriter.create<mhlo::SliceOp>(
op.getLoc(), tupleSecondElement,
DenseIntElementsAttr::get(indicesTy, beginIndices),
DenseIntElementsAttr::get(indicesTy, endIndices),
DenseIntElementsAttr::get(indicesTy, strides));
rewriter.replaceOp(op, {values, indices});
return success();
}
};
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ZetaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
FloatType minPrecisionTy = rewriter.getF32Type();
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
minPrecisionTy, &materializeZeta));
return success();
}
};
struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
BroadcastSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
Value pred = adaptor.pred();
Value onTrue = adaptor.on_true();
Value onFalse = adaptor.on_false();
auto predType = pred.getType().dyn_cast<RankedTensorType>();
auto onTrueType = onTrue.getType().dyn_cast<RankedTensorType>();
auto onFalseType = onFalse.getType().dyn_cast<RankedTensorType>();
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!predType || !onTrueType || !onFalseType || !resultType) {
return failure();
}
auto loc = op.getLoc();
Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
int64_t resultRank = std::max(
{predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});
Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
loc, ValueRange{predShape, onTrueShape, onFalseShape});
auto assumingOp = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{resultType}, broadcastableCstr);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assumingOp.getDoRegion());
Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(op.getContext()),
ValueRange{predShape, onTrueShape, onFalseShape},
/*error=*/nullptr);
auto shapeType =
RankedTensorType::get({resultRank}, rewriter.getIndexType());
resultExtents =
rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);
Value broadcastedPred = pred;
// Pred has an implicit broadcast for scalars, so use that when convenient.
if (predType.getRank() > 0) {
auto predBroadcastDimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
broadcastedPred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(resultType.getShape(),
predType.getElementType()),
pred, resultExtents,
rewriter.getI64TensorAttr(predBroadcastDimensions));
}
auto onTrueBroadcastDimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
Value broadcastedOnTrue = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(resultType.getShape(),
onTrueType.getElementType()),
onTrue, resultExtents,
rewriter.getI64TensorAttr(onTrueBroadcastDimensions));
auto onFalseBroadcastDimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
Value broadcastedOnFalse = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(resultType.getShape(),
onFalseType.getElementType()),
onFalse, resultExtents,
rewriter.getI64TensorAttr(onFalseBroadcastDimensions));
// And generate the final non-broadcasted ternary op.
Value finalResult =
rewriter.create<mhlo::SelectOp>(loc, resultType, broadcastedPred,
broadcastedOnTrue, broadcastedOnFalse);
rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
rewriter.replaceOp(op, {assumingOp.getResult(0)});
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite for statically determinable non-broadcasting cases.
auto lhsType =
adaptor.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhsType =
adaptor.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType) return failure();
// Requires rank broadcast.
if (lhsType.getRank() != rhsType.getRank()) return failure();
// Any dynamic dimension may require broadcasting and requires more
// analysis.
if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape())
return failure();
for (auto extents : llvm::zip(lhsType.getShape(), rhsType.getShape())) {
auto lhsExtent = std::get<0>(extents);
auto rhsExtent = std::get<1>(extents);
if (lhsExtent != rhsExtent) {
return failure();
}
}
rewriter.replaceOp(op,
{Adaptor::CreateOp(op, op.getResult().getType(),
adaptor.getOperands(), rewriter)});
return success();
}
};
// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding mhlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
// Specifically, this includes the following cases:
// - Same rank broadcast (operands have the same static rank).
// - Different-rank broadcast, either without a broadcast_dims attribte or
// with the broadcast_dims attribute set to map to a prefix padding.
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
// The restriction on broadcast_dims derives from the definition of the
// `shape.broadcast` op, which only supports prefix-padding.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only support ranked operands.
Value lhs = adaptor.lhs();
Value rhs = adaptor.rhs();
auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
auto resultType =
op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType || !resultType) return failure();
// Check for "numpy"-style rank broadcast.
auto broadcastDimensions = op.broadcast_dimensions();
if (broadcastDimensions &&
!hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
// it is incompatible with unranked inputs. If this warning is emitted
// in real programs, it is an indication that the feature should be
// implemented versus just falling back on the more standard definition
// of numpy-like prefix-padding.
op.emitWarning() << "unsupported non prefix-padded dynamic rank "
<< "broadcast_dimensions = " << *broadcastDimensions;
return failure();
}
// Compute result shape.
auto loc = op.getLoc();
// Insert a constraint on the shapes being broadcastable and insert all
// future code into an assuming block reliant on the constraint.
Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
auto broadcastableCstr =
rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
auto assumingOp = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assumingOp.getDoRegion());
int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
Value resultExtents =
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
// downstream canonicalizations fold them away if possible. This is
// because, in the dynamic case, there are many corner cases regarding
// when it is safe to omit, and some of them require analysis to prove
// properly.
auto lhsBroadcastDimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
Value broadcastedLhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(resultType.getShape(), lhsType.getElementType()),
lhs, resultExtents, rewriter.getI64TensorAttr(lhsBroadcastDimensions));
auto rhsBroadcastDimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
Value broadcastedRhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(resultType.getShape(), rhsType.getElementType()),
rhs, resultExtents, rewriter.getI64TensorAttr(rhsBroadcastDimensions));
// And generate the final non-broadcasted binary op.
Value finalResult = Adaptor::CreateOp(
op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
rewriter.replaceOp(op, {assumingOp.getResult(0)});
return success();
}
};
class ConvertDynamicReshapeOp
: public OpRewritePattern<chlo::DynamicReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto tensor = op.operand();
auto shape = op.output_shape();
auto shapeTy = shape.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
Value cstr = rewriter.create<mhlo::CstrReshapableOp>(loc, numEls, shape);
rewriter.replaceOpWithNewOp<shape::AssumingOp>(
op, cstr, [&](OpBuilder &b, Location l) {
Value computedShape =
b.create<mhlo::ComputeReshapeShapeOp>(l, shapeTy, numEls, shape);
SmallVector<Value> result;
result.push_back(b.create<mhlo::DynamicReshapeOp>(l, resultTy, tensor,
computedShape));
return result;
});
return success();
}
};
#include "generated_chlo_legalize_to_hlo.inc"
} // namespace
void PopulateChloBroadcastingPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved.
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
context, patterns, 10);
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns
->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
context);
}
void PopulateDecomposeChloPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
populateWithGenerated(*patterns);
// Other patterns.
// clang-format off
patterns->add<ConvertCoshOp,
ConvertDigammaOp,
ConvertErfOp,
ConvertErfcOp,
ConvertLgammaOp,
ConvertNextAfterOp,
ConvertPolygammaOp,
ConvertSinhOp,
ConvertTanOp,
ConvertTopKOp,
ConvertZetaOp>(context);
// clang-format on
}
} // namespace chlo
} // namespace mlir