[XLA:LAYOUT_ASSIGNMENT] propagate reshape layouts depth first if there are unmodifed dimensions.
PiperOrigin-RevId: 389295338
Change-Id: I46da325fa422feb38bb1cba4c04d0701da793af7
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 5da8b80..fe2388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1366,7 +1366,8 @@
// A transpose or a reshape that only changes trivial dimensions have meaningful
// layouts that are valuable to propagate in a depthfirst manner to avoid
// unassigned layouts in the graph.
-bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
+bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
+ bool forward_propagation = true) {
switch (hlo.opcode()) {
case HloOpcode::kFusion:
return hlo.IsCustomFusion();
@@ -1374,7 +1375,8 @@
return true;
case HloOpcode::kReshape:
return hlo.operand(0)->shape().rank() == 1 ||
- std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
+ (forward_propagation &&
+ std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
case HloOpcode::kScatter:
case HloOpcode::kTranspose:
return true;
@@ -1557,7 +1559,8 @@
if (layout != nullptr) {
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
*layout, *buffer,
- /*mandatory=*/user->opcode() == HloOpcode::kReduce));
+ /*mandatory=*/user->opcode() == HloOpcode::kReduce,
+ /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
}
}
return Status::OK();
@@ -1617,7 +1620,8 @@
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
*operand_layout, instruction, operand_no, /*mandatory=*/false,
/*dfs=*/
- InstructionShouldPropagateDepthFirst(*instruction)));
+ InstructionShouldPropagateDepthFirst(
+ *instruction, /*forward_propagation=*/false)));
}
} else {
VLOG(6) << "Operand already has a constraint "