Virtualize the AddCopiesOnConditional function on copy insertion.
PiperOrigin-RevId: 355189766
Change-Id: Ib2891e1bf22312b0a8735df1612de0f9a46b1ba4
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 80511dd..89cb7a4 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -332,35 +332,6 @@
return Status::OK();
}
-// We add copies for all non-phi indices of the true and false computation
-// roots, in order to resolve interference. We later rely on
-// RemoveUnnecessaryCopies to drop the unnecessary ones.
-Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
- HloInstruction* conditional) {
- VLOG(2) << "Adding copies for kConditional instruction "
- << conditional->name();
- ShapeTree<bool> indices_to_copy(conditional->shape());
- TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
- if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
- conditional, &indices_to_copy)) {
- VLOG(2) << "No copies necessary for kWhile instruction "
- << conditional->name();
- return Status::OK();
- }
- for (HloComputation* computation : conditional->branch_computations()) {
- HloInstruction* root = computation->root_instruction();
- std::vector<HloInstruction*> users = root->users();
- TF_ASSIGN_OR_RETURN(
- HloInstruction * deep_copy,
- computation->DeepCopyInstruction(root, &indices_to_copy));
- for (HloInstruction* user : users) {
- TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
- }
- computation->set_root_instruction(deep_copy);
- }
- return Status::OK();
-}
-
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
// will remove the unnecessary copies.
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
@@ -1006,6 +977,36 @@
} // namespace
+// We add copies for all non-phi indices of the true and false computation
+// roots, in order to resolve interference. We later rely on
+// RemoveUnnecessaryCopies to drop the unnecessary ones.
+Status CopyInsertion::AddCopiesForConditional(
+ const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
+ VLOG(2) << "Adding copies for kConditional instruction "
+ << conditional->name();
+ ShapeTree<bool> indices_to_copy(conditional->shape());
+ TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
+ if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
+ conditional, &indices_to_copy)) {
+ VLOG(2) << "No copies necessary for kWhile instruction "
+ << conditional->name();
+ return Status::OK();
+ }
+
+ for (HloComputation* computation : conditional->branch_computations()) {
+ HloInstruction* root = computation->root_instruction();
+ std::vector<HloInstruction*> users = root->users();
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * deep_copy,
+ computation->DeepCopyInstruction(root, &indices_to_copy));
+ for (HloInstruction* user : users) {
+ TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
+ }
+ computation->set_root_instruction(deep_copy);
+ }
+ return Status::OK();
+}
+
// Add kCopy instructions to the given module to guarantee there is no
// live-range interference. Generally interference can only occur around kWhile
// instructions which have update-in-place semantics.
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 70394266..22dffbf 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -83,6 +83,10 @@
virtual Status AddSpecialCaseCopies(const CallGraph& call_graph,
HloModule* module);
+ // Add copies for conditional instructions.
+ virtual Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
+ HloInstruction* conditional);
+
// Backend specific function that decides whether an instruction can share
// buffer with its operand.
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 8ebb522..e984c72 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -2166,6 +2166,7 @@
// already-bad compile errors even worse.
XLA_VARIADIC_OP_PATTERN(AfterAll);
XLA_VARIADIC_OP_PATTERN(Concatenate);
+XLA_VARIADIC_OP_PATTERN(Conditional);
XLA_VARIADIC_OP_PATTERN(CustomCall);
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
XLA_VARIADIC_OP_PATTERN(Fusion);