[XLA] Use the second iteration of layout assignment to propagate channel
constraints instead of trying to apply them at the end of each pass.
PiperOrigin-RevId: 276819747
Change-Id: I41f7508f8bb49c6d6515eb2d16fce48cbc625c11
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 0eb912a..81a42de 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1913,53 +1913,13 @@
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
HloInstruction* operand = instruction->mutable_operand(0);
- const Layout* layout = get_channel_constraints(instruction)
- ->ConstrainChannel(*instruction->channel_id(),
- operand->shape().layout());
- if (layout != nullptr) {
- // We found an already constrained layout which does not match the one
- // the kSend wants to impose. Either add a new kCopy, or use the
- // existing one to marshal the correct shape.
- Shape shape = operand->shape();
- *shape.mutable_layout() = *layout;
- if (operand->opcode() != HloOpcode::kCopy) {
- HloInstruction* copy = operand->parent()->AddInstruction(
- HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
- RegisterAddedCopy(copy);
- SetupCopiedInstruction(*operand, copy, {});
- TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
- operand = copy;
- } else {
- *operand->mutable_shape() = shape;
- }
- Shape* send_shape =
- ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
- *send_shape = shape;
- }
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(*instruction->channel_id(),
+ operand->shape().layout());
} else if (instruction->IsCrossModuleAllReduce()) {
- const Layout* layout =
- get_channel_constraints(instruction)
- ->ConstrainChannel(instruction->channel_id().value(),
- instruction->shape().layout());
- if (layout != nullptr) {
- // We found an already constrained layout which does not match the one
- // the channel wants to impose. Either add a new kCopy, or use the
- // existing one to marshal the correct shape.
- HloInstruction* operand = instruction->mutable_operand(0);
- Shape shape = operand->shape();
- *shape.mutable_layout() = *layout;
- if (operand->opcode() != HloOpcode::kCopy) {
- HloInstruction* copy = operand->parent()->AddInstruction(
- HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
- RegisterAddedCopy(copy);
- SetupCopiedInstruction(*operand, copy, {});
- TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
- operand = copy;
- } else {
- *operand->mutable_shape() = shape;
- }
- *instruction->mutable_shape() = shape;
- }
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->channel_id().value(),
+ instruction->shape().layout());
}
}
return Status::OK();
@@ -2035,6 +1995,22 @@
TF_RETURN_IF_ERROR(Init());
call_graph_ = CallGraph::Build(module);
auto computations = module->computations();
+
+ // Add copy to the operand of Send instructions, since we cannot call
+ // SetOperandLayout on Send instructions as it aliases its input to the
+ // output.
+ //
+ // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
+ // operand buffers that aliases with the output.
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (instruction->opcode() == HloOpcode::kSend) {
+ TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
+ }
+ }
+ }
+
// Clone Conditional computations with multiple callsites.
for (HloComputation* computation : computations) {
CallGraphNode& node = call_graph_->GetNode(computation);
@@ -2274,7 +2250,6 @@
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
- ResetChannelConstraints();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7d5a3b6..f9d0adb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -869,11 +869,8 @@
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
+ FindInstruction(m.get(), "recv")->shape()));
}
TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {