[tfrt:jitrt] Approximate ExpM1 with Eigen Exp approximation
This change approximates ExpM1 with Eigen Exp approximation
PiperOrigin-RevId: 466993181
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
index dc5b4ca..05359fe 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
@@ -245,11 +245,67 @@
return mlir::success();
}
+struct EigenExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::ExpM1Op op,
+ PatternRewriter &rewriter) const final;
+};
+
+LogicalResult EigenExpM1Approximation::matchAndRewrite(
+ math::ExpM1Op op, PatternRewriter &rewriter) const {
+ auto shape = vectorShape(op.getOperand().getType(), isF32);
+ if (!shape.hasValue())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, *shape);
+ };
+
+ // expm1(x) = exp(x) - 1 = u - 1.
+ // We have to handle it carefully when x is near 0, i.e. u ~= 1,
+ // and when the input is ~= -inf, i.e. u - 1 ~= -1.
+ Value cstOne = bcast(f32Cst(builder, 1.0f));
+ Value cstNegOne = bcast(f32Cst(builder, -1.0f));
+ Value x = op.getOperand();
+ Value u = builder.create<math::ExpOp>(x);
+ Value uEqOneOrNaN =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
+ Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
+ Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
+ arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
+ // logU = log(u) ~= x
+ Value logU = builder.create<math::LogOp>(u);
+
+ // Detect exp(x) = +inf; written this way to avoid having to form +inf.
+ Value isInf =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
+
+ // (u - 1) * (x / ~x)
+ Value expm1 = builder.create<arith::MulFOp>(
+ uMinusOne, builder.create<arith::DivFOp>(x, logU));
+ expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
+ Value approximation = builder.create<arith::SelectOp>(
+ uEqOneOrNaN, x,
+ builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
+ rewriter.replaceOp(op, approximation);
+
+ return mlir::success();
+}
+
static void populateMathApproximationPatterns(RewritePatternSet &patterns,
ArrayRef<std::string> oplist) {
for (const std::string &op : oplist) {
- if (op == "exp" || op == "all")
+ if (op == "all") {
+ patterns.add<EigenExpApproximation, EigenExpM1Approximation>(
+ patterns.getContext());
+ } else if (op == "exp") {
patterns.add<EigenExpApproximation>(patterns.getContext());
+ } else if (op == "expm1") {
+ patterns.add<EigenExpM1Approximation>(patterns.getContext());
+ }
}
}
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
index 2b145b4..b1c5da7 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
@@ -2,6 +2,8 @@
// RUN: | FileCheck %s
// RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation="oplist=exp" \
// RUN: | FileCheck --check-prefix=EXP %s
+// RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation="oplist=expm1" \
+// RUN: | FileCheck --check-prefix=EXPM1 %s
// RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation \
// RUN: | FileCheck --check-prefix=NOOP %s
@@ -70,3 +72,24 @@
%0 = math.exp %arg0 : f32
func.return %0 : f32
}
+
+// CHECK-LABEL: func @expm1_scalar(
+// CHECK-NOT: math.exp
+// EXP: math.expm1
+// EXPM1-NOT: math.expm1
+// NOOP: math.expm1
+// CHECK: %[[VAL_38:.*]] = arith.cmpf ueq, %[[VAL_37:.*]], %cst
+// CHECK: %[[VAL_39:.*]] = arith.subf %[[VAL_37]], %cst
+// CHECK: %[[VAL_40:.*]] = arith.cmpf oeq, %[[VAL_39]], %cst_0
+// CHECK: %[[VAL_41:.*]] = math.log %[[VAL_37]]
+// CHECK: %[[VAL_42:.*]] = arith.cmpf oeq, %[[VAL_41]], %[[VAL_37]]
+// CHECK: %[[VAL_43:.*]] = arith.divf %arg0, %[[VAL_41]]
+// CHECK: %[[VAL_44:.*]] = arith.mulf %[[VAL_39]], %[[VAL_43]]
+// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_42]], %[[VAL_37]], %[[VAL_44]]
+// CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_40]], %cst_0, %[[VAL_45]]
+// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_38]], %arg0, %[[VAL_46]]
+// CHECK: }
+func.func @expm1_scalar(%arg0 : f32) -> f32 {
+ %0 = math.expm1 %arg0: f32
+ func.return %0 : f32
+}