Add option to GraphPruning pass to allow it to skip main function.

PiperOrigin-RevId: 285514262
Change-Id: I62e1bbc4763727d87ecfa88e9a89d7a465cbd939
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
index 8585790..771ad5e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
@@ -167,3 +167,16 @@
   }
   return
 }
+
+// Check that @main function is pruned.
+// CHECK-LABEL: func @main
+func @main() {
+  tf_executor.graph {
+    // CHECK-NOT: tf_executor.island
+    %0 = tf_executor.island {
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir
new file mode 100644
index 0000000..86568cc
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir
@@ -0,0 +1,14 @@
+// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
+
+// Check that @main function is skipped by default.
+// CHECK-LABEL: func @main
+func @main() {
+  tf_executor.graph {
+    // CHECKT: tf_executor.island
+    %0 = tf_executor.island {
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
index 882e769..23cdebc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
@@ -86,17 +86,36 @@
 // This transformation pass prunes a TF graph eliminating dead-nodes.
 struct GraphPruning : public FunctionPass<GraphPruning> {
   void runOnFunction() override {
-    getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
+    FuncOp func = getFunction();
+    if (func.getName() == "main" && skip_main_func) return;
+    func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
   }
+
+  struct Options : public PassOptions<Options> {
+    Option<bool> skip_main_func{
+        *this, "skip-main-func",
+        llvm::cl::desc("skip graph pruning for main function"),
+        llvm::cl::init(false)};
+  };
+
+  explicit GraphPruning(bool skip_main_func)
+      : FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
+
+  explicit GraphPruning(const Options& option)
+      : GraphPruning(option.skip_main_func) {}
+
+ private:
+  bool skip_main_func;
 };
 
 }  // namespace
 
-std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
-  return std::make_unique<GraphPruning>();
+std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
+    bool skip_main_func) {
+  return std::make_unique<GraphPruning>(skip_main_func);
 }
 
-static PassRegistration<GraphPruning> pass(
+static PassRegistration<GraphPruning, GraphPruning::Options> pass(
     "tf-executor-graph-pruning",
     "Prune unreachable nodes in a TensorFlow Graph.");
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index fca1c02..d890494 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -79,7 +79,8 @@
 std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
 
 // Create a pass to prune tf_executor.graph from dead nodes.
-std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
+std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
+    bool skip_main_func = false);
 
 // Prunes unreachable operations of a tf_executor.graph operation.
 void PruneGraph(GraphOp graph);