Control dependency flowing between devices should have a higher priority virtual channel than data flows, otherwise there is a risk of control transfer getting stuck behind a large data transfer. Current channel device implementation is simple, it doesn't model such prioritization.

Instead of adding complexity to channel device, tf-sim could let control deps bypass channel device and (magically) flow across devices in zero time. Control deps are small, latency dominated transfers. Channel devices are only good at modeling BW bound transfers, so bypassing channel device for control deps should not be so bad.

PiperOrigin-RevId: 351727352
Change-Id: Ibe20b9018427e30cffe9e7fc1cb3713ebe47510b
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index c533dc6..089c24e 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -498,7 +498,12 @@
       const string in_device = DeviceName(input_node);
       const auto input_node_port_num = NodePosition(input_node_name);
 
-      if (curr_node_device == in_device) {
+      // Control dependencies should be treated as high priority. Current
+      // Channel device doesn't model a separate virual channel for control v/s
+      // data transfers. So in the interim, it may be okay to let control
+      // dependencies magically flow across devices bypassing the channel
+      // device.
+      if (curr_node_device == in_device || IsControlInput(input_node_name)) {
         // Same device: connect input_node and curr_node directly.
         curr_node_state.inputs.push_back(
             std::make_pair(input_node, input_node_port_num));
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 2da1551..c22a6ed 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -2982,9 +2982,10 @@
   // Same number of _Send and _Recv.
   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
 
-  // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
-  EXPECT_EQ(op_count.at(kRecv), 4);
-  EXPECT_EQ(op_count.at(kSend), 4);
+  // Expect 3 Send and Recvs each: port 0, 1, and, 2.
+  // Control dependency bypasses the channel.
+  EXPECT_EQ(op_count.at(kRecv), 3);
+  EXPECT_EQ(op_count.at(kSend), 3);
 
   // Helper lambda for extracting output Tensor size.
   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
@@ -3006,9 +3007,6 @@
   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
-  // Control dependency size is 4B.
-  EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
-  EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
 }
 
 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {