Add option to GraphPruning pass to allow it to skip main function.
PiperOrigin-RevId: 285514262
Change-Id: I62e1bbc4763727d87ecfa88e9a89d7a465cbd939
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
index 8585790..771ad5e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
@@ -167,3 +167,16 @@
}
return
}
+
+// Check that @main function is pruned.
+// CHECK-LABEL: func @main
+func @main() {
+ tf_executor.graph {
+ // CHECK-NOT: tf_executor.island
+ %0 = tf_executor.island {
+ tf_executor.yield
+ }
+ tf_executor.fetch
+ }
+ return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir
new file mode 100644
index 0000000..86568cc
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir
@@ -0,0 +1,14 @@
+// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
+
+// Check that @main function is skipped by default.
+// CHECK-LABEL: func @main
+func @main() {
+ tf_executor.graph {
+ // CHECKT: tf_executor.island
+ %0 = tf_executor.island {
+ tf_executor.yield
+ }
+ tf_executor.fetch
+ }
+ return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
index 882e769..23cdebc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
@@ -86,17 +86,36 @@
// This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public FunctionPass<GraphPruning> {
void runOnFunction() override {
- getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
+ FuncOp func = getFunction();
+ if (func.getName() == "main" && skip_main_func) return;
+ func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
}
+
+ struct Options : public PassOptions<Options> {
+ Option<bool> skip_main_func{
+ *this, "skip-main-func",
+ llvm::cl::desc("skip graph pruning for main function"),
+ llvm::cl::init(false)};
+ };
+
+ explicit GraphPruning(bool skip_main_func)
+ : FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
+
+ explicit GraphPruning(const Options& option)
+ : GraphPruning(option.skip_main_func) {}
+
+ private:
+ bool skip_main_func;
};
} // namespace
-std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
- return std::make_unique<GraphPruning>();
+std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
+ bool skip_main_func) {
+ return std::make_unique<GraphPruning>(skip_main_func);
}
-static PassRegistration<GraphPruning> pass(
+static PassRegistration<GraphPruning, GraphPruning::Options> pass(
"tf-executor-graph-pruning",
"Prune unreachable nodes in a TensorFlow Graph.");
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index fca1c02..d890494 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -79,7 +79,8 @@
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
// Create a pass to prune tf_executor.graph from dead nodes.
-std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
+std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
+ bool skip_main_func = false);
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph);