Fix updating results of TPU cluster with parallel_execute results when TPU cluster results are not perfectly forwarded.

After TPUExtractHeadTailOutsideCompilation some results left of the TPU cluster may not be used by an optional tf_device.replicate op if there is data parallelism. Instead, all results should be remapped if they are used outside of parallel_execute.

PiperOrigin-RevId: 316981114
Change-Id: I5529074857e06cfe26a7141c262a6229fe848be6
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
index 6bb8e99..d88489f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
@@ -262,7 +262,6 @@
   return %1 : tensor<?xi32>
 }
 
-
 // Tests extraction of a single outside compiled cluster with multiple input/output.
 
 // CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
@@ -439,3 +438,24 @@
 
   return %1 : tensor<?xi32>
 }
+
+// Tests only directly used results of tpu cluster are remapped with
+// parallel_execute.
+
+// CHECK-LABEL: func @remapped_results
+func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+  %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+  // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+  // CHECK:   %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
+  // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
+  %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+    %2:2 = "tf_device.cluster"() ( {
+      %3 = "tf.A"() : () -> (tensor<?xi32>)
+      %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
+      %5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
+      tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
+    }) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>)
+    tf_device.return %2#1 : tensor<?xi32>
+  }
+  return %1 : tensor<?xi32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 54600fa..a2a1910 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -108,18 +108,6 @@
   return launch_op;
 }
 
-// Propagates the return from `parallel_execute_op` to parent replicate
-// op if it exists.
-void PropagateParallelExecuteReturnToReplicate(
-    tf_device::ParallelExecuteOp parallel_execute_op) {
-  // Update the return for the parallel_execute op parent.
-  auto replicate = llvm::dyn_cast_or_null<tf_device::ReplicateOp>(
-      parallel_execute_op.getParentOp());
-  if (replicate)
-    replicate.GetBody().getTerminator()->setOperands(
-        parallel_execute_op.execute_outputs());
-}
-
 // Extracts all externally provided operands of `cluster_ops`.
 llvm::SmallSetVector<Value, 4> GetExternalOperands(
     llvm::ArrayRef<Operation*> cluster_ops) {
@@ -305,7 +293,16 @@
   tpu_cluster.getOperation()->moveBefore(
       parallel_execute_tpu_block.getTerminator());
 
-  PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
+  // Remap cluster results with parallel_execute results if user is outside of
+  // parallel_execute.
+  for (auto result :
+       llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
+    Value tpu_cluster_result = std::get<0>(result);
+    Value parallel_execute_result = std::get<1>(result);
+    for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
+      if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
+        use.set(parallel_execute_result);
+  }
 }
 
 void TPUExtractOutsideCompilation::runOnFunction() {