Virtualize the AddCopiesOnConditional function on copy insertion.

PiperOrigin-RevId: 355189766
Change-Id: Ib2891e1bf22312b0a8735df1612de0f9a46b1ba4
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 80511dd..89cb7a4 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -332,35 +332,6 @@
   return Status::OK();
 }
 
-// We add copies for all non-phi indices of the true and false computation
-// roots, in order to resolve interference. We later rely on
-// RemoveUnnecessaryCopies to drop the unnecessary ones.
-Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
-                               HloInstruction* conditional) {
-  VLOG(2) << "Adding copies for kConditional instruction "
-          << conditional->name();
-  ShapeTree<bool> indices_to_copy(conditional->shape());
-  TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
-  if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
-                                   conditional, &indices_to_copy)) {
-    VLOG(2) << "No copies necessary for kWhile instruction "
-            << conditional->name();
-    return Status::OK();
-  }
-  for (HloComputation* computation : conditional->branch_computations()) {
-    HloInstruction* root = computation->root_instruction();
-    std::vector<HloInstruction*> users = root->users();
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * deep_copy,
-        computation->DeepCopyInstruction(root, &indices_to_copy));
-    for (HloInstruction* user : users) {
-      TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
-    }
-    computation->set_root_instruction(deep_copy);
-  }
-  return Status::OK();
-}
-
 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
 // will remove the unnecessary copies.
 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
@@ -1006,6 +977,36 @@
 
 }  // namespace
 
+// We add copies for all non-phi indices of the true and false computation
+// roots, in order to resolve interference. We later rely on
+// RemoveUnnecessaryCopies to drop the unnecessary ones.
+Status CopyInsertion::AddCopiesForConditional(
+    const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
+  VLOG(2) << "Adding copies for kConditional instruction "
+          << conditional->name();
+  ShapeTree<bool> indices_to_copy(conditional->shape());
+  TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
+  if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
+                                   conditional, &indices_to_copy)) {
+    VLOG(2) << "No copies necessary for kWhile instruction "
+            << conditional->name();
+    return Status::OK();
+  }
+
+  for (HloComputation* computation : conditional->branch_computations()) {
+    HloInstruction* root = computation->root_instruction();
+    std::vector<HloInstruction*> users = root->users();
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * deep_copy,
+        computation->DeepCopyInstruction(root, &indices_to_copy));
+    for (HloInstruction* user : users) {
+      TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
+    }
+    computation->set_root_instruction(deep_copy);
+  }
+  return Status::OK();
+}
+
 // Add kCopy instructions to the given module to guarantee there is no
 // live-range interference. Generally interference can only occur around kWhile
 // instructions which have update-in-place semantics.
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 70394266..22dffbf 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -83,6 +83,10 @@
   virtual Status AddSpecialCaseCopies(const CallGraph& call_graph,
                                       HloModule* module);
 
+  // Add copies for conditional instructions.
+  virtual Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
+                                         HloInstruction* conditional);
+
   // Backend specific function that decides whether an instruction can share
   // buffer with its operand.
   HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 8ebb522..e984c72 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -2166,6 +2166,7 @@
 // already-bad compile errors even worse.
 XLA_VARIADIC_OP_PATTERN(AfterAll);
 XLA_VARIADIC_OP_PATTERN(Concatenate);
+XLA_VARIADIC_OP_PATTERN(Conditional);
 XLA_VARIADIC_OP_PATTERN(CustomCall);
 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
 XLA_VARIADIC_OP_PATTERN(Fusion);