Use TF legalization patterns to determine if op should be marked for OutsideCompilation.

PiperOrigin-RevId: 325042393
Change-Id: I075db3bb540e9cbc682b699bdf468021ce5debdb
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index c6f0083..63908c8 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -827,6 +827,7 @@
         ":xla_sharding_util",
         "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
         "//tensorflow/compiler/mlir/lite:validators",
+        "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla:xla_proto_cc",
         "//tensorflow/compiler/xla/client:sharding_builder",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
index 9b28b3b..afad117 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
@@ -1,154 +1,159 @@
 // RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
 
-
-// CHECK-LABEL: func @op_string_result
-func @op_string_result() -> tensor<?xi32> {
+// CHECK-LABEL: func @unsupported_op
+func @unsupported_op() -> tensor<i32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
-    // CHECK-NOT: _xla_outside_compilation
-    // CHECK: "tf.B"
+    // CHECK: "tf.UnsupportedOp"
     // CHECK-SAME: _xla_outside_compilation
-    // CHECK: "tf.C"
+    // CHECK: "tf.Identity"
     // CHECK-NOT: _xla_outside_compilation
-    %1 = "tf.A"() : () -> tensor<?xi32>
-    %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<!tf.string>
-    %3 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %3 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+    %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %2 : tensor<i32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
 }
 
-// CHECK-LABEL: func @op_string_operand
-func @op_string_operand(%arg0: tensor<!tf.string>) -> tensor<?xi32> {
+// CHECK-LABEL: func @op_string_result
+func @op_string_result() -> tensor<i32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
     // CHECK-NOT: _xla_outside_compilation
-    // CHECK: "tf.B"
+    // CHECK: "tf.Const"
     // CHECK-SAME: _xla_outside_compilation
-    // CHECK: "tf.C"
+    // CHECK-SAME: tf.string
+    // CHECK: "tf.Identity"
     // CHECK-NOT: _xla_outside_compilation
-    %1 = "tf.A"() : () -> tensor<?xi32>
-    %2 = "tf.B"(%arg0) : (tensor<!tf.string>) -> tensor<?xi32>
-    %3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %3 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.Const"() {value = dense<"x"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+    %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %3 : tensor<i32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+// CHECK-LABEL: func @op_string_operand
+func @op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
+  %0 = "tf_device.cluster"() ( {
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK-NOT: _xla_outside_compilation
+    // CHECK: "tf.StringToNumber"
+    // CHECK-SAME: _xla_outside_compilation
+    // CHECK-SAME: tf.string
+    // CHECK: "tf.Identity"
+    // CHECK-NOT: _xla_outside_compilation
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.StringToNumber"(%arg0) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+    %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %3 : tensor<i32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
 }
 
 // CHECK-LABEL: func @op_string_operand_string_result
-func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<?xi32> {
+func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<i32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
     // CHECK-NOT: _xla_outside_compilation
-    // CHECK: "tf.B"
+    // CHECK: "tf.Identity"
     // CHECK-SAME: _xla_outside_compilation
-    // CHECK: "tf.C"
+    // CHECK-SAME: tf.string
+    // CHECK: "tf.Identity"
     // CHECK-NOT: _xla_outside_compilation
-    %1 = "tf.A"() : () -> tensor<?xi32>
-    %2 = "tf.B"(%arg0) : (tensor<!tf.string>) -> tensor<!tf.string>
-    %3 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %3 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.Identity"(%arg0)  : (tensor<!tf.string>) -> tensor<!tf.string>
+    %3 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
+    tf_device.return %3 : tensor<i32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
+  return %0 : tensor<i32>
 }
 
-
 // Test that a tf.IfRegion op with a captured string operand is marked for outside compilation.
 // CHECK-LABEL: func @if_region_captured_string
-func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<?xi32> {
+func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.IfRegion"
-    // CHECK: "tf.D"
-    // CHECK-SAME: _xla_outside_compilation
-    // CHECK: _xla_outside_compilation
-    // CHECK-SAME: is_stateless = true
-    %1 = "tf.A"() : () -> tensor<?xi32>
+    // CHECK: "tf.StringToNumber"
+    // CHECK: _xla_outside_compilation = "auto", is_stateless = true
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %2 = "tf.IfRegion"(%arg0) ( {
-      %3 = "tf.D"(%arg1) : (tensor<!tf.string>) -> tensor<?xi32>
-      "tf.Yield"(%3) : (tensor<?xi32>) -> ()
+      %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
      },  {
-      %4 = "tf.H"() : () -> tensor<?xi32>
-      "tf.Yield"(%4) : (tensor<?xi32>) -> ()
-    }) {is_stateless = true} : (tensor<i1>) -> (tensor<?xi32>)
-    %5 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %5 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+      %4 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+      "tf.Yield"(%4) : (tensor<f32>) -> ()
+    }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    %5 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+    tf_device.return %5 : tensor<f32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
+  return %0 : tensor<f32>
 }
 
-// Test that op with a string results/operands inside a tf.IfRegion branch is marked for outside compilation.
+// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation.
 
 // CHECK-LABEL: func @if_region_string_op
-func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
+func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.IfRegion"
     // CHECK-NOT: _xla_outside_compilation
-    %1 = "tf.A"() : () -> tensor<?xi32>
-    %2 = "tf.IfRegion"(%arg0)({
-      // CHECK: "tf.D"
-      // CHECK-NOT: _xla_outside_compilation
-      %3 = "tf.D"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
-      "tf.Yield"(%3) : (tensor<?xi32>) -> ()
-    },  {
-      // CHECK: "tf.F"
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.IfRegion"(%arg0) ( {
+      %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
+     },  {
+      // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+      // CHECK-NEXT: "tf.StringToNumber"
       // CHECK-SAME: _xla_outside_compilation
-      %4 = "tf.F"() : () -> tensor<!tf.string>
-      // CHECK: "tf.G"
-      // CHECK-SAME: _xla_outside_compilation
-      %5 = "tf.G"(%4) : (tensor<!tf.string>) -> tensor<?xi32>
-      %6 = "tf.H"() : () -> tensor<?xi32>
-      "tf.Yield"(%6) : (tensor<?xi32>) -> ()
-      }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
-    // CHECK: "tf.C"
-    // CHECK-NOT: _xla_outside_compilation
-    %7 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %7 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+      %4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+      %5 = "tf.StringToNumber"(%4) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+      "tf.Yield"(%5) : (tensor<f32>) -> ()
+    // CHECK: {is_stateless
+    }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    %6 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+    tf_device.return %6: tensor<f32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
+  return %0 : tensor<f32>
 }
 
-// Test that op with a string results/operands inside a tf.IfRegion branch is marked for outside compilation.
+// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation.
 
 // CHECK-LABEL: func @nested_if_region_string_op
-func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
+func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ( {
-    // CHECK: "tf.A"
+    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.IfRegion"
     // CHECK-NOT: _xla_outside_compilation
-    %1 = "tf.A"() : () -> tensor<?xi32>
-    %2 = "tf.IfRegion"(%arg0)({
-      // CHECK: "tf.D"
-      // CHECK-NOT: _xla_outside_compilation
-      %3 = "tf.D"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
-      "tf.Yield"(%3) : (tensor<?xi32>) -> ()
-    },  {
-      %4 = "tf.E"() : () -> tensor<i1>
-      %5 = "tf.IfRegion"(%4)({
-        // CHECK: "tf.F"
-        // CHECK-NOT: _xla_outside_compilation
-        %6 = "tf.F"(%arg1) : (tensor<?xi32>) -> tensor<?xi32>
-        "tf.Yield"(%6) : (tensor<?xi32>) -> ()
+    %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "tf.IfRegion"(%arg0) ( {
+      %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
+      "tf.Yield"(%3) : (tensor<f32>) -> ()
       },  {
-        // CHECK: "tf.G"
-        // CHECK-SAME: _xla_outside_compilation
-        %7 = "tf.G"() : () -> tensor<!tf.string>
-        // CHECK: "tf.H"
-        // CHECK-SAME: _xla_outside_compilation
-        %8 = "tf.H"(%7) : (tensor<!tf.string>) -> tensor<?xi32>
-        %9 = "tf.I"() : () -> tensor<?xi32>
-        "tf.Yield"(%9) : (tensor<?xi32>) -> ()
-      }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
-      "tf.Yield"(%5) : (tensor<?xi32>) -> ()
-    }) {is_stateless = true} : (tensor<i1>) -> tensor<?xi32>
-    // CHECK: "tf.C"
-    // CHECK-NOT: _xla_outside_compilation
-    %10 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
-    tf_device.return %10 : tensor<?xi32>
-  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+       // CHECK: "tf.Const"() {value = dense<true> : tensor<i1>}
+       // CHECK-NOT: _xla_outside_compilation
+       %4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+       %5 = "tf.IfRegion"(%4)({
+         // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+         // CHECK-NEXT: "tf.StringToNumber"
+         // CHECK-SAME: _xla_outside_compilation
+         %6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
+         %7 = "tf.StringToNumber"(%6) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
+         "tf.Yield"(%7) : (tensor<f32>) -> ()
+       },  {
+         // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+         // CHECK-NOT: _xla_outside_compilation
+         %8 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+         "tf.Yield"(%8) : (tensor<f32>) -> ()
+       // CHECK: {is_stateless
+       }){is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+       "tf.Yield"(%5) : (tensor<f32>) -> ()
+    // CHECK: {is_stateless
+    }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
+    %9 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
+    tf_device.return %9: tensor<f32>
+  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
+  return %0 : tensor<f32>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
index 8f1f3ec..71146cf 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 
 namespace mlir {
 namespace TFDevice {
@@ -43,6 +44,14 @@
   void runOnOperation() override;
 };
 
+// TODO(b/159128666): Check the control flow legalization passes instead once
+// added.
+void AddSupportedControlFlowOps(MLIRContext* context,
+                                llvm::DenseSet<OperationName>* supported_ops) {
+  supported_ops->insert(OperationName("tf.IfRegion", context));
+  supported_ops->insert(OperationName("tf.Yield", context));
+}
+
 bool HasStringOperand(Operation& op) {
   for (auto operand : op.getOperands()) {
     if (getElementTypeOrSelf(operand).isa<TF::StringType>()) return true;
@@ -57,12 +66,18 @@
   return false;
 }
 
+bool MatchesPattern(Operation& op,
+                    const llvm::DenseSet<OperationName>& supported_ops) {
+  return (supported_ops.contains(op.getName()));
+}
+
 // Checks if the op is supported inside of a device cluster.
-bool IsSupportedOp(Operation& op) {
-  if (HasStringOperand(op) || HasStringResult(op)) {
-    return false;
-  }
-  return true;
+bool IsSupportedOp(Operation& op,
+                   const llvm::DenseSet<OperationName>& supported_ops) {
+  // TODO(b/161726307): Check the allowed ops list in LegalizeTfWithTf2XlaPass
+  // as well.
+  return !HasStringOperand(op) && !HasStringResult(op) &&
+         MatchesPattern(op, supported_ops);
 }
 
 bool HasCapturedStringOperand(TF::IfRegionOp* if_op) {
@@ -83,9 +98,10 @@
   return string_operand;
 }
 
-LogicalResult MarkUncompilableOps(Block* block) {
+LogicalResult MarkUncompilableOps(
+    Block* block, llvm::DenseSet<OperationName>& supported_ops) {
   block->walk([&](Operation* op) {
-    if (!IsSupportedOp(*op)) {
+    if (!IsSupportedOp(*op, supported_ops)) {
       op->setAttr(kXlaOutsideCompilationAttr,
                   StringAttr::get("auto", op->getContext()));
     }
@@ -101,9 +117,21 @@
 
 void MarkOpsForOutsideCompilation::runOnOperation() {
   auto module = getOperation();
+  OwningRewritePatternList patterns;
+  mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
+
+  // `supported_ops` contains the name of all of the ops that can potentially be
+  // lowered into HLO on the device. This doesn't always mean that the op can
+  // be lowered in the future passes but if the op is not in this set, it can't
+  // be lowered in a subsequent pass.
+  llvm::DenseSet<OperationName> supported_ops;
+  for (auto& pattern : patterns) {
+    supported_ops.insert(*pattern->getRootKind());
+  }
+  AddSupportedControlFlowOps(module.getContext(), &supported_ops);
 
   auto result = module.walk([&](tf_device::ClusterOp cluster) {
-    if (failed(MarkUncompilableOps(&cluster.GetBody())))
+    if (failed(MarkUncompilableOps(&cluster.GetBody(), supported_ops)))
       return WalkResult::interrupt();
 
     return WalkResult::advance();