[XLA] Use the second iteration of layout assignment to propagate channel
constraints instead of trying to apply them at the end of each pass.

PiperOrigin-RevId: 276819747
Change-Id: I41f7508f8bb49c6d6515eb2d16fce48cbc625c11
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 0eb912a..81a42de 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1913,53 +1913,13 @@
   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
     if (instruction->opcode() == HloOpcode::kSend) {
       HloInstruction* operand = instruction->mutable_operand(0);
-      const Layout* layout = get_channel_constraints(instruction)
-                                 ->ConstrainChannel(*instruction->channel_id(),
-                                                    operand->shape().layout());
-      if (layout != nullptr) {
-        // We found an already constrained layout which does not match the one
-        // the kSend wants to impose. Either add a new kCopy, or use the
-        // existing one to marshal the correct shape.
-        Shape shape = operand->shape();
-        *shape.mutable_layout() = *layout;
-        if (operand->opcode() != HloOpcode::kCopy) {
-          HloInstruction* copy = operand->parent()->AddInstruction(
-              HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
-          RegisterAddedCopy(copy);
-          SetupCopiedInstruction(*operand, copy, {});
-          TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
-          operand = copy;
-        } else {
-          *operand->mutable_shape() = shape;
-        }
-        Shape* send_shape =
-            ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
-        *send_shape = shape;
-      }
+      get_channel_constraints(instruction)
+          ->ConstrainChannel(*instruction->channel_id(),
+                             operand->shape().layout());
     } else if (instruction->IsCrossModuleAllReduce()) {
-      const Layout* layout =
-          get_channel_constraints(instruction)
-              ->ConstrainChannel(instruction->channel_id().value(),
-                                 instruction->shape().layout());
-      if (layout != nullptr) {
-        // We found an already constrained layout which does not match the one
-        // the channel wants to impose. Either add a new kCopy, or use the
-        // existing one to marshal the correct shape.
-        HloInstruction* operand = instruction->mutable_operand(0);
-        Shape shape = operand->shape();
-        *shape.mutable_layout() = *layout;
-        if (operand->opcode() != HloOpcode::kCopy) {
-          HloInstruction* copy = operand->parent()->AddInstruction(
-              HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
-          RegisterAddedCopy(copy);
-          SetupCopiedInstruction(*operand, copy, {});
-          TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
-          operand = copy;
-        } else {
-          *operand->mutable_shape() = shape;
-        }
-        *instruction->mutable_shape() = shape;
-      }
+      get_channel_constraints(instruction)
+          ->ConstrainChannel(instruction->channel_id().value(),
+                             instruction->shape().layout());
     }
   }
   return Status::OK();
@@ -2035,6 +1995,22 @@
   TF_RETURN_IF_ERROR(Init());
   call_graph_ = CallGraph::Build(module);
   auto computations = module->computations();
+
+  // Add copy to the operand of Send instructions, since we cannot call
+  // SetOperandLayout on Send instructions as it aliases its input to the
+  // output.
+  //
+  // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
+  // operand buffers that aliases with the output.
+  for (HloComputation* computation : module->computations()) {
+    for (HloInstruction* instruction :
+         computation->MakeInstructionPostOrder()) {
+      if (instruction->opcode() == HloOpcode::kSend) {
+        TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
+      }
+    }
+  }
+
   // Clone Conditional computations with multiple callsites.
   for (HloComputation* computation : computations) {
     CallGraphNode& node = call_graph_->GetNode(computation);
@@ -2274,7 +2250,6 @@
     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
     TF_RETURN_IF_ERROR(dce.Run(module).status());
   }
-  ResetChannelConstraints();
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7d5a3b6..f9d0adb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -869,11 +869,8 @@
   ChannelLayoutConstraints channel_constraints;
   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
 
-  EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
-  EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0));
-  EXPECT_TRUE(ShapeUtil::Equal(
-      ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}),
-      ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+  EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
+                               FindInstruction(m.get(), "recv")->shape()));
 }
 
 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {