Dependency optimizer groups cross-host control edges to reduce RPC traffic.

PiperOrigin-RevId: 279344543
Change-Id: Iff504d1be03b02afaf9a945fc5b19b3a7ac57b5a
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index ace1308..80e6bcf 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -362,6 +362,8 @@
         "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
         "//tensorflow/core/grappler/utils:grappler_test",
         "//tensorflow/core/grappler/utils:topological_sort",
+        "//tensorflow/core/platform:test",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 0734c32..fed6003 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -19,6 +19,7 @@
 
 #include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/grappler_item.h"
@@ -613,11 +614,18 @@
 // We can reduce cross-device communication by introducing an intermediate
 // NoOp node C' on device X and rewriting the control edges to:
 // A->C', B->C', C' -> C
-void DependencyOptimizer::GroupCrossDeviceControlEdges() {
+void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
+  VLOG(1)
+      << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
+      << host_granularity;
   const int num_nodes = optimized_graph_->node_size();
   for (int i = 0; i < num_nodes; ++i) {
     NodeDef* node = optimized_graph_->mutable_node(i);
     if (node->device().empty()) continue;
+    string rest, node_device = node->device();
+    if (host_granularity) {
+      DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
+    }
 
     // Creates new noop nodes for devices on which multiple control inputs are
     // located.
@@ -630,11 +638,19 @@
     for (int j = 0; j < node->input_size(); ++j) {
       if (IsControlInput(node->input(j))) {
         const NodeDef* input = node_map_->GetNode(node->input(j));
-        if (input != nullptr && !input->device().empty() &&
-            input->device() != node->device()) {
-          auto emplace_result = noops.emplace(input->device(), nullptr);
+        if (input == nullptr || input->device().empty()) continue;
+        string input_device = input->device();
+        if (host_granularity) {
+          DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
+                                           &rest);
+        }
+        if (input_device != node_device) {
+          VLOG(2) << "Cross-device " << node->name() << " " << input->device()
+                  << " -> " << node->device();
+          auto emplace_result = noops.emplace(input_device, nullptr);
           if (!emplace_result.second &&
               emplace_result.first->second == nullptr) {
+            VLOG(2) << "Duplicate input device from " << node->name();
             // This is the second cross-device control input from the same
             // device. Creates an intermediate noop node on that device.
             string group_name;
@@ -654,6 +670,8 @@
             noop->set_op("NoOp");
             node_map_->AddNode(noop->name(), noop);
             emplace_result.first->second = noop;
+            VLOG(1) << "GroupCrossDeviceControlEdges: Added "
+                    << SummarizeNodeDef(*noop);
           }
         }
       }
@@ -668,10 +686,16 @@
         if (input == nullptr) {
           ++pos;
         } else {
-          auto it = noops.find(input->device());
+          string input_device = input->device();
+          if (host_granularity) {
+            DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
+                                             &rest);
+          }
+          auto it = noops.find(input_device);
           if (it == noops.end() || it->second == nullptr) {
             ++pos;
           } else {
+            VLOG(2) << "Rewriting input from " << input_name;
             node->mutable_input()->SwapElements(pos, node->input_size() - 1);
             node->mutable_input()->RemoveLast();
             it->second->add_input(AsControlDependency(*input));
@@ -725,7 +749,11 @@
     // Dedup control inputs.
     CleanControlInputs();
 
-    GroupCrossDeviceControlEdges();
+    // Merge multiple control edges from the same device.
+    GroupCrossDeviceControlEdges(/*host_granularity=*/false);
+
+    // Merge control edges from the same host to reduce RPC traffic.
+    GroupCrossDeviceControlEdges(/*host_granularity=*/true);
   }
 
   return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index feeddc2..bb34b25 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -71,8 +71,9 @@
   // Main driver of dependency optimizations.
   Status OptimizeDependencies();
   // Replaces multiple cross-device control edges from the same device with a
-  // single control edge.
-  void GroupCrossDeviceControlEdges();
+  // single control edge.  If `host_granularity` is true then group control
+  // edges from all devices on the same host.
+  void GroupCrossDeviceControlEdges(bool host_granularity);
 
   bool fetch_nodes_known_;
   std::unordered_set<string> nodes_to_preserve_;
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 074b0dc..8c8107c 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+
+#include "absl/strings/match.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -888,6 +890,61 @@
   CompareGraphs(expected, output);
 }
 
+TEST_F(DependencyOptimizerTest, GroupCrossHostControlDeps) {
+  GrapplerItem item;
+  {
+    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+    std::vector<Operation> ops;
+    Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:0"),
+                                  {1, 2}, DT_FLOAT);
+    for (int t = 0; t < 4; ++t) {
+      for (int c = 0; c < 8; ++c) {
+        string opname = absl::StrCat("t", t, "/c", c);
+        string device = absl::StrCat("/task:", t, "/device:TPU:", c);
+        Output output = ops::RandomUniform(
+            s.WithOpName(opname).WithDevice(device), {1, 2}, DT_FLOAT);
+        ops.push_back(output.op());
+      }
+    }
+    // Node with cross-device dependencies.
+    auto fetch = ops::Identity(
+        s.WithOpName("f").WithControlDependencies(ops).WithDevice("/CPU:0"),
+        {a});
+
+    TF_CHECK_OK(s.ToGraphDef(&item.graph));
+    item.fetch.push_back("f");
+  }
+
+  GraphDef expected;
+  {
+    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+    TF_CHECK_OK(s.ToGraphDef(&expected));
+  }
+
+  DependencyOptimizer optimizer;
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
+  std::set<string> tasks;
+  for (const auto& n : output.node()) {
+    if (n.op() == "NoOp") {
+      EXPECT_TRUE(absl::StartsWith(n.name(), "GroupCrossDeviceControlEdges"));
+      EXPECT_EQ(n.input_size(), 8);
+      tasks.insert(n.device());
+    }
+
+    if (n.name() == "f") {
+      EXPECT_EQ(n.input_size(), 5);
+      for (const auto& i : n.input()) {
+        EXPECT_TRUE(i == "a" ||
+                    absl::StartsWith(i, "^GroupCrossDeviceControlEdges"));
+      }
+    }
+  }
+  EXPECT_EQ(tasks.size(), 4);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow