[tfrt:jitrt] Approximate ExpM1 with Eigen Exp approximation

This change approximates ExpM1 with Eigen Exp approximation

PiperOrigin-RevId: 466993181
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
index dc5b4ca..05359fe 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc
@@ -245,11 +245,67 @@
   return mlir::success();
 }
 
+struct EigenExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::ExpM1Op op,
+                                PatternRewriter &rewriter) const final;
+};
+
+LogicalResult EigenExpM1Approximation::matchAndRewrite(
+    math::ExpM1Op op, PatternRewriter &rewriter) const {
+  auto shape = vectorShape(op.getOperand().getType(), isF32);
+  if (!shape.hasValue())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, *shape);
+  };
+
+  // expm1(x) = exp(x) - 1 = u - 1.
+  // We have to handle it carefully when x is near 0, i.e. u ~= 1,
+  // and when the input is ~= -inf, i.e. u - 1 ~= -1.
+  Value cstOne = bcast(f32Cst(builder, 1.0f));
+  Value cstNegOne = bcast(f32Cst(builder, -1.0f));
+  Value x = op.getOperand();
+  Value u = builder.create<math::ExpOp>(x);
+  Value uEqOneOrNaN =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
+  Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
+  Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
+      arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
+  // logU = log(u) ~= x
+  Value logU = builder.create<math::LogOp>(u);
+
+  // Detect exp(x) = +inf; written this way to avoid having to form +inf.
+  Value isInf =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
+
+  // (u - 1) * (x / ~x)
+  Value expm1 = builder.create<arith::MulFOp>(
+      uMinusOne, builder.create<arith::DivFOp>(x, logU));
+  expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
+  Value approximation = builder.create<arith::SelectOp>(
+      uEqOneOrNaN, x,
+      builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
+  rewriter.replaceOp(op, approximation);
+
+  return mlir::success();
+}
+
 static void populateMathApproximationPatterns(RewritePatternSet &patterns,
                                               ArrayRef<std::string> oplist) {
   for (const std::string &op : oplist) {
-    if (op == "exp" || op == "all")
+    if (op == "all") {
+      patterns.add<EigenExpApproximation, EigenExpM1Approximation>(
+          patterns.getContext());
+    } else if (op == "exp") {
       patterns.add<EigenExpApproximation>(patterns.getContext());
+    } else if (op == "expm1") {
+      patterns.add<EigenExpM1Approximation>(patterns.getContext());
+    }
   }
 }
 
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
index 2b145b4..b1c5da7 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_math_approximation.mlir
@@ -2,6 +2,8 @@
 // RUN: | FileCheck %s
 // RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation="oplist=exp"               \
 // RUN: | FileCheck --check-prefix=EXP %s
+// RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation="oplist=expm1"             \
+// RUN: | FileCheck --check-prefix=EXPM1 %s
 // RUN: tf-tfrt-opt %s -tf-jitrt-math-approximation                            \
 // RUN: | FileCheck --check-prefix=NOOP %s
 
@@ -70,3 +72,24 @@
   %0 = math.exp %arg0 : f32
   func.return %0 : f32
 }
+
+// CHECK-LABEL: func @expm1_scalar(
+// CHECK-NOT: math.exp
+// EXP: math.expm1
+// EXPM1-NOT: math.expm1
+// NOOP: math.expm1
+// CHECK:    %[[VAL_38:.*]] = arith.cmpf ueq, %[[VAL_37:.*]], %cst
+// CHECK:    %[[VAL_39:.*]] = arith.subf %[[VAL_37]], %cst
+// CHECK:    %[[VAL_40:.*]] = arith.cmpf oeq, %[[VAL_39]], %cst_0
+// CHECK:    %[[VAL_41:.*]] = math.log %[[VAL_37]]
+// CHECK:    %[[VAL_42:.*]] = arith.cmpf oeq, %[[VAL_41]], %[[VAL_37]]
+// CHECK:    %[[VAL_43:.*]] = arith.divf %arg0, %[[VAL_41]]
+// CHECK:    %[[VAL_44:.*]] = arith.mulf %[[VAL_39]], %[[VAL_43]]
+// CHECK:    %[[VAL_45:.*]] = arith.select %[[VAL_42]], %[[VAL_37]], %[[VAL_44]]
+// CHECK:    %[[VAL_46:.*]] = arith.select %[[VAL_40]], %cst_0, %[[VAL_45]] 
+// CHECK:    %[[VAL_47:.*]] = arith.select %[[VAL_38]], %arg0, %[[VAL_46]] 
+// CHECK: }
+func.func @expm1_scalar(%arg0 : f32) -> f32 {
+  %0 = math.expm1 %arg0: f32
+  func.return %0 : f32
+}