`TPUReshardVariables` is inserted even when there is outside compilation. Before this change, all models with a `parallel_execute` were skipped, because `TPUReshardVariables` does not work with model parallelism. Now model parallelism detection is precise: It's detected when there is more than 1 `TPUExecuteAndUpdateVariables`.
PiperOrigin-RevId: 463647960
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
index 4d85456..5594918 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
@@ -160,7 +160,8 @@
// -----
// Tests that the pass does not format variables when model parallelism is
-// present.
+// present. Model parallelism is present when there is more than 1
+// TPUExecuteAndUpdateVariables in a parallel_execute.
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
@@ -168,7 +169,8 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
- func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
+ func.func @main(
+ %arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
@@ -197,7 +199,8 @@
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
- %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
+ %rep:2 = tf_device.replicate(
+ [%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
[%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
@@ -210,13 +213,120 @@
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
}, {
+ "tf_device.launch"() ({
+ "tf.TPUExecuteAndUpdateVariables"(%compile#1)
+ {device_var_reads_indices = [], device_var_updates_indices = []}
+ : (tensor<2x!tf_type.string>) -> ()
+ tf_device.return
+ }) {device = "TPU_REPLICATED_CORE_1"} : () -> ()
+ tf_device.return
+ }) {} : () -> ()
+ %ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ tf_device.return %ret : tensor<i32>
+ }
+ "tf.Yield"(%b1) : (tensor<i32>) -> ()
+ }) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
+ func.return
+ }
+}
+
+// -----
+
+// Tests that the pass formats variables when there is a parallel_execute but
+// only 1 TPUExecuteAndUpdateVariables. This can happen when there is outside
+// compilation and no model parallelism.
+
+!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
+!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+ // CHECK-LABEL: func @main
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
+ // CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
+ // CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
+ func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
+ %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
+ %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
+ %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
+
+ %0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
+ // CHECK-SAME: device = "/device:TPU:0"
+ // CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
+ // CHECK-SAME: device = "/device:TPU:1"
+ // CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
+ %1 = "tf.WhileRegion"(%0) ({
+ // Condition region
+ // CHECK: ^bb
+ // CHECK: "tf.Yield"
+ ^bb0(%carg0: tensor<i32>):
+ %c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "tf.Yield"(%c1) : (tensor<i1>) -> ()
+ }, {
+ // Body region
+ // CHECK: ^bb0
+ ^bb0(%barg0: tensor<i32>):
+ %b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
+ // CHECK-NEXT: "tf._TPUCompileMlir"()
+ %compile:2 = "tf_device.launch"() ({
+ %b2:2 = "tf._TPUCompileMlir"() {
+ NumDynamicShapes = 0 : i64,
+ // The metadata encodes 2 parameter and 2 return values.
+ metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
+ mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
+ tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
+ }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
+ "tf_device.launch"() ({
+ "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
+ tf_device.return
+ }) {device = "/device:CPU:0"} : () -> ()
+ // CHECK: tf_device.replicate
+ // CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
+ // CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
+ // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
+ // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
+ %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
+ [%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
+ {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
+ // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
+ %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
+ // CHECK: "tf_device.parallel_execute"
+ // CHECK: "tf_device.launch"
+ // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
+ // CHECK-NEXT: tf_device.return
+ // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
+ // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
+ "tf_device.parallel_execute"() ({
+ "tf_device.launch"() ({
+ "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
+ {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
+ : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf_type.string>) -> ()
+ tf_device.return
+ }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
+ tf_device.return
+ }, {
tf_device.return
}) {} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
+ // CHECK: "tf.Yield"
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
+ // CHECK: %[[DEFAULT:.*]] = "tf.Const"()
+ // CHECK: tf_device.replicate
+ // CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
+ // CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
+ // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
+ // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
+ // CHECK: "tf_device.launch"
+ // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
+ // CHECK-NEXT: tf_device.return
+ // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
func.return
}
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 08dd3e0..462b93e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -336,23 +336,26 @@
tf_device::ReplicateOp replicate) {
int64_t num_replicas = replicate.n();
if (num_replicas == 1) return;
+
+ // Set execute_launch when there is exactly one LaunchOp with a
+ // TPUExecuteAndUpdateVariablesOp. More than one means there is model
+ // parallelism, which is not supported with TPUReshardVariables. None
+ // means there is no TPU computation.
tf_device::LaunchOp execute_launch;
- for (auto execute_launch_op :
- replicate.GetBody().getOps<tf_device::LaunchOp>()) {
+ replicate.walk([&](tf_device::LaunchOp execute_launch_op) {
if (!execute_launch_op.WrapsSingleOp() ||
!llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
execute_launch_op.GetBody().front()))
- continue;
-
+ return WalkResult::advance();
if (execute_launch == nullptr) {
execute_launch = execute_launch_op;
- } else {
- // We only support one execute op inside replicate.
- execute_launch = nullptr;
- break;
+ return WalkResult::advance();
}
- }
+ execute_launch = nullptr;
+ return WalkResult::interrupt();
+ });
if (!execute_launch) return;
+
auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
execute_launch.GetBody().front());
auto compile =
@@ -485,11 +488,7 @@
replicate = nullptr;
return WalkResult::interrupt();
});
- // Model parallelism is not supported, and can be detected when a
- // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
- if (replicate &&
- replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
- HandleReplicateOp(while_op, replicate);
+ if (replicate) HandleReplicateOp(while_op, replicate);
});
}