Update ReplicateToIslandPass to only create a sink island pinning control dependencies instead of forwarding all replica data results.
Creating a sink island forwarding results was too restrictive after TPUMergeVariablesWithExecutePass. This fixes a deadlock in the 3D UNet model.
PiperOrigin-RevId: 309784369
Change-Id: Iae2f9c254085ec67b376c55b600eced3d7bf37ae
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
index cfbd112..c8b4ad2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
@@ -18,11 +18,10 @@
return
}
-// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK: %[[ISLAND_0:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]])
-// CHECK: %[[ISLAND_1:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]])
-// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]])
+// CHECK: %[[CT_0:.*]] = tf_executor.ControlTrigger
+// CHECK: %[[CT_1:.*]] = tf_executor.ControlTrigger
+// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]])
+// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]])
// Tests devices are not remapped if no devices were defined in replicate.
@@ -100,64 +99,23 @@
// CHECK: device = "/GPU:1"
-// Tests unused per replica island are added as a control dependency to the
-// island forwarding per replica results.
-// CHECK-LABEL: func @unused_replica_control
-// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i1>)
-func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
- %0 = tf_executor.graph {
- %1 = tf_executor.ControlTrigger {}
- %2:2 = tf_executor.island(%1) {
- %3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32} {
- %4 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
- %5 = "tf.opB"(%4) : (tensor<i1>) -> tensor<i1>
- tf_device.return %4, %5 : tensor<i1>, tensor<i1>
+// Tests replicate with control dependency output has each expanded replica
+// control pinned to a sink island.
+// CHECK-LABEL: func @replicate_control
+func @replicate_control() {
+ tf_executor.graph {
+ %1 = tf_executor.island {
+ tf_device.replicate {n = 2 : i32} {
+ tf_device.return
}
- tf_executor.yield %3#0 : tensor<i1>
+ tf_executor.yield
}
- tf_executor.fetch %2#0 : tensor<i1>
+ tf_executor.fetch %1 : !tf_executor.control
}
return
}
-// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
-// CHECK: %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]])
-// CHECK: %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]])
-// CHECK: %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]])
-// CHECK: tf_executor.yield %[[OP_A_0]], %[[OP_B_0]]
-// CHECK: %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]])
-// CHECK: %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]])
-// CHECK: %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]])
-// CHECK: tf_executor.yield %[[OP_A_1]], %[[OP_B_1]]
-// CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]])
-// CHECK: tf_executor.yield %[[ISLAND_0]]#0
-// CHECK: tf_executor.fetch %[[ISLAND_2]]
-
-
-// Tests replicate with dynamic result shapes uses its inner ops to determine
-// types for sink island.
-// CHECK-LABEL: func @replicate_body_result_types
-func @replicate_body_result_types() {
- "tf_executor.graph"() ( {
- %0:3 = "tf_executor.island"() ( {
- %1:2 = "tf_device.replicate"() ( {
- ^bb0:
- %a = "tf.opA"() : () -> tensor<i1>
- "tf_device.return"(%a) : (tensor<i1>) -> ()
- }) {n = 2 : i32} : () -> (tensor<*xi1>, tensor<*xi1>)
- "tf_executor.yield"(%1#0, %1#1) : (tensor<*xi1>, tensor<*xi1>) -> ()
- }) : () -> (tensor<*xi1>, tensor<*xi1>, !tf_executor.control)
- "tf_executor.fetch"(%0#2) : (!tf_executor.control) -> ()
- }) : () -> ()
- return
-}
-
-// CHECK: %[[ISLAND_0:.*]], %{{.*}} = tf_executor.island
-// CHECK-NEXT: %[[OP_A_0:.*]] = "tf.opA"()
-// CHECK-NEXT: tf_executor.yield %[[OP_A_0]] : tensor<i1>
-// CHECK: %[[ISLAND_1:.*]], %{{.*}} = tf_executor.island
-// CHECK-NEXT: %[[OP_A_1:.*]] = "tf.opA"()
-// CHECK-NEXT: tf_executor.yield %[[OP_A_1]] : tensor<i1>
-// CHECK: %[[ISLAND_2:.*]]:2, %[[ISLAND_2_CTRL:.*]] = tf_executor.island
-// CHECK-NEXT: tf_executor.yield %[[ISLAND_0]], %[[ISLAND_1]] : tensor<i1>, tensor<i1>
-// CHECK: tf_executor.fetch %[[ISLAND_2_CTRL]] : !tf_executor.control
+// CHECK: %[[REPLICA_0:.*]] = tf_executor.island
+// CHECK: %[[REPLICA_1:.*]] = tf_executor.island
+// CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]])
+// CHECK: tf_executor.fetch %[[SINK]]
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index 30bc1a2..fe9283d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -19,6 +19,7 @@
#include <memory>
#include <utility>
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
@@ -107,10 +108,9 @@
// Creates islands per replica from `tf_device.replicate` region and remap
// replicate results with new island outputs. A single island is created to
-// forward results from each replica island. Control dependencies of individual
-// replicas are added to the single island if the single island does not emit
-// a result from the respective replica. Devices are remapped from aliased
-// devices to explicit devices, for `tf_device.launch` ops.
+// forward control dependencies if there is a control dependency output from the
+// replicate island. Devices are remapped from aliased devices to explicit
+// devices, for `tf_device.launch` ops.
//
// For example, the following:
//
@@ -156,9 +156,6 @@
// }) {device = "/DEVICE:3"} : () -> tensor<i1>
// tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
// }
-// %6:2 = tf_executor.island(%3#2) {
-// tf_executor.yield %0#0 : tensor<i1>
-// }
LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
tf_executor::IslandOp island_op,
tf_device::ReplicateOp replicate_op) {
@@ -181,28 +178,25 @@
replica_result_and_idx.value();
// Remap replicate results to per replica result.
- replicate_op.replaceAllUsesWith(replicas_outputs);
+ for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
+ std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
- // Collect per replica control dependency and add to island operand if replica
- // island has no uses.
- llvm::SmallVector<Value, 8> island_operands;
- for (auto& replica : replicas)
- if (replica.use_empty()) island_operands.push_back(replica.control());
+ // Add sink island to pin all replicas as a control dependency if there is a
+ // control dependency leading from the replicate originally.
+ if (!island_op.control().use_empty()) {
+ llvm::SmallVector<Value, 8> island_operands;
+ for (auto& replica : replicas) island_operands.push_back(replica.control());
- // Create single island forwarding per replica result.
- builder.setInsertionPoint(island_op);
- auto island_sink = builder.create<tf_executor::IslandOp>(
- island_op.getLoc(),
- llvm::to_vector<8>(island_op.GetYield().fetches().getTypes()),
- tf_executor::ControlType::get(island_op.getContext()), island_operands);
- island_sink.body().push_back(new Block);
-
- // Move replicate island YieldOp over to new single island.
- island_op.GetYield().getOperation()->moveBefore(
- &island_sink.GetBody(), island_sink.GetBody().begin());
-
- // Remap island results.
- island_op.replaceAllUsesWith(island_sink);
+ builder.setInsertionPoint(island_op);
+ auto island_sink = builder.create<tf_executor::IslandOp>(
+ island_op.getLoc(), llvm::ArrayRef<Type>{},
+ tf_executor::ControlType::get(island_op.getContext()), island_operands);
+ island_sink.body().push_back(new Block);
+ builder.setInsertionPointToEnd(&island_sink.GetBody());
+ builder.create<tf_executor::YieldOp>(island_op.getLoc(),
+ llvm::ArrayRef<Value>{});
+ island_op.control().replaceAllUsesWith(island_sink.control());
+ }
island_op.erase();
return success();