Change to use the new graph partitioning approach.
PiperOrigin-RevId: 449084237
diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD
index cb364dd..1f9dc2b 100644
--- a/tensorflow/core/tfrt/utils/BUILD
+++ b/tensorflow/core/tfrt/utils/BUILD
@@ -190,6 +190,7 @@
srcs = ["tfrt_graph_execution_state.cc"],
hdrs = ["tfrt_graph_execution_state.h"],
deps = [
+ ":graph_partition",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:upgrade_graph",
"//tensorflow/core:core_cpu_base",
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
index a896543..13e7140 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
@@ -32,7 +32,6 @@
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
@@ -50,6 +49,7 @@
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tfrt/utils/graph_partition.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
@@ -313,8 +313,11 @@
// This is done by partitioning `graph` and add Send/Recv ops on the edges
// across devices.
StatusOr<std::unique_ptr<Graph>> MaybeInsertTransferOps(
- const FallbackState& fallback_state, const std::vector<std::string>& inputs,
- const std::vector<std::string>& outputs, std::unique_ptr<Graph> graph) {
+ const std::string& graph_func_name, const FallbackState& fallback_state,
+ const std::vector<std::string>& inputs,
+ const std::vector<std::string>& outputs,
+ const std::vector<std::string>& control_outputs,
+ std::unique_ptr<Graph> graph) {
// Skip inserting transfer ops if this is a TPU graph.
// Our stack currently cannot run the old bridge on TPU graphs, as it will
// generate ops that are not supported by the subsequent MLIR passes.
@@ -353,7 +356,9 @@
// Insert send/recv ops to the graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Graph> new_graph,
- InsertTransferOps(fallback_state.device_set(), std::move(graph)));
+ InsertTransferOps(graph_func_name, fallback_state.device_set(),
+ cpu_device, inputs, outputs, control_outputs,
+ std::move(graph)));
if (VLOG_IS_ON(1)) {
DumpGraphToFile("after_transfer_ops_insertion", *new_graph);
}
@@ -429,10 +434,12 @@
result.grappler_duration = absl::Now() - grappler_start_time;
if (options_.enable_tfrt_gpu) {
- TF_ASSIGN_OR_RETURN(result.graph,
- MaybeInsertTransferOps(fallback_state_, inputs,
- graph_import_config.outputs,
- std::move(result.graph)));
+ TF_ASSIGN_OR_RETURN(
+ result.graph,
+ MaybeInsertTransferOps(
+ graph_import_config.graph_func_name, fallback_state_, inputs,
+ graph_import_config.outputs, graph_import_config.control_outputs,
+ std::move(result.graph)));
// Update `control_outputs` as there might be newly added Send ops.
for (const Node* node : result.graph->nodes()) {
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
index 6920772..4fa47ad 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc
@@ -42,7 +42,9 @@
using ::testing::EqualsProto;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
+using ::testing::NotNull;
using ::testing::Pair;
+using ::testing::SizeIs;
using ::testing::proto::IgnoringFieldPaths;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
@@ -742,6 +744,16 @@
CompareGraphs(expected, *graph_execution_state->original_graph_def());
}
+// An auxiliary struct to verify the graph after partitioning and inserting
+// transfer ops.
+struct GraphInfo {
+ NodeDef* input_node = nullptr;
+ NodeDef* output_node = nullptr;
+ NodeDef* stateful_partitioned_call_node = nullptr;
+ std::vector<NodeDef*> partitioned_call_nodes;
+ std::vector<FunctionDef> fdefs;
+};
+
class InsertTransferOpsTest : public grappler::GrapplerTest {
protected:
void SetUp() override {
@@ -758,6 +770,36 @@
std::make_unique<FallbackState>(options, std::move(devices), fdef_lib_);
}
+ GraphInfo GetGraphInfo(const std::string& input, const std::string& output,
+ GraphDef& graphdef) {
+ GraphInfo graph_info;
+ for (NodeDef& node : *graphdef.mutable_node()) {
+ if (node.op() == "PartitionedCall") {
+ graph_info.partitioned_call_nodes.push_back(&node);
+ } else if (node.op() == "StatefulPartitionedCall") {
+ graph_info.stateful_partitioned_call_node = &node;
+ } else if (node.name() == input) {
+ graph_info.input_node = &node;
+ } else if (node.name() == output) {
+ graph_info.output_node = &node;
+ }
+ }
+
+ // Find the corresponding function called by the PartitionedCall nodes.
+ absl::flat_hash_map<std::string, FunctionDef> func_name_to_func;
+ for (const FunctionDef& fdef : graphdef.library().function()) {
+ func_name_to_func[fdef.signature().name()] = fdef;
+ }
+ for (NodeDef* node : graph_info.partitioned_call_nodes) {
+ CHECK(node->attr().contains("f"));
+ CHECK(func_name_to_func.contains(node->attr().at("f").func().name()));
+ const FunctionDef& fdef =
+ func_name_to_func.at(node->attr().at("f").func().name());
+ graph_info.fdefs.push_back(fdef);
+ }
+ return graph_info;
+ }
+
std::unique_ptr<FallbackState> fallback_state_;
Device* device0_ = nullptr; // Not owned.
Device* device1_ = nullptr; // Not owned.
@@ -812,17 +854,29 @@
auto optimized_graph,
graph_execution_state->CreateOptimizedGraph(graph_import_config));
- // Verify that two paris of Send/Recv nodes are added.
- int send_count = 0, recv_count = 0;
- for (const auto* op : optimized_graph.graph->op_nodes()) {
- if (op->IsSend())
- ++send_count;
- else if (op->IsRecv())
- ++recv_count;
+ GraphDef new_graphdef;
+ optimized_graph.graph->ToGraphDef(&new_graphdef);
+
+ GraphInfo graph_info =
+ GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
+
+ ASSERT_THAT(graph_info.input_node, NotNull());
+ ASSERT_THAT(graph_info.output_node, NotNull());
+ ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
+ ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
+
+ // Verify that each partition contains a _Send op and a _Recv op.
+ for (const FunctionDef& fdef : graph_info.fdefs) {
+ int send_count = 0, recv_count = 0;
+ for (const NodeDef& node : fdef.node_def()) {
+ if (node.op() == "_Send")
+ ++send_count;
+ else if (node.op() == "_Recv")
+ ++recv_count;
+ }
+ EXPECT_EQ(send_count, 1);
+ EXPECT_EQ(recv_count, 1);
}
- EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7);
- EXPECT_EQ(send_count, 2);
- EXPECT_EQ(recv_count, 2);
}
TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) {
@@ -911,105 +965,29 @@
auto optimized_graph,
graph_execution_state->CreateOptimizedGraph(graph_import_config));
- // Verify that the resultant graph has no PartitionedCall ops, function body
- // is inlined into the main graph, and send/recv ops are added.
- int partitioned_call_count = 0, mul_count = 0, send_count = 0, recv_count = 0;
- for (const auto* op : optimized_graph.graph->op_nodes()) {
- if (op->IsPartitionedCall())
- ++partitioned_call_count;
- else if (op->IsSend())
- ++send_count;
- else if (op->IsRecv())
- ++recv_count;
- else if (op->type_string() == "Mul")
- ++mul_count;
- }
+ GraphDef new_graphdef;
+ optimized_graph.graph->ToGraphDef(&new_graphdef);
- EXPECT_EQ(partitioned_call_count, 0);
- EXPECT_EQ(send_count, 2);
- EXPECT_EQ(recv_count, 2);
- EXPECT_EQ(mul_count, 1);
-}
+ GraphInfo graph_info =
+ GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
-TEST_F(InsertTransferOpsTest, AppendIdentityN) {
- GraphDef graphdef;
- {
- Scope scope = Scope::NewRootScope();
- Scope scope1 = scope.WithDevice(device0_->name());
- Scope scope2 = scope.WithDevice(device1_->name());
+ ASSERT_THAT(graph_info.input_node, NotNull());
+ ASSERT_THAT(graph_info.output_node, NotNull());
+ ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
+ ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
- // A graph with two nodes assigned on different devices.
- // a(Const, on device0) -> b(Abs, on device1)
- Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1});
- Output b = ops::Abs(scope2.WithOpName("b"), a);
-
- TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
-
- // There is no IdentityN/Send/Recv nodes originally.
- int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0,
- recv_count = 0;
- for (const auto* op : scope.graph()->op_nodes()) {
- if (op->type_string() == "IdentityN")
- ++identity_count;
- else if (op->IsConstant())
- ++const_count;
- else if (op->type_string() == "Abs")
- ++abs_count;
- else if (op->IsSend())
+ // Verify that each partition contains a _Send op and a _Recv op.
+ for (const FunctionDef& fdef : graph_info.fdefs) {
+ int send_count = 0, recv_count = 0;
+ for (const NodeDef& node : fdef.node_def()) {
+ if (node.op() == "_Send")
++send_count;
- else if (op->IsRecv())
+ else if (node.op() == "_Recv")
++recv_count;
}
- ASSERT_EQ(scope.graph()->num_op_nodes(), 2);
- ASSERT_EQ(identity_count, 0);
- ASSERT_EQ(const_count, 1);
- ASSERT_EQ(abs_count, 1);
- ASSERT_EQ(send_count, 0);
- ASSERT_EQ(recv_count, 0);
+ EXPECT_EQ(send_count, 1);
+ EXPECT_EQ(recv_count, 1);
}
- TfrtGraphExecutionState::Options options;
- options.run_placer_grappler_on_functions = false;
- options.enable_tfrt_gpu = true;
- TF_ASSERT_OK_AND_ASSIGN(
- auto graph_execution_state,
- TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
-
- tensorflow::GraphImportConfig graph_import_config;
- graph_import_config.prune_unused_nodes = true;
- graph_import_config.enable_shape_inference = false;
- tensorflow::ArrayInfo array_info;
- array_info.imported_dtype = DT_FLOAT;
- array_info.shape.set_unknown_rank(true);
- graph_import_config.inputs["a"] = array_info;
- graph_import_config.outputs = {"b"};
-
- TF_ASSERT_OK_AND_ASSIGN(
- auto optimized_graph,
- graph_execution_state->CreateOptimizedGraph(graph_import_config));
- GraphDef optimized_graphdef;
- optimized_graph.graph->ToGraphDef(&optimized_graphdef);
-
- // Verify that IdentityN/Send/Recv nodes are added.
- int identity_count = 0, abs_count = 0, const_count = 0, send_count = 0,
- recv_count = 0;
- for (const auto* op : optimized_graph.graph->op_nodes()) {
- if (op->type_string() == "IdentityN")
- ++identity_count;
- else if (op->IsConstant())
- ++const_count;
- else if (op->type_string() == "Abs")
- ++abs_count;
- else if (op->IsSend())
- ++send_count;
- else if (op->IsRecv())
- ++recv_count;
- }
- EXPECT_EQ(optimized_graph.graph->num_op_nodes(), 7);
- EXPECT_EQ(identity_count, 1);
- EXPECT_EQ(const_count, 1);
- EXPECT_EQ(abs_count, 1);
- EXPECT_EQ(send_count, 2);
- EXPECT_EQ(recv_count, 2);
}
} // namespace