Fix updating results of TPU cluster with parallel_execute results when TPU cluster results are not perfectly forwarded.
After TPUExtractHeadTailOutsideCompilation some results left of the TPU cluster may not be used by an optional tf_device.replicate op if there is data parallelism. Instead, all results should be remapped if they are used outside of parallel_execute.
PiperOrigin-RevId: 316981114
Change-Id: I5529074857e06cfe26a7141c262a6229fe848be6
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
index 6bb8e99..d88489f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
@@ -262,7 +262,6 @@
return %1 : tensor<?xi32>
}
-
// Tests extraction of a single outside compiled cluster with multiple input/output.
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
@@ -439,3 +438,24 @@
return %1 : tensor<?xi32>
}
+
+// Tests only directly used results of tpu cluster are remapped with
+// parallel_execute.
+
+// CHECK-LABEL: func @remapped_results
+func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+ // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+ // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
+ // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
+ %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+ %2:2 = "tf_device.cluster"() ( {
+ %3 = "tf.A"() : () -> (tensor<?xi32>)
+ %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
+ %5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
+ tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
+ }) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>)
+ tf_device.return %2#1 : tensor<?xi32>
+ }
+ return %1 : tensor<?xi32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 54600fa..a2a1910 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -108,18 +108,6 @@
return launch_op;
}
-// Propagates the return from `parallel_execute_op` to parent replicate
-// op if it exists.
-void PropagateParallelExecuteReturnToReplicate(
- tf_device::ParallelExecuteOp parallel_execute_op) {
- // Update the return for the parallel_execute op parent.
- auto replicate = llvm::dyn_cast_or_null<tf_device::ReplicateOp>(
- parallel_execute_op.getParentOp());
- if (replicate)
- replicate.GetBody().getTerminator()->setOperands(
- parallel_execute_op.execute_outputs());
-}
-
// Extracts all externally provided operands of `cluster_ops`.
llvm::SmallSetVector<Value, 4> GetExternalOperands(
llvm::ArrayRef<Operation*> cluster_ops) {
@@ -305,7 +293,16 @@
tpu_cluster.getOperation()->moveBefore(
parallel_execute_tpu_block.getTerminator());
- PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
+ // Remap cluster results with parallel_execute results if user is outside of
+ // parallel_execute.
+ for (auto result :
+ llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
+ Value tpu_cluster_result = std::get<0>(result);
+ Value parallel_execute_result = std::get<1>(result);
+ for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
+ if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
+ use.set(parallel_execute_result);
+ }
}
void TPUExtractOutsideCompilation::runOnFunction() {