Change to use the new graph partitioning approach.

PiperOrigin-RevId: 449084237
diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD
index cb364dd..1f9dc2b 100644
--- a/tensorflow/core/tfrt/utils/BUILD
+++ b/tensorflow/core/tfrt/utils/BUILD
@@ -190,6 +190,7 @@
     srcs = ["tfrt_graph_execution_state.cc"],
     hdrs = ["tfrt_graph_execution_state.h"],
     deps = [
+        ":graph_partition",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
         "//tensorflow/compiler/mlir/tensorflow:upgrade_graph",
         "//tensorflow/core:core_cpu_base",
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
index a896543..13e7140 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
@@ -32,7 +32,6 @@
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/common_runtime/partitioning_utils.h"
 #include "tensorflow/core/common_runtime/placer.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/function.h"
@@ -50,6 +49,7 @@
 #include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tfrt/utils/graph_partition.h"
 #include "tensorflow/core/util/dump_graph.h"
 
 namespace tensorflow {
@@ -313,8 +313,11 @@
 // This is done by partitioning `graph` and add Send/Recv ops on the edges
 // across devices.
 StatusOr<std::unique_ptr<Graph>> MaybeInsertTransferOps(
-    const FallbackState& fallback_state, const std::vector<std::string>& inputs,
-    const std::vector<std::string>& outputs, std::unique_ptr<Graph> graph) {
+    const std::string& graph_func_name, const FallbackState& fallback_state,
+    const std::vector<std::string>& inputs,
+    const std::vector<std::string>& outputs,
+    const std::vector<std::string>& control_outputs,
+    std::unique_ptr<Graph> graph) {
   // Skip inserting transfer ops if this is a TPU graph.
   // Our stack currently cannot run the old bridge on TPU graphs, as it will
   // generate ops that are not supported by the subsequent MLIR passes.
@@ -353,7 +356,9 @@
   // Insert send/recv ops to the graph.
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<Graph> new_graph,
-      InsertTransferOps(fallback_state.device_set(), std::move(graph)));
+      InsertTransferOps(graph_func_name, fallback_state.device_set(),
+                        cpu_device, inputs, outputs, control_outputs,
+                        std::move(graph)));
   if (VLOG_IS_ON(1)) {
     DumpGraphToFile("after_transfer_ops_insertion", *new_graph);
   }
@@ -429,10 +434,12 @@
   result.grappler_duration = absl::Now() - grappler_start_time;
 
   if (options_.enable_tfrt_gpu) {
-    TF_ASSIGN_OR_RETURN(result.graph,
-                        MaybeInsertTransferOps(fallback_state_, inputs,
-                                               graph_import_config.outputs,
-                                               std::move(result.graph)));
+    TF_ASSIGN_OR_RETURN(
+        result.graph,
+        MaybeInsertTransferOps(
+            graph_import_config.graph_func_name, fallback_state_, inputs,
+            graph_import_config.outputs, graph_import_config.control_outputs,
+            std::move(result.graph)));
 
     // Update `control_outputs` as there might be newly added Send ops.
     for (const Node* node : result.graph->nodes()) {
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
index 6920772..4fa47ad 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
@@ -42,7 +42,9 @@
 using ::testing::EqualsProto;
 using ::testing::HasSubstr;
 using ::testing::IsEmpty;
+using ::testing::NotNull;
 using ::testing::Pair;
+using ::testing::SizeIs;
 using ::testing::proto::IgnoringFieldPaths;
 using ::testing::proto::IgnoringRepeatedFieldOrdering;
 
@@ -742,6 +744,16 @@
   CompareGraphs(expected, *graph_execution_state->original_graph_def());
 }
 
+// An auxiliary struct to verify the graph after partitioning and inserting
+// transfer ops.
+struct GraphInfo {
+  NodeDef* input_node = nullptr;
+  NodeDef* output_node = nullptr;
+  NodeDef* stateful_partitioned_call_node = nullptr;
+  std::vector<NodeDef*> partitioned_call_nodes;
+  std::vector<FunctionDef> fdefs;
+};
+
 class InsertTransferOpsTest : public grappler::GrapplerTest {
  protected:
   void SetUp() override {
@@ -758,6 +770,36 @@
         std::make_unique<FallbackState>(options, std::move(devices), fdef_lib_);
   }
 
+  GraphInfo GetGraphInfo(const std::string& input, const std::string& output,
+                         GraphDef& graphdef) {
+    GraphInfo graph_info;
+    for (NodeDef& node : *graphdef.mutable_node()) {
+      if (node.op() == "PartitionedCall") {
+        graph_info.partitioned_call_nodes.push_back(&node);
+      } else if (node.op() == "StatefulPartitionedCall") {
+        graph_info.stateful_partitioned_call_node = &node;
+      } else if (node.name() == input) {
+        graph_info.input_node = &node;
+      } else if (node.name() == output) {
+        graph_info.output_node = &node;
+      }
+    }
+
+    // Find the corresponding function called by the PartitionedCall nodes.
+    absl::flat_hash_map<std::string, FunctionDef> func_name_to_func;
+    for (const FunctionDef& fdef : graphdef.library().function()) {
+      func_name_to_func[fdef.signature().name()] = fdef;
+    }
+    for (NodeDef* node : graph_info.partitioned_call_nodes) {
+      CHECK(node->attr().contains("f"));
+      CHECK(func_name_to_func.contains(node->attr().at("f").func().name()));
+      const FunctionDef& fdef =
+          func_name_to_func.at(node->attr().at("f").func().name());
+      graph_info.fdefs.push_back(fdef);
+    }
+    return graph_info;
+  }
+
   std::unique_ptr<FallbackState> fallback_state_;
   Device* device0_ = nullptr;  // Not owned.
   Device* device1_ = nullptr;  // Not owned.
@@ -812,17 +854,29 @@
       auto optimized_graph,
       graph_execution_state->CreateOptimizedGraph(graph_import_config));
 
-  // Verify that two paris of Send/Recv nodes are added.
-  int send_count = 0, recv_count = 0;
-  for (const auto* op : optimized_graph.graph->op_nodes()) {
-    if (op->IsSend())
-      ++send_count;
-    else if (op->IsRecv())
-      ++recv_count;
+  GraphDef new_graphdef;
+  optimized_graph.graph->ToGraphDef(&new_graphdef);
+
+  GraphInfo graph_info =
+      GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
+
+  ASSERT_THAT(graph_info.input_node, NotNull());
+  ASSERT_THAT(graph_info.output_node, NotNull());
+  ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
+  ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
+
+  // Verify that each partition contains a _Send op and a _Recv op.
+  for (const FunctionDef& fdef : graph_info.fdefs) {
+    int send_count = 0, recv_count = 0;
+    for (const NodeDef& node : fdef.node_def()) {
+      if (node.op() == "_Send")
+        ++send_count;
+      else if (node.op() == "_Recv")
+        ++recv_count;
+    }
+    EXPECT_EQ(send_count, 1);
+    EXPECT_EQ(recv_count, 1);
   }
-  EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7);
-  EXPECT_EQ(send_count, 2);
-  EXPECT_EQ(recv_count, 2);
 }
 
 TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) {
@@ -911,105 +965,29 @@
       auto optimized_graph,
       graph_execution_state->CreateOptimizedGraph(graph_import_config));
 
-  // Verify that the resultant graph has no PartitionedCall ops, function body
-  // is inlined into the main graph, and send/recv ops are added.
-  int partitioned_call_count = 0, mul_count = 0, send_count = 0, recv_count = 0;
-  for (const auto* op : optimized_graph.graph->op_nodes()) {
-    if (op->IsPartitionedCall())
-      ++partitioned_call_count;
-    else if (op->IsSend())
-      ++send_count;
-    else if (op->IsRecv())
-      ++recv_count;
-    else if (op->type_string() == "Mul")
-      ++mul_count;
-  }
+  GraphDef new_graphdef;
+  optimized_graph.graph->ToGraphDef(&new_graphdef);
 
-  EXPECT_EQ(partitioned_call_count, 0);
-  EXPECT_EQ(send_count, 2);
-  EXPECT_EQ(recv_count, 2);
-  EXPECT_EQ(mul_count, 1);
-}
+  GraphInfo graph_info =
+      GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
 
-TEST_F(InsertTransferOpsTest, AppendIdentityN) {
-  GraphDef graphdef;
-  {
-    Scope scope = Scope::NewRootScope();
-    Scope scope1 = scope.WithDevice(device0_->name());
-    Scope scope2 = scope.WithDevice(device1_->name());
+  ASSERT_THAT(graph_info.input_node, NotNull());
+  ASSERT_THAT(graph_info.output_node, NotNull());
+  ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
+  ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
 
-    // A graph with two nodes assigned on different devices.
-    // a(Const, on device0) -> b(Abs, on device1)
-    Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1});
-    Output b = ops::Abs(scope2.WithOpName("b"), a);
-
-    TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
-
-    // There is no IdentityN/Send/Recv nodes originally.
-    int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0,
-        recv_count = 0;
-    for (const auto* op : scope.graph()->op_nodes()) {
-      if (op->type_string() == "IdentityN")
-        ++identity_count;
-      else if (op->IsConstant())
-        ++const_count;
-      else if (op->type_string() == "Abs")
-        ++abs_count;
-      else if (op->IsSend())
+  // Verify that each partition contains a _Send op and a _Recv op.
+  for (const FunctionDef& fdef : graph_info.fdefs) {
+    int send_count = 0, recv_count = 0;
+    for (const NodeDef& node : fdef.node_def()) {
+      if (node.op() == "_Send")
         ++send_count;
-      else if (op->IsRecv())
+      else if (node.op() == "_Recv")
         ++recv_count;
     }
-    ASSERT_EQ(scope.graph()->num_op_nodes(), 2);
-    ASSERT_EQ(identity_count, 0);
-    ASSERT_EQ(const_count, 1);
-    ASSERT_EQ(abs_count, 1);
-    ASSERT_EQ(send_count, 0);
-    ASSERT_EQ(recv_count, 0);
+    EXPECT_EQ(send_count, 1);
+    EXPECT_EQ(recv_count, 1);
   }
-  TfrtGraphExecutionState::Options options;
-  options.run_placer_grappler_on_functions = false;
-  options.enable_tfrt_gpu = true;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto graph_execution_state,
-      TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
-
-  tensorflow::GraphImportConfig graph_import_config;
-  graph_import_config.prune_unused_nodes = true;
-  graph_import_config.enable_shape_inference = false;
-  tensorflow::ArrayInfo array_info;
-  array_info.imported_dtype = DT_FLOAT;
-  array_info.shape.set_unknown_rank(true);
-  graph_import_config.inputs["a"] = array_info;
-  graph_import_config.outputs = {"b"};
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto optimized_graph,
-      graph_execution_state->CreateOptimizedGraph(graph_import_config));
-  GraphDef optimized_graphdef;
-  optimized_graph.graph->ToGraphDef(&optimized_graphdef);
-
-  // Verify that IdentityN/Send/Recv nodes are added.
-  int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0,
-      recv_count = 0;
-  for (const auto* op : optimized_graph.graph->op_nodes()) {
-    if (op->type_string() == "IdentityN")
-      ++identity_count;
-    else if (op->IsConstant())
-      ++const_count;
-    else if (op->type_string() == "Abs")
-      ++abs_count;
-    else if (op->IsSend())
-      ++send_count;
-    else if (op->IsRecv())
-      ++recv_count;
-  }
-  EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7);
-  EXPECT_EQ(identity_count, 1);
-  EXPECT_EQ(const_count, 1);
-  EXPECT_EQ(abs_count, 1);
-  EXPECT_EQ(send_count, 2);
-  EXPECT_EQ(recv_count, 2);
 }
 
 }  // namespace