[XLA:LAYOUT_ASSIGNMENT] propagate reshape layouts depth first if there are unmodifed dimensions.

PiperOrigin-RevId: 389295338
Change-Id: I46da325fa422feb38bb1cba4c04d0701da793af7
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 5da8b80..fe2388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1366,7 +1366,8 @@
 // A transpose or a reshape that only changes trivial dimensions have meaningful
 // layouts that are valuable to propagate in a depthfirst manner to avoid
 // unassigned layouts in the graph.
-bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
+bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
+                                          bool forward_propagation = true) {
   switch (hlo.opcode()) {
     case HloOpcode::kFusion:
       return hlo.IsCustomFusion();
@@ -1374,7 +1375,8 @@
       return true;
     case HloOpcode::kReshape:
       return hlo.operand(0)->shape().rank() == 1 ||
-             std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
+             (forward_propagation &&
+              std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
     case HloOpcode::kScatter:
     case HloOpcode::kTranspose:
       return true;
@@ -1557,7 +1559,8 @@
           if (layout != nullptr) {
             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
                 *layout, *buffer,
-                /*mandatory=*/user->opcode() == HloOpcode::kReduce));
+                /*mandatory=*/user->opcode() == HloOpcode::kReduce,
+                /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
           }
         }
         return Status::OK();
@@ -1617,7 +1620,8 @@
           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
               *operand_layout, instruction, operand_no, /*mandatory=*/false,
               /*dfs=*/
-              InstructionShouldPropagateDepthFirst(*instruction)));
+              InstructionShouldPropagateDepthFirst(
+                  *instruction, /*forward_propagation=*/false)));
         }
       } else {
         VLOG(6) << "Operand already has a constraint "