Update ReplicateToIslandPass to only create a sink island pinning control dependencies instead of forwarding all replica data results.

Creating a sink island forwarding results was too restrictive after TPUMergeVariablesWithExecutePass. This fixes a deadlock in the 3D UNet model.

PiperOrigin-RevId: 309784369
Change-Id: Iae2f9c254085ec67b376c55b600eced3d7bf37ae
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
index cfbd112..c8b4ad2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
@@ -18,11 +18,10 @@
   return
 }
 
-// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK: %[[ISLAND_0:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]])
-// CHECK: %[[ISLAND_1:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]])
-// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]])
+// CHECK: %[[CT_0:.*]] = tf_executor.ControlTrigger
+// CHECK: %[[CT_1:.*]] = tf_executor.ControlTrigger
+// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]])
+// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]])
 
 
 // Tests devices are not remapped if no devices were defined in replicate.
@@ -100,64 +99,23 @@
 // CHECK: device = "/GPU:1"
 
 
-// Tests unused per replica island are added as a control dependency to the
-// island forwarding per replica results.
-// CHECK-LABEL: func @unused_replica_control
-// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
-func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
-  %0 = tf_executor.graph {
-    %1 = tf_executor.ControlTrigger {}
-    %2:2 = tf_executor.island(%1) {
-      %3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32} {
-        %4 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
-        %5 = "tf.opB"(%4) : (tensor<i1>) -> tensor<i1>
-        tf_device.return %4, %5 : tensor<i1>, tensor<i1>
+// Tests replicate with control dependency output has each expanded replica
+// control pinned to a sink island.
+// CHECK-LABEL: func @replicate_control
+func @replicate_control() {
+  tf_executor.graph {
+    %1 = tf_executor.island {
+      tf_device.replicate {n = 2 : i32} {
+        tf_device.return
       }
-      tf_executor.yield %3#0 : tensor<i1>
+      tf_executor.yield
     }
-    tf_executor.fetch %2#0 : tensor<i1>
+    tf_executor.fetch %1 : !tf_executor.control
   }
   return
 }
 
-// CHECK:      %[[CT:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK:      %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]])
-// CHECK:        %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]])
-// CHECK:        %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]])
-// CHECK:        tf_executor.yield %[[OP_A_0]], %[[OP_B_0]]
-// CHECK:      %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]])
-// CHECK:        %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]])
-// CHECK:        %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]])
-// CHECK:        tf_executor.yield %[[OP_A_1]], %[[OP_B_1]]
-// CHECK:      %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]])
-// CHECK:        tf_executor.yield %[[ISLAND_0]]#0
-// CHECK:      tf_executor.fetch %[[ISLAND_2]]
-
-
-// Tests replicate with dynamic result shapes uses its inner ops to determine
-// types for sink island.
-// CHECK-LABEL: func @replicate_body_result_types
-func @replicate_body_result_types() {
-  "tf_executor.graph"() ( {
-    %0:3 = "tf_executor.island"() ( {
-      %1:2 = "tf_device.replicate"() ( {
-      ^bb0:
-        %a = "tf.opA"() : () -> tensor<i1>
-        "tf_device.return"(%a) : (tensor<i1>) -> ()
-      }) {n = 2 : i32} : () -> (tensor<*xi1>, tensor<*xi1>)
-      "tf_executor.yield"(%1#0, %1#1) : (tensor<*xi1>, tensor<*xi1>) -> ()
-    }) : () -> (tensor<*xi1>, tensor<*xi1>, !tf_executor.control)
-    "tf_executor.fetch"(%0#2) : (!tf_executor.control) -> ()
-  }) : () -> ()
-  return
-}
-
-// CHECK:      %[[ISLAND_0:.*]], %{{.*}} = tf_executor.island
-// CHECK-NEXT:   %[[OP_A_0:.*]] = "tf.opA"()
-// CHECK-NEXT:   tf_executor.yield %[[OP_A_0]] : tensor<i1>
-// CHECK:      %[[ISLAND_1:.*]], %{{.*}} = tf_executor.island
-// CHECK-NEXT:   %[[OP_A_1:.*]] = "tf.opA"()
-// CHECK-NEXT:   tf_executor.yield %[[OP_A_1]] : tensor<i1>
-// CHECK:      %[[ISLAND_2:.*]]:2, %[[ISLAND_2_CTRL:.*]] = tf_executor.island
-// CHECK-NEXT:   tf_executor.yield %[[ISLAND_0]], %[[ISLAND_1]] : tensor<i1>, tensor<i1>
-// CHECK:      tf_executor.fetch %[[ISLAND_2_CTRL]] : !tf_executor.control
+// CHECK: %[[REPLICA_0:.*]] = tf_executor.island
+// CHECK: %[[REPLICA_1:.*]] = tf_executor.island
+// CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]])
+// CHECK: tf_executor.fetch %[[SINK]]
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index 30bc1a2..fe9283d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -19,6 +19,7 @@
 #include <memory>
 #include <utility>
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
@@ -107,10 +108,9 @@
 
 // Creates islands per replica from `tf_device.replicate` region and remap
 // replicate results with new island outputs. A single island is created to
-// forward results from each replica island. Control dependencies of individual
-// replicas are added to the single island if the single island does not emit
-// a result from the respective replica. Devices are remapped from aliased
-// devices to explicit devices, for `tf_device.launch` ops.
+// forward control dependencies if there is a control dependency output from the
+// replicate island. Devices are remapped from aliased devices to explicit
+// devices, for `tf_device.launch` ops.
 //
 // For example, the following:
 //
@@ -156,9 +156,6 @@
 //   }) {device = "/DEVICE:3"} : () -> tensor<i1>
 //   tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
 // }
-// %6:2 = tf_executor.island(%3#2) {
-//   tf_executor.yield %0#0 : tensor<i1>
-// }
 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
                                          tf_executor::IslandOp island_op,
                                          tf_device::ReplicateOp replicate_op) {
@@ -181,28 +178,25 @@
           replica_result_and_idx.value();
 
   // Remap replicate results to per replica result.
-  replicate_op.replaceAllUsesWith(replicas_outputs);
+  for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
+    std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
 
-  // Collect per replica control dependency and add to island operand if replica
-  // island has no uses.
-  llvm::SmallVector<Value, 8> island_operands;
-  for (auto& replica : replicas)
-    if (replica.use_empty()) island_operands.push_back(replica.control());
+  // Add sink island to pin all replicas as a control dependency if there is a
+  // control dependency leading from the replicate originally.
+  if (!island_op.control().use_empty()) {
+    llvm::SmallVector<Value, 8> island_operands;
+    for (auto& replica : replicas) island_operands.push_back(replica.control());
 
-  // Create single island forwarding per replica result.
-  builder.setInsertionPoint(island_op);
-  auto island_sink = builder.create<tf_executor::IslandOp>(
-      island_op.getLoc(),
-      llvm::to_vector<8>(island_op.GetYield().fetches().getTypes()),
-      tf_executor::ControlType::get(island_op.getContext()), island_operands);
-  island_sink.body().push_back(new Block);
-
-  // Move replicate island YieldOp over to new single island.
-  island_op.GetYield().getOperation()->moveBefore(
-      &island_sink.GetBody(), island_sink.GetBody().begin());
-
-  // Remap island results.
-  island_op.replaceAllUsesWith(island_sink);
+    builder.setInsertionPoint(island_op);
+    auto island_sink = builder.create<tf_executor::IslandOp>(
+        island_op.getLoc(), llvm::ArrayRef<Type>{},
+        tf_executor::ControlType::get(island_op.getContext()), island_operands);
+    island_sink.body().push_back(new Block);
+    builder.setInsertionPointToEnd(&island_sink.GetBody());
+    builder.create<tf_executor::YieldOp>(island_op.getLoc(),
+                                         llvm::ArrayRef<Value>{});
+    island_op.control().replaceAllUsesWith(island_sink.control());
+  }
 
   island_op.erase();
   return success();