Uses prefer_tf2xla fallback option to cleanly separate op legalization between MLIR and tf2xla. The list of ops MLIR-legalized ops is introduced. All other ops use tf2xla to legalize.

PiperOrigin-RevId: 375725177
Change-Id: Ie56b66d3cfe97c10ac8f36b5fec5b52a7b779168
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir
new file mode 100644
index 0000000..0aa0377
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir
@@ -0,0 +1,29 @@
+// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false use-tf2xla-fallback=true prefer-tf2xla=true" %s | FileCheck %s
+// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false prefer-tf2xla=true" %s | FileCheck --check-prefix NOFALLBACK %s
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+
+// CHECK-LABEL: @abs
+func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: bessel_i0e
+func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
+  // CHECK-NOT: tf.BesselI0e
+  %0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
+  %1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
+  %2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
+  return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
+}
+
+// NOFALLBACK-LABEL: @xla_svd
+func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) {
+  // NOFALLBACK: XlaSvd
+  %s, %u, %v = "tf.XlaSvd"(%arg0) {max_iter = 1, epsilon = 1.0E-09 : f32, precision_config = ""} : (tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>)
+  return %s, %u, %v : tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>
+}
+
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 793eea0..1bbe8da 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -6390,6 +6390,34 @@
   }
 }
 
+// Patterns whose root op is in the set `include_ops` are moved from the set
+// `from` to the returned set. This is used to partition patterns by op so they
+// can be cleanly migrated from the old bridge to the MLIR bridge.
+OwningRewritePatternList PatternsIncludeOps(
+    OwningRewritePatternList &from,
+    const llvm::DenseSet<mlir::TypeID> &include_ops);
+
+OwningRewritePatternList PatternsIncludeOps(
+    OwningRewritePatternList &from,
+    const llvm::DenseSet<mlir::TypeID> &include_ops) {
+  OwningRewritePatternList to(from.getContext());
+  // Filter NativePatterns.
+  for (auto &pattern : from.getNativePatterns()) {
+    Optional<OperationName> pat_op_name = pattern->getRootKind();
+    // If the pattern does not have a specific operation, always include it,
+    // If the pattern is in include_ops then include it.
+    bool include =
+        !pat_op_name ||
+        include_ops.count(pat_op_name->getAbstractOperation()->typeID);
+    if (include) to.add(std::move(pattern));
+  }
+
+  // Don't filter PDLPatterns.
+  to.add(std::move(from.getPDLPatterns()));
+
+  return to;
+}
+
 }  // end namespace
 
 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
@@ -6399,7 +6427,7 @@
                          llvm::Optional<StringRef> tf2xla_fallback_device_type,
                          bool prefer_tf2xla) {
   MLIRContext *context = op->getContext();
-  OwningRewritePatternList patterns(context);
+  OwningRewritePatternList legalize_lower_patterns(context);
   // Note that the `OperationConverter` orders patterns lexicographically by:
   // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
   //    to arrive at conversion target). This requires relevant patterns to
@@ -6411,15 +6439,35 @@
   // 4) Order of patterns in `OwningRewritePatternList`.
 
   // Add TF->HLO legalization patterns.
-  PopulateLegalizeTfPatterns(context, &patterns);
+  PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
 
   // Add TF->TF lowering patterns.
-  TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
+  TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
 
-  // Add TF->HLO legalization patterns via TF2XLA fallback.
-  if (tf2xla_fallback_device_type.hasValue()) {
+  if (tf2xla_fallback_device_type && prefer_tf2xla) {
+    VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
+               "either native MLIR legalization, or TF2XLA fallback "
+               "legalzation, with a preference toward TF2XLA.";
+  } else if (tf2xla_fallback_device_type) {
+    VLOG(1) << "TF to XLA legalization patterns include all native patterns "
+               "and TF2XLA fallback patterns.";
+  } else {
+    VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
+  }
+
+  // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
+  // only patterns whose ops are in the set MlirLegalizedUnderPreferTf2XlaSet
+  // are kept.
+  OwningRewritePatternList patterns =
+      (tf2xla_fallback_device_type && prefer_tf2xla)
+          ? PatternsIncludeOps(legalize_lower_patterns,
+                               MlirLegalizedUnderPreferTf2XlaSet())
+          : std::move(legalize_lower_patterns);
+
+  if (tf2xla_fallback_device_type) {
+    // Add TF->HLO legalization patterns via TF2XLA fallback.
     PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
-                                         patterns, context);
+                                         patterns, context, prefer_tf2xla);
   }
 
   // Populate with CHLO->HLO lowerings to account for TF ops legalized to
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index c260195..b8f2120 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -280,6 +280,30 @@
   return ops.count(abstractOp->typeID);
 }
 
+// List ops that should use MLIR legalization in the case of prefer_tf2xla. All
+// other ops not in this list should use the old bridge's XlaOpKernel
+// legalization.
+const llvm::DenseSet<mlir::TypeID>& MlirLegalizedUnderPreferTf2XlaSet() {
+  // The static variable is a pointer in order to avoid destruction upon thread
+  // termination.
+
+  // clang-format off
+  static const llvm::DenseSet<mlir::TypeID>* ops =
+      new llvm::DenseSet<mlir::TypeID>{
+    // Ops that are legalized in the old bridge using MlirXlaOpKernel
+    TypeID::get<TF::AbsOp>(),
+  };
+  // clang-format on
+  return *ops;
+}
+
+bool IsOpMlirLegalizedUnderPreferTf2Xla(Operation* op) {
+  auto mlir_ops = MlirLegalizedUnderPreferTf2XlaSet();
+  auto* abstractOp = op->getAbstractOperation();
+  if (!abstractOp) return false;
+  return mlir_ops.count(abstractOp->typeID);
+}
+
 namespace {
 
 template <typename T, size_t N>
@@ -571,33 +595,42 @@
 class Tf2XlaRewritePattern : public RewritePattern {
  public:
   explicit Tf2XlaRewritePattern(MLIRContext* ctx,
-                                const std::string& device_type)
+                                const std::string& device_type,
+                                bool prefer_tf2xla)
       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
-        device_type_(device_type) {}
+        device_type_(device_type),
+        prefer_tf2xla_(prefer_tf2xla) {}
 
   LogicalResult matchAndRewrite(Operation* op,
                                 PatternRewriter& rewriter) const override {
-    if (!IsOpAllowedTf2XlaFallback(op)) return failure();
+    if (prefer_tf2xla_) {
+      if (IsOpMlirLegalizedUnderPreferTf2Xla(op)) return failure();
+    } else {
+      if (!IsOpAllowedTf2XlaFallback(op)) return failure();
+    }
     return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
   }
 
  private:
   std::string device_type_;
+  bool prefer_tf2xla_;
 };
 
 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
  public:
   LegalizeTF() = default;
 
-  explicit LegalizeTF(llvm::StringRef device_type) {
+  explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
     device_type_ = device_type.str();
+    prefer_tf2xla_ = prefer_tf2xla;
   }
 
   LegalizeTF(const LegalizeTF&) {}
 
   void runOnFunction() override {
     OwningRewritePatternList patterns(&getContext());
-    patterns.insert<Tf2XlaRewritePattern>(&getContext(), device_type_);
+    patterns.insert<Tf2XlaRewritePattern>(&getContext(), device_type_,
+                                          prefer_tf2xla_);
     if (failed(
             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
       signalPassFailure();
@@ -609,6 +642,12 @@
   Option<std::string> device_type_{
       *this, "device-type",
       llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
+  Option<bool> prefer_tf2xla_{
+      *this,
+      "prefer-tf2xla",
+      llvm::cl::desc("Enable legalization when it is not in the list of "
+                     "MLIR-legalized ops."),
+  };
 };
 
 static PassRegistration<LegalizeTF> pass(
@@ -619,13 +658,14 @@
 
 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
                                           OwningRewritePatternList& patterns,
-                                          MLIRContext* ctx) {
-  patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str());
+                                          MLIRContext* ctx,
+                                          bool prefer_tf2xla) {
+  patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
-    llvm::StringRef device_type) {
-  return std::make_unique<LegalizeTF>(device_type);
+    llvm::StringRef device_type, bool prefer_tf2xla) {
+  return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
 }
 
 }  // end namespace mhlo
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index 5df4565..b4fadcd 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -50,22 +50,29 @@
 /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
 /// specified device type.
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
-    llvm::StringRef device_type);
+    llvm::StringRef device_type, bool prefer_tf2xla = false);
 
 /// Replaces types that do not exist in MHLO with equivalent types that do
 /// exist.
 std::unique_ptr<OperationPass<void>> CreateLegalizeTfTypesPass();
 
 /// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
+/// `prefer_tf2xla` means an op will be included iff it is not in
+/// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be
+/// included iff it is in `IsOpAllowedTf2XlaFallback`.
 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
                                           OwningRewritePatternList& patterns,
-                                          MLIRContext* ctx);
+                                          MLIRContext* ctx,
+                                          bool prefer_tf2xla = false);
 
 /// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern
 /// list.
 void PopulateLegalizeTfPatterns(MLIRContext* context,
                                 OwningRewritePatternList* patterns);
 
+/// Ops that should be legalized using MLIR given the prefer_tf2xla option.
+const llvm::DenseSet<mlir::TypeID>& MlirLegalizedUnderPreferTf2XlaSet();
+
 /// Checks whether the op is supported by the Tf2Xla fallback for legalization.
 bool IsOpAllowedTf2XlaFallback(Operation* op);