| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Enable the use of M_* math constants. |
| // NOTE: this must be first in the file to ensure that if cmath is transitively |
| // included by any other header it has the define set on first processing. |
| // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants |
| #define _USE_MATH_DEFINES |
| #include <cmath> |
| #include <numeric> |
| #include <vector> |
| |
| #include "llvm/ADT/SmallVector.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" |
| #include "mlir-hlo/utils/broadcast_utils.h" |
| #include "mlir-hlo/utils/hlo_utils.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace chlo { |
| namespace { |
| |
| struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> { |
| using OpConversionPattern<ConstantLikeOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ConstantLikeOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto result_ty = op.getType().cast<ShapedType>(); |
| |
| // Unranked uses are not supported. Consider `mhlo-transform-unranked-hlo`. |
| if (!result_ty.hasRank()) return failure(); |
| |
| // Lower to MHLO constant if statically shaped. |
| if (result_ty.hasStaticShape()) { |
| rewriter.replaceOpWithNewOp<mhlo::ConstOp>( |
| op, DenseElementsAttr::get(result_ty, op.value())); |
| return success(); |
| } |
| |
| // Lower to broadcasted constant. |
| ConstantLikeOp::Adaptor transformed(operands); |
| auto loc = op.getLoc(); |
| Type extent_tensor_type = shape::getExtentTensorType(op.getContext()); |
| Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value()); |
| Value uncasted_shape = rewriter.create<shape::ShapeOfOp>( |
| loc, extent_tensor_type, transformed.operand()); |
| Type shape_ty = |
| RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); |
| Value shape = |
| rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape); |
| rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>( |
| op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); |
| return success(); |
| } |
| }; |
| |
| template <typename FTy> |
| Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, Value x, |
| const std::vector<FTy> &coefficients) { |
| Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x); |
| for (FTy c : coefficients) { |
| poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x); |
| poly = rewriter.create<mhlo::AddOp>( |
| loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x)); |
| } |
| return poly; |
| } |
| |
| // Precondition is |x| >= 1. Use erf approximation, otherwise. |
| // |
| // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an |
| // argument and derive the final approximation for all |x| >= 1. |
| // This implementation is based on Cephes. |
| Value MaterializeErfcApproximationF64ForMagnituteGEOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF64() && |
| "expect f64 element type"); |
| const double kMaxlog = 7.09782712893383996843E2; |
| const std::vector<double> kErfcPCoefficients{ |
| 2.46196981473530512524E-10, 5.64189564831068821977E-1, |
| 7.46321056442269912687E0, 4.86371970985681366614E1, |
| 1.96520832956077098242E2, 5.26445194995477358631E2, |
| 9.34528527171957607540E2, 1.02755188689515710272E3, |
| 5.57535335369399327526E2}; |
| const std::vector<double> kErfcQCoefficients{ |
| 1.00000000000000000000E0, 1.32281951154744992508E1, |
| 8.67072140885989742329E1, 3.54937778887819891062E2, |
| 9.75708501743205489753E2, 1.82390916687909736289E3, |
| 2.24633760818710981792E3, 1.65666309194161350182E3, |
| 5.57535340817727675546E2}; |
| const std::vector<double> kErfcRCoefficients{ |
| 5.64189583547755073984E-1, 1.27536670759978104416E0, |
| 5.01905042251180477414E0, 6.16021097993053585195E0, |
| 7.40974269950448939160E0, 2.97886665372100240670E0}; |
| const std::vector<double> kErfcSCoefficients{ |
| 1.00000000000000000000E0, 2.26052863220117276590E0, |
| 9.39603524938001434673E0, 1.20489539808096656605E1, |
| 1.70814450747565897222E1, 9.60896809063285878198E0, |
| 3.36907645100081516050E0}; |
| |
| // Let z = -x^2. |
| Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); |
| Value z = rewriter.create<mhlo::NegOp>(loc, x_sq); |
| |
| // Materialize polynomial approximation for x in [1, 8) as |
| // erfc(x) = exp(z) P(|x|) / Q(|x|). |
| Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z); |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x, |
| kErfcPCoefficients); |
| Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p); |
| Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x, |
| kErfcQCoefficients); |
| Value erfc_approx_1_8 = |
| rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q); |
| |
| // Materialize polynomial approximation for x in >= 8 as |
| // erfc(x) exp(z) R(|x|) / S(|x|). |
| Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x, |
| kErfcRCoefficients); |
| Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r); |
| Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x, |
| kErfcSCoefficients); |
| Value erfc_approx_8_inf = |
| rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s); |
| |
| // Combine polynomial approximations for x >= 1. |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x); |
| Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT); |
| Value erfc_approx = rewriter.create<mhlo::SelectOp>( |
| loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf); |
| |
| // Clamp to prevent overflow and materialize approximation for large x as |
| // erfc(x) = 0. |
| Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>( |
| loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT); |
| Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); |
| Value erfc_approx_clamped = |
| rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx); |
| |
| // Derive approximation for x <= -1 as |
| // erfc(x) = 2 - erfc(-x). |
| // Reuse previously materialized approximations all of which take |x| as their |
| // argument. |
| Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT); |
| Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); |
| Value two_sub_erfc_approx_clamped = |
| rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped); |
| return rewriter.create<mhlo::SelectOp>( |
| loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped); |
| } |
| |
| // Precondition is |x| <= 1. Use erfc approximation, otherwise. |
| // This implementation is based on Cephes. |
| Value MaterializeErfApproximationF64ForMagnituteLEOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF64() && |
| "expect f64 element type"); |
| const std::vector<double> kErfTCoefficients{ |
| 9.60497373987051638749E0, 9.00260197203842689217E1, |
| 2.23200534594684319226E3, 7.00332514112805075473E3, |
| 5.55923013010394962768E4}; |
| const std::vector<double> kErfUCoefficients{ |
| 1.00000000000000000000E0, 3.35617141647503099647E1, |
| 5.21357949780152679795E2, 4.59432382970980127987E3, |
| 2.26290000613890934246E4, 4.92673942608635921086E4}; |
| |
| // Materialize polynomial approximation for |x| <= 1 as |
| // erf(x) = x T(x^2) / U(x^2). |
| Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); |
| Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq, |
| kErfTCoefficients); |
| Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t); |
| Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq, |
| kErfUCoefficients); |
| return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u); |
| } |
| |
| // This implementation is based on Cephes. |
| Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF64() && |
| "expect f64 element type"); |
| |
| // Rely on erf approximation for |x| < 1 |
| // erf(x) = erf_approx(x) |
| Value erf_approx = |
| MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x); |
| |
| // Rely on erfc approximation for |x| >= 1 and materialize erf as |
| // erf(x) = 1 - erfc_approx(x) |
| Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); |
| Value erfc_approx = |
| MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x); |
| Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx); |
| |
| // Materialize approximation selection based on argument. |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT); |
| return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx, |
| erfc_based_approx); |
| } |
| |
| Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF64() && |
| "expect f64 element type"); |
| |
| // Rely on erfc approximation for |x| >= 1 |
| // erfc(x) = erfc_approx(x) |
| Value erfc_approx = |
| MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x); |
| |
| // Rely on erf approximation for |x| < 1 and materialize erfc as |
| // erfc(x) = 1 - erf_approx(x) |
| Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); |
| Value erf_approx = |
| MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x); |
| Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx); |
| |
| // Materialize approximation selection based on argument. |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT); |
| return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx, |
| erfc_approx); |
| } |
| |
| // Precondition is |x| >= 1. Use erf approximation, otherwise. |
| // |
| // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an |
| // argument and derive the final approximation for all |x| >= 1. |
| // This implementation is based on Cephes. |
| Value MaterializeErfcApproximationF32ForMagnitudeGEOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF32() && |
| "expect f32 element type"); |
| const double kMaxlog = 88.72283905206835; |
| const std::vector<float> kErfcPCoefficients{ |
| +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, |
| -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, |
| +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, |
| }; |
| const std::vector<float> kErfcRCoefficients{ |
| -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, |
| +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, |
| -2.820767439740514E-1, +5.641895067754075E-1, |
| }; |
| |
| // Let z = -x^2. |
| Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); |
| Value z = rewriter.create<mhlo::NegOp>(loc, x_sq); |
| |
| // Materialize polynomial approximation for x >= 1 as |
| // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2) |
| // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2 |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); |
| Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq); |
| Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z); |
| Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x); |
| Value exp_z_mul_one_div_abs_x = |
| rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x); |
| Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); |
| Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT); |
| Value poly_p = MaterializePolynomialApproximation( |
| rewriter, loc, reciprocal_x_sq, kErfcPCoefficients); |
| Value poly_r = MaterializePolynomialApproximation( |
| rewriter, loc, reciprocal_x_sq, kErfcRCoefficients); |
| Value poly = |
| rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r); |
| Value erfc_approx = |
| rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly); |
| |
| // Clamp to prevent overflow and materialize approximation for large x as |
| // erfc(x) = 0. |
| Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>( |
| loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT); |
| Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); |
| Value erfc_approx_clamped = |
| rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx); |
| |
| // Derive approximation for x <= -1 as |
| // erfc(x) = 2 - erfc(-x). |
| // Reuse previously materialized approximations all of which take |x| as their |
| // argument. |
| Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT); |
| Value two_sub_erfc_approx = |
| rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped); |
| return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx, |
| erfc_approx_clamped); |
| } |
| |
| // Precondition is |x| <= 1. Use erfc approximation, otherwise. |
| // This implementation is based on Cephes. |
| Value MaterializeErfApproximationF32ForMagnitudeLEOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF32() && |
| "expect f32 element type"); |
| const std::vector<float> kErfTCoefficients{ |
| +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, |
| -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, |
| +1.128379165726710E+0, |
| }; |
| |
| // Materialize polynomial approximation for |x| <= 1 as |
| // erf(x) = x T(x^2). |
| Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); |
| Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq, |
| kErfTCoefficients); |
| return rewriter.create<mhlo::MulOp>(loc, x, poly_t); |
| } |
| |
| // This is the same approximation as used in Eigen. |
| Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF32() && |
| "expect f32 element type"); |
| const std::vector<float> kAlpha{ |
| -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, |
| -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, |
| -1.60960333262415e-02f, |
| }; |
| const std::vector<float> kBeta{ |
| -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, |
| -7.37332916720468e-03f, -1.42647390514189e-02f, |
| }; |
| |
| // Clamp argument between -4 and 4. |
| Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x); |
| Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x); |
| x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub); |
| Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); |
| |
| // Materialize polynomial approximation for x in [-4, 4] as |
| // erf(x) = x * Alpha(x^2) / Beta(x^2). |
| Value alpha_poly = |
| MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha); |
| Value beta_poly = |
| MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta); |
| Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly); |
| return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly); |
| } |
| |
| Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(x.getType().cast<ShapedType>().getElementType().isF32() && |
| "expect f32 element type"); |
| |
| // Rely on erfc approximation for |x| >= 1 |
| // erfc(x) = erfc_approx(x) |
| Value erfc_approx = |
| MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x); |
| |
| // Rely on erf approximation for |x| < 1 and materialize erfc as |
| // erfc(x) = 1 - erf_approx(x) |
| Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); |
| Value erf_approx = |
| MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x); |
| Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx); |
| |
| // Materialize approximation selection based on argument. |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT); |
| return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx, |
| erfc_approx); |
| } |
| |
| Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args, FloatType min_precision_ty, |
| Value callback(ConversionPatternRewriter &, |
| Location, ValueRange)) { |
| auto original_ty = |
| getElementTypeOrSelf(args.front().getType()).cast<FloatType>(); |
| bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth(); |
| |
| // Upcast arguments if necessary. |
| llvm::SmallVector<Value, 2> casted_args; |
| if (needs_upcast) { |
| for (Value a : args) { |
| casted_args.push_back( |
| rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty)); |
| } |
| args = casted_args; |
| } |
| |
| Value result = callback(rewriter, loc, args); |
| |
| // Cast back if necessary. |
| if (needs_upcast) { |
| result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty); |
| } |
| |
| return result; |
| } |
| |
| struct ConvertErfOp : public OpConversionPattern<ErfOp> { |
| using OpConversionPattern<ErfOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ErfOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| ErfOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| Type ty = x.getType().cast<ShapedType>().getElementType(); |
| |
| // For now, we support only f64, f32, and f16. |
| if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure(); |
| |
| if (ty.isF64()) { |
| rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x)); |
| return success(); |
| } |
| |
| rewriter.replaceOp(op, MaterializeWithUpcast( |
| rewriter, loc, operands, rewriter.getF32Type(), |
| &MaterializeErfApproximationF32)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertErfcOp : public OpConversionPattern<ErfcOp> { |
| using OpConversionPattern<ErfcOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ErfcOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| ErfcOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| Type ty = x.getType().cast<ShapedType>().getElementType(); |
| |
| // For now, we support only f64, f32, and f16. |
| if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure(); |
| |
| if (ty.isF64()) { |
| rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x)); |
| return success(); |
| } |
| |
| rewriter.replaceOp(op, MaterializeWithUpcast( |
| rewriter, loc, operands, rewriter.getF32Type(), |
| &MaterializeErfcApproximationF32)); |
| return success(); |
| } |
| }; |
| |
| // Coefficients for the Lanczos approximation of the gamma function. The |
| // coefficients are uniquely determined by the choice of g and n (kLanczosGamma |
| // and kLanczosCoefficients.size() + 1). The coefficients below correspond to |
| // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and |
| // [7, 9] seemed to be the least sensitive to the quality of the log function. |
| // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 |
| // for a particularly inaccurate log function. |
| constexpr double kLanczosGamma = 7; // aka g |
| constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; |
| constexpr std::array<double, 8> kLanczosCoefficients = { |
| 676.520368121885098567009190444019, -1259.13921672240287047156078755283, |
| 771.3234287776530788486528258894, -176.61502916214059906584551354, |
| 12.507343278686904814458936853, -0.13857109526572011689554707, |
| 9.984369578019570859563e-6, 1.50563273514931155834e-7}; |
| |
| // Compute the Lgamma function using Lanczos' approximation from "A Precision |
| // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis |
| // series B. Vol. 1: |
| // lgamma(z + 1) = (log(2) + log(pi)) / 2 |
| // + (z + 1/2) * log(t(z)) |
| // - t(z) + log(a(z)) |
| // with t(z) = z + kLanczosGamma + 1/2 |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args) { |
| // If the input is less than 0.5 use Euler's reflection formula. |
| // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) |
| // Let z be |
| // z = -x if x < 1/2 |
| // z = x - 1 otheriwse |
| Value x = args.front(); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value half = getConstantLike(rewriter, loc, 0.5, x); |
| Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT); |
| Value neg_x = rewriter.create<mhlo::NegOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1, x); |
| Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one); |
| Value z = |
| rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one); |
| |
| // Materialize |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); |
| for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { |
| Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); |
| Value one_based_index = getConstantLike(rewriter, loc, i + 1, x); |
| Value quotient = rewriter.create<mhlo::DivOp>( |
| loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index)); |
| a = rewriter.create<mhlo::AddOp>(loc, a, quotient); |
| } |
| |
| // To improve accuracy on platforms with less-precise log implementations, |
| // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the |
| // device. |
| // Materialize as |
| // log(t) = log(kLanczosGamma + 1/2 + z) |
| // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). |
| Value lanczos_plus_half = |
| getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); |
| Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z); |
| Value log_term = |
| getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); |
| Value log1p_term = rewriter.create<mhlo::Log1pOp>( |
| loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half)); |
| Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term); |
| |
| // Note that t(z) may be large and we need to be careful not to overflow to |
| // infinity in the relevant term |
| // r = (z + 1/2) * log(t(z)) - t(z). |
| // Therefore, we compute this as |
| // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)). |
| Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t); |
| Value sum = rewriter.create<mhlo::SubOp>( |
| loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t); |
| Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t); |
| |
| // Compute the final result (modulo reflection) as |
| // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)). |
| Value log_a = rewriter.create<mhlo::LogOp>(loc, a); |
| Value lgamma = rewriter.create<mhlo::AddOp>( |
| loc, |
| rewriter.create<mhlo::AddOp>( |
| loc, |
| getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x), |
| r), |
| log_a); |
| |
| // Compute the reflected value for x < 0.5 as |
| // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). |
| // |
| // The abs is needed because lgamma is the log of the absolute value of the |
| // gamma function. |
| // |
| // We have to be careful when computing the final term above. gamma(x) goes |
| // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x) |
| // term. The slope is large, so precision is particularly important. |
| // |
| // Because abs(sin(pi * x)) has period of 1 we can equivalently use |
| // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is |
| // more numerically accurate: It doesn't overflow to inf like pi * x would and |
| // if x is an integer it evaluates to exactly 0 which is important because we |
| // then take the log of this value, and log(0) is inf. |
| // |
| // We don't have a frac(x) primitive in HLO and computing it is tricky, but |
| // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our |
| // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). |
| // |
| // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close |
| // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain |
| // [0, 1] is symmetric across the line Y=0.5. |
| // |
| |
| // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of |
| // pi * abs_frac for values of abs_frac close to 1. |
| Value abs = rewriter.create<mhlo::AbsOp>(loc, x); |
| Value abs_frac = rewriter.create<mhlo::SubOp>( |
| loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs)); |
| Value reduce_abs_frac = |
| rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT); |
| abs_frac = rewriter.create<mhlo::SelectOp>( |
| loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac), |
| abs_frac); |
| |
| // Materialize reflection. |
| Value reflection_denom = rewriter.create<mhlo::LogOp>( |
| loc, |
| rewriter.create<mhlo::SinOp>( |
| loc, rewriter.create<mhlo::MulOp>( |
| loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac))); |
| Value lgamma_reflection = rewriter.create<mhlo::SubOp>( |
| loc, |
| rewriter.create<mhlo::SubOp>( |
| loc, getConstantLike(rewriter, loc, std::log(M_PI), x), |
| reflection_denom), |
| lgamma); |
| |
| // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, |
| // then it "wins" and the result is +/-inf. |
| Value finite_reflection_denom = |
| rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom); |
| Value neg_reflection_denom = |
| rewriter.create<mhlo::NegOp>(loc, reflection_denom); |
| lgamma_reflection = rewriter.create<mhlo::SelectOp>( |
| loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom); |
| |
| // Select whether or not to rely on the reflection. |
| lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, |
| lgamma_reflection, lgamma); |
| |
| // Materialize +/-inf behavior as |
| // lgamma(+/-inf) = +inf. |
| Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x); |
| return rewriter.create<mhlo::SelectOp>( |
| loc, x_is_inf, |
| chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), |
| lgamma); |
| } |
| |
| // Express `cosh` as |
| // cosh(x) = (e^x + e^-x) / 2 |
| // = e^(x + log(1/2)) + e^(-x + log(1/2)) |
| // |
| // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. |
| // |
| // This incorrectly overflows to inf for two f32 input values, namely |
| // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The |
| // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so |
| // we deem this acceptable. |
| Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| CoshOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| |
| Value log_one_half = |
| rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x)); |
| Value exp_add = rewriter.create<mhlo::ExpOp>( |
| loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half)); |
| Value exp_sub = rewriter.create<mhlo::ExpOp>( |
| loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x)); |
| return rewriter.create<mhlo::AddOp>(loc, exp_add, exp_sub); |
| } |
| |
| struct ConvertCoshOp : public OpConversionPattern<CoshOp> { |
| using OpConversionPattern<CoshOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| CoshOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CoshOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) { |
| // TODO(hinsu): Support operands with complex element types by always |
| // using the formula for large x. The compare op is not legal for complex |
| // numbers. |
| return failure(); |
| } |
| rewriter.replaceOp(op, |
| MaterializeWithUpcast(rewriter, op.getLoc(), operands, |
| rewriter.getF32Type(), |
| &MaterializeCoshApproximation)); |
| return success(); |
| } |
| }; |
| |
| // Compute the Digamma function using Lanczos' approximation from "A Precision |
| // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis |
| // series B. Vol. 1: |
| // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z) |
| // with t(z) = z + kLanczosGamma + 1/2 |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) |
| Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args) { |
| // If the input is less than 0.5 use Euler's reflection formula. |
| // digamma(x) = digamma(1 - x) - pi * cot(pi * x) |
| // Let z be |
| // z = -x if x < 1/2 |
| // z = x - 1 otheriwse |
| Value x = args.front(); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value half = getConstantLike(rewriter, loc, 0.5, x); |
| Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT); |
| Value neg_x = rewriter.create<mhlo::NegOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1, x); |
| Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one); |
| Value z = |
| rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one); |
| |
| // Materialize |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); |
| Value a_prime = zero; |
| for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { |
| Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); |
| Value one_based_index = getConstantLike(rewriter, loc, i + 1, x); |
| Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index); |
| a_prime = rewriter.create<mhlo::SubOp>( |
| loc, a_prime, |
| rewriter.create<mhlo::DivOp>( |
| loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term))); |
| a = rewriter.create<mhlo::AddOp>( |
| loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term)); |
| } |
| |
| // To improve accuracy on platforms with less-precise log implementations, |
| // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the |
| // device. |
| // Materialize as |
| // log(t) = log(kLanczosGamma + 1/2 + z) |
| // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). |
| Value lanczos_plus_half = |
| getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); |
| Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z); |
| Value log_term = |
| getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); |
| Value log1p_term = rewriter.create<mhlo::Log1pOp>( |
| loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half)); |
| Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term); |
| |
| // Materialize the final result (modulo reflection) as |
| // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z). |
| Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a); |
| Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>( |
| loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t); |
| Value digamma = rewriter.create<mhlo::SubOp>( |
| loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a), |
| lanczos_gamma_div_t); |
| |
| // We need to be careful how we compute cot(pi * input) below: For |
| // near-integral arguments, pi * input can lose precision. |
| // |
| // Input is already known to be less than 0.5 (otherwise we don't have to |
| // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to |
| // increase precision of pi * x and the resulting cotangent. |
| Value reduced_x = rewriter.create<mhlo::AddOp>( |
| loc, x, |
| rewriter.create<mhlo::AbsOp>( |
| loc, rewriter.create<mhlo::FloorOp>( |
| loc, rewriter.create<mhlo::AddOp>( |
| loc, x, getConstantLike(rewriter, loc, 0.5, x))))); |
| |
| // Materialize reflection for inputs less than 0.5 as |
| // digamma(x) = digamma(1 - x) - pi * cot(pi * x) |
| // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) |
| Value pi = getConstantLike(rewriter, loc, M_PI, x); |
| Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x); |
| Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x); |
| Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x); |
| Value reflection = rewriter.create<mhlo::SubOp>( |
| loc, digamma, |
| rewriter.create<mhlo::DivOp>( |
| loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin)); |
| |
| // Select whether or not to rely on the reflection. |
| digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection, |
| digamma); |
| |
| // Digamma has poles at negative integers and zero; return nan for those. |
| const StringAttr kLE = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE)); |
| Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE); |
| const StringAttr kEQ = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ)); |
| Value is_int = rewriter.create<mhlo::CompareOp>( |
| loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ); |
| Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int); |
| return rewriter.create<mhlo::SelectOp>( |
| loc, is_pole, |
| getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(), |
| x), |
| digamma); |
| } |
| |
| Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args) { |
| assert(args.size() == 2); |
| Value x = args[0]; |
| Value q = args[1]; |
| static const std::array<double, 12> kZetaCoeffs{ |
| -7.1661652561756670113e18, |
| 1.8152105401943546773e17, |
| -4.5979787224074726105e15, |
| 1.1646782814350067249e14, |
| -2.950130727918164224e12, |
| 7.47242496e10, |
| -1.8924375803183791606e9, |
| 47900160.0, |
| -1209600.0, |
| 30240.0, |
| -720.0, |
| 12.0, |
| }; |
| |
| // For speed we'll always use 9 iterations for the initial series estimate, |
| // and a 12 term expansion for the Euler-Maclaurin formula. |
| Value a = q; |
| Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a); |
| Value neg_power = zero; |
| Value neg_x = rewriter.create<mhlo::NegOp>(loc, x); |
| Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x); |
| Value one = chlo::getConstantLike(rewriter, loc, 1.0, a); |
| for (int i = 0; i < 9; ++i) { |
| a = rewriter.create<mhlo::AddOp>(loc, a, one); |
| neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); |
| initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power); |
| } |
| a = rewriter.create<mhlo::AddOp>(loc, a, one); |
| neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); |
| Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x); |
| Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x); |
| Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a); |
| Value neg_power_mul_a_div_x_minus_one = |
| rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one); |
| Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum, |
| neg_power_mul_a_div_x_minus_one); |
| Value a_inverse_square = rewriter.create<mhlo::DivOp>( |
| loc, one, rewriter.create<mhlo::MulOp>(loc, a, a)); |
| |
| Value horner_sum = zero; |
| Value factor = one; |
| // Use Horner's rule for this. |
| // Note this differs from Cephes which does a 'naive' polynomial evaluation. |
| // Using Horner's rule allows to avoid some NaN's and Infs from happening, |
| // resulting in more numerically stable code. |
| for (int i = 0; i < 11; ++i) { |
| Value factor_lhs = rewriter.create<mhlo::SubOp>( |
| loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x)); |
| Value factor_rhs = rewriter.create<mhlo::SubOp>( |
| loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x)); |
| factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs); |
| horner_sum = rewriter.create<mhlo::MulOp>( |
| loc, factor, |
| rewriter.create<mhlo::MulOp>( |
| loc, a_inverse_square, |
| rewriter.create<mhlo::AddOp>( |
| loc, horner_sum, |
| chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); |
| } |
| Value zero_point_five_like_neg_power = |
| chlo::getConstantLike(rewriter, loc, .5, neg_power); |
| Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a); |
| s = rewriter.create<mhlo::AddOp>( |
| loc, s, |
| rewriter.create<mhlo::MulOp>( |
| loc, neg_power, |
| rewriter.create<mhlo::AddOp>( |
| loc, zero_point_five_like_neg_power, |
| rewriter.create<mhlo::MulOp>( |
| loc, x_div_a, |
| rewriter.create<mhlo::AddOp>( |
| loc, |
| chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], |
| a), |
| horner_sum))))); |
| |
| // Use the initial zeta sum without the correction term coming |
| // from Euler-Maclaurin if it is accurate enough. |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power); |
| Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum); |
| Value output = rewriter.create<mhlo::SelectOp>( |
| loc, |
| rewriter.create<mhlo::CompareOp>( |
| loc, abs_neg_power, |
| rewriter.create<mhlo::MulOp>( |
| loc, abs_initial_sum, |
| chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)), |
| kLT), |
| initial_sum, s); |
| |
| // Function is not defined for x < 1. |
| Value nan = chlo::getConstantLike( |
| rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x); |
| output = rewriter.create<mhlo::SelectOp>( |
| loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan, |
| output); |
| |
| // For q <= 0, x must be an integer. |
| const StringAttr kLE = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE)); |
| const StringAttr kNE = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE)); |
| Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE); |
| Value x_not_int = rewriter.create<mhlo::CompareOp>( |
| loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE); |
| Value x_domain_error = |
| rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int); |
| output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output); |
| |
| // For all integer q <= 0, zeta has a pole. The limit is only defined as |
| // +inf if x is and even integer. |
| const StringAttr kEQ = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ)); |
| Value inf = chlo::getConstantLike(rewriter, loc, |
| std::numeric_limits<double>::infinity(), x); |
| Value q_is_int = rewriter.create<mhlo::CompareOp>( |
| loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ); |
| Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int); |
| Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); |
| Value x_is_int = rewriter.create<mhlo::CompareOp>( |
| loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ); |
| Value x_is_even = rewriter.create<mhlo::CompareOp>( |
| loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ); |
| Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even); |
| output = rewriter.create<mhlo::SelectOp>( |
| loc, at_pole, |
| rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output); |
| |
| // For x = 1, this is the harmonic series and diverges. |
| output = rewriter.create<mhlo::SelectOp>( |
| loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output); |
| |
| return output; |
| } |
| |
| Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args) { |
| PolygammaOp::Adaptor transformed(args); |
| Value n = transformed.n(); |
| Value x = transformed.x(); |
| |
| // Handle integer n > 0. |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value sign = rewriter.create<mhlo::SubOp>( |
| loc, |
| rewriter.create<mhlo::MulOp>(loc, two, |
| rewriter.create<mhlo::RemOp>(loc, n, two)), |
| one); |
| Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one); |
| Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>( |
| loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one)); |
| Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x); |
| Value result = rewriter.create<mhlo::MulOp>( |
| loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta); |
| |
| // Handle n = 0. |
| const StringAttr kEQ = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ)); |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ); |
| result = rewriter.create<mhlo::SelectOp>( |
| loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result); |
| |
| // Check that n is a natural number. Return nan, otherwise. |
| const StringAttr kNE = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE)); |
| Value non_int = rewriter.create<mhlo::CompareOp>( |
| loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT); |
| Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative); |
| return rewriter.create<mhlo::SelectOp>( |
| loc, non_natural, |
| getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(), |
| x), |
| result); |
| } |
| |
| struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> { |
| using OpConversionPattern<LgammaOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| LgammaOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FloatType min_precision_ty = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, MaterializeWithUpcast(rewriter, op.getLoc(), operands, |
| min_precision_ty, &MaterializeLgamma)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> { |
| using OpConversionPattern<DigammaOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| DigammaOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FloatType min_precision_ty = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, MaterializeWithUpcast(rewriter, op.getLoc(), operands, |
| min_precision_ty, &MaterializeDigamma)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> { |
| using OpConversionPattern<PolygammaOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| PolygammaOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| FloatType min_precision_ty = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty, |
| &MaterializePolygamma)); |
| return success(); |
| } |
| }; |
| |
| Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| SinhOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| auto result_ty = x.getType().cast<ShapedType>(); |
| |
| // TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types. |
| Value two = rewriter.create<mhlo::ConstOp>( |
| loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2)); |
| Type extent_tensor_type = shape::getExtentTensorType(x.getContext()); |
| Value uncasted_shape = |
| rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, x); |
| Type shape_ty = |
| RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); |
| Value shape = rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape); |
| Value two_with_x_shape = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, result_ty, two, shape, rewriter.getI64TensorAttr({})); |
| |
| Value log_two = rewriter.create<mhlo::LogOp>(loc, two_with_x_shape); |
| Value log_one_half = rewriter.create<mhlo::NegOp>(loc, log_two); |
| Value exp_add = rewriter.create<mhlo::ExpOp>( |
| loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half)); |
| Value exp_sub = rewriter.create<mhlo::ExpOp>( |
| loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x)); |
| return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub); |
| } |
| |
| // Express `sinh` as |
| // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 |
| // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. |
| Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| Value large_sinh_result = |
| MaterializeSinhApproximationForLargeX(rewriter, loc, operands); |
| |
| SinhOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| const StringAttr kLT = rewriter.getStringAttr( |
| mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); |
| Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x); |
| Value exp_neg_x = |
| rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x)); |
| Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value small_sinh_result = |
| rewriter.create<mhlo::DivOp>(loc, exp_difference, two); |
| |
| Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT); |
| return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result, |
| large_sinh_result); |
| } |
| |
| struct ConvertSinhOp : public OpConversionPattern<SinhOp> { |
| using OpConversionPattern<SinhOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| SinhOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| SinhOp::Adaptor transformed(operands); |
| Value x = transformed.operand(); |
| if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) { |
| rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX( |
| rewriter, op.getLoc(), operands)); |
| return success(); |
| } |
| rewriter.replaceOp(op, |
| MaterializeWithUpcast(rewriter, op.getLoc(), operands, |
| rewriter.getF32Type(), |
| &MaterializeSinhApproximation)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertZetaOp : public OpConversionPattern<ZetaOp> { |
| using OpConversionPattern<ZetaOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ZetaOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| FloatType min_precision_ty = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty, |
| &MaterializeZeta)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> { |
| using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| BroadcastSelectOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only support ranked operands. |
| typename BroadcastSelectOp::Adaptor transformed(operands); |
| Value pred = transformed.pred(); |
| Value on_true = transformed.on_true(); |
| Value on_false = transformed.on_false(); |
| auto pred_type = pred.getType().dyn_cast<RankedTensorType>(); |
| auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>(); |
| auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>(); |
| auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>(); |
| if (!pred_type || !on_true_type || !on_false_type || !result_type) { |
| return failure(); |
| } |
| |
| auto loc = op.getLoc(); |
| |
| Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred); |
| Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true); |
| Value on_false_shape = |
| rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false); |
| int64_t result_rank = std::max( |
| {pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()}); |
| |
| Value broadcastable_cstr = |
| rewriter.createOrFold<shape::CstrBroadcastableOp>( |
| loc, ValueRange{pred_shape, on_true_shape, on_false_shape}); |
| auto assuming_op = rewriter.create<shape::AssumingOp>( |
| loc, ArrayRef<Type>{result_type}, broadcastable_cstr); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.createBlock(&assuming_op.doRegion()); |
| |
| Value result_extents = rewriter.createOrFold<shape::BroadcastOp>( |
| loc, shape::getExtentTensorType(op.getContext()), |
| ValueRange{pred_shape, on_true_shape, on_false_shape}, |
| /*error=*/nullptr); |
| auto shape_type = |
| RankedTensorType::get({result_rank}, rewriter.getIndexType()); |
| result_extents = |
| rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents); |
| |
| Value broadcasted_pred = pred; |
| // Pred has an implicit broadcast for scalars, so use that when convenient. |
| if (pred_type.getRank() > 0) { |
| auto pred_broadcast_dimensions = llvm::to_vector<4>( |
| llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank)); |
| broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(result_type.getShape(), |
| pred_type.getElementType()), |
| pred, result_extents, |
| rewriter.getI64TensorAttr(pred_broadcast_dimensions)); |
| } |
| auto on_true_broadcast_dimensions = llvm::to_vector<4>( |
| llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank)); |
| Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(result_type.getShape(), |
| on_true_type.getElementType()), |
| on_true, result_extents, |
| rewriter.getI64TensorAttr(on_true_broadcast_dimensions)); |
| auto on_false_broadcast_dimensions = llvm::to_vector<4>( |
| llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank)); |
| Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(result_type.getShape(), |
| on_false_type.getElementType()), |
| on_false, result_extents, |
| rewriter.getI64TensorAttr(on_false_broadcast_dimensions)); |
| |
| // And generate the final non-broadcasted ternary op. |
| Value final_result = rewriter.create<mhlo::SelectOp>( |
| loc, result_type, broadcasted_pred, broadcasted_on_true, |
| broadcasted_on_false); |
| rewriter.create<shape::AssumingYieldOp>(loc, final_result); |
| rewriter.replaceOp(op, {assuming_op.getResult(0)}); |
| return success(); |
| } |
| }; |
| |
| // Converts binary ops that statically are determined to not broadcast directly |
| // to the corresponding mhlo non-broadcasting op. |
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> |
| struct ConvertTrivialNonBroadcastBinaryOp |
| : public OpConversionPattern<ChloOpTy> { |
| using OpConversionPattern<ChloOpTy>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ChloOpTy op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only rewrite for statically determinable non-broadcasting cases. |
| typename ChloOpTy::Adaptor transformed(operands); |
| auto lhs_type = |
| transformed.lhs().getType().template dyn_cast<RankedTensorType>(); |
| auto rhs_type = |
| transformed.rhs().getType().template dyn_cast<RankedTensorType>(); |
| if (!lhs_type || !rhs_type) return failure(); |
| |
| // Requires rank broadcast. |
| if (lhs_type.getRank() != rhs_type.getRank()) return failure(); |
| // Any dynamic dimension may require broadcasting and requires more |
| // analysis. |
| if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) |
| return failure(); |
| |
| for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) { |
| auto lhs_extent = std::get<0>(extents); |
| auto rhs_extent = std::get<1>(extents); |
| if (lhs_extent != rhs_extent) { |
| return failure(); |
| } |
| } |
| |
| rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), |
| operands, rewriter)}); |
| return success(); |
| } |
| }; |
| |
| // Converts a binary op with ranked broadcasting operands to explicitly |
| // broadcast and invoke the corresponding mhlo non-broadcasting op. |
| // Note that dynamic broadcasting supported by this pattern is only valid for |
| // "numpy" broadcasting semantics as defined here: |
| // https://docs.scipy.org/doc/numpy/reference/ufuncs.html |
| // Specifically, this includes the following cases: |
| // - Same rank broadcast (operands have the same static rank). |
| // - Different-rank broadcast, either without a broadcast_dims attribte or |
| // with the broadcast_dims attribute set to map to a prefix padding. |
| // - Legal combinations of degenerate (1-dim) implicit broadcasting. |
| // The restriction on broadcast_dims derives from the definition of the |
| // `shape.broadcast` op, which only supports prefix-padding. |
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> |
| struct ConvertRankedDynamicBroadcastBinaryOp |
| : public OpConversionPattern<ChloOpTy> { |
| using OpConversionPattern<ChloOpTy>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| ChloOpTy op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only support ranked operands. |
| typename ChloOpTy::Adaptor transformed(operands); |
| Value lhs = transformed.lhs(); |
| Value rhs = transformed.rhs(); |
| auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>(); |
| auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>(); |
| auto result_type = |
| op.getResult().getType().template dyn_cast<RankedTensorType>(); |
| if (!lhs_type || !rhs_type || !result_type) return failure(); |
| |
| // Check for "numpy"-style rank broadcast. |
| auto broadcast_dimensions = op.broadcast_dimensions(); |
| if (broadcast_dimensions && |
| !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { |
| // Note: It is unclear whether the general specification of explicit |
| // broadcast_dimensions on binary ops is a feature we want to carry |
| // forward. While it can technically be implemented for ranked-dynamic, |
| // it is incompatible with unranked inputs. If this warning is emitted |
| // in real programs, it is an indication that the feature should be |
| // implemented versus just falling back on the more standard definition |
| // of numpy-like prefix-padding. |
| op.emitWarning() << "unsupported non prefix-padded dynamic rank " |
| << "broadcast_dimensions = " << *broadcast_dimensions; |
| return failure(); |
| } |
| |
| // Compute result shape. |
| auto loc = op.getLoc(); |
| |
| // Insert a constraint on the shapes being broadcastable and insert all |
| // future code into an assuming block reliant on the constraint. |
| Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs); |
| Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs); |
| auto broadcastable_cstr = |
| rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape); |
| auto assuming_op = rewriter.create<shape::AssumingOp>( |
| loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result()); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.createBlock(&assuming_op.doRegion()); |
| |
| int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); |
| Value result_extents = |
| hlo::ComputeBinaryElementwiseBroadcastingResultExtents( |
| loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true); |
| |
| // Note that we unconditionally emit DynamicBroadcastInDim ops and let |
| // downstream canonicalizations fold them away if possible. This is |
| // because, in the dynamic case, there are many corner cases regarding |
| // when it is safe to omit, and some of them require analysis to prove |
| // properly. |
| auto lhs_broadcast_dimensions = llvm::to_vector<4>( |
| llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank)); |
| Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(result_type.getShape(), |
| lhs_type.getElementType()), |
| lhs, result_extents, |
| rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); |
| auto rhs_broadcast_dimensions = llvm::to_vector<4>( |
| llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank)); |
| Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(result_type.getShape(), |
| rhs_type.getElementType()), |
| rhs, result_extents, |
| rewriter.getI64TensorAttr(rhs_broadcast_dimensions)); |
| |
| // And generate the final non-broadcasted binary op. |
| Value final_result = Adaptor::CreateOp( |
| op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter); |
| rewriter.create<shape::AssumingYieldOp>(loc, final_result); |
| rewriter.replaceOp(op, {assuming_op.getResult(0)}); |
| return success(); |
| } |
| }; |
| |
| #include "generated_chlo_legalize_to_hlo.inc" |
| } // namespace |
| |
| void PopulateChloBroadcastingPatterns(MLIRContext *context, |
| OwningRewritePatternList *patterns) { |
| // Instantiate conversion templates for conforming binary elementwise ops |
| // that do not have different dtypes between operands and results and do |
| // not have special attributes that need to be preserved. |
| PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>( |
| context, patterns, 10); |
| PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( |
| context, patterns, 5); |
| patterns->insert<ConvertSelectOp>(context); |
| patterns->insert<ConvertConstantLikeOp>(context); |
| } |
| |
| void PopulateDecomposeChloPatterns(MLIRContext *context, |
| OwningRewritePatternList *patterns) { |
| populateWithGenerated(*patterns); |
| |
| // Other patterns. |
| // clang-format off |
| patterns->insert<ConvertCoshOp, |
| ConvertDigammaOp, |
| ConvertErfOp, |
| ConvertErfcOp, |
| ConvertLgammaOp, |
| ConvertPolygammaOp, |
| ConvertSinhOp, |
| ConvertZetaOp>(context); |
| // clang-format on |
| } |
| |
| } // namespace chlo |
| } // namespace mlir |