[XLA] Modify layout assignment so that priorities can be set for propagating operands layout, and the layot of  dynamic update slice is controlled more by operand 0. Also added support for layout of computation parameters to be propagated when the computation traversal order has been reverted.

PiperOrigin-RevId: 377073732
Change-Id: Ic7fec101aa4131b0e9ba28ec5aa51ef5abf4170d
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 38d5841..1179256 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -69,8 +69,11 @@
 
 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
                                                const LogicalBuffer& buffer,
-                                               bool mandatory, bool dfs)
-    : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
+                                               bool mandatory, bool dfs,
+                                               int64 priority)
+    : LayoutConstraint(mandatory, dfs, priority),
+      layout_(layout),
+      buffer_(&buffer) {
   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
 }
 
@@ -81,8 +84,8 @@
 
 OperandLayoutConstraint::OperandLayoutConstraint(
     const ShapeLayout& shape_layout, const HloInstruction* instruction,
-    int64 operand_no, bool mandatory, bool dfs)
-    : LayoutConstraint(mandatory, dfs),
+    int64 operand_no, bool mandatory, bool dfs, int64 priority)
+    : LayoutConstraint(mandatory, dfs, priority),
       shape_layout_(shape_layout),
       instruction_(instruction),
       operand_no_(operand_no) {
@@ -221,10 +224,21 @@
                                            const HloInstruction* instruction,
                                            int64 operand_no, bool mandatory,
                                            bool dfs) {
+  auto priority = LayoutConstraint::kDefaultPriority;
+  // The second and third operands (operand_no > 0) of a dynamic-update-slice
+  // operation typically have much smaller sizes than the first (operand_no==0)
+  // operand. It is necessary to downgrade the importance of the smaller
+  // operands, so that the overall layout choice of the operation is dicated by
+  // operand 0 when possible.
+  if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice &&
+      operand_no > 0 && !mandatory) {
+    dfs = false;
+    priority--;
+  }
   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
           << operand_no << " : "
-          << ShapeUtil::HumanStringWithLayout(shape_with_layout);
-
+          << ShapeUtil::HumanStringWithLayout(shape_with_layout)
+          << " : priority = " << priority << "\n";
   const OperandLayoutConstraint* curr_shape_layout =
       GetOperandLayoutConstraint(instruction, operand_no);
   if (curr_shape_layout != nullptr) {
@@ -257,16 +271,28 @@
   auto iter = operand_constraints_.find(key);
   if (iter == operand_constraints_.end()) {
     auto pair = std::make_pair(
-        key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
-                                     instruction, operand_no, mandatory, dfs));
+        key,
+        OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
+                                operand_no, mandatory, dfs, priority));
     iter = operand_constraints_.insert(pair).first;
   } else {
     iter->second =
         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
-                                operand_no, mandatory, dfs);
+                                operand_no, mandatory, dfs, priority);
   }
   added_constraints_.push_back(&iter->second);
-
+  // Move the new constraint from the end of the worklist forward if its
+  // priority is higher than those ahead of it in the worklist. Assuming the
+  // worklist is already sorted, this works like a bubble sort, where the lower
+  // priority constraints are always pushed to the tail of the worklist.
+  for (int64 i = added_constraints_.size() - 2; i >= 0; --i) {
+    if (added_constraints_[i]->priority() <
+        added_constraints_[i + 1]->priority()) {
+      std::swap(added_constraints_[i + 1], added_constraints_[i]);
+    } else {
+      break;
+    }
+  }
   return Status::OK();
 }
 
@@ -451,6 +477,34 @@
   return collective != nullptr && collective->constrain_layout();
 }
 
+Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
+                                       const Shape& shape,
+                                       LayoutConstraints* constraints) {
+  for (auto* user : instruction->users()) {
+    // Excluding tuple operations as they do not participate in layout
+    // propagations (they do not create or aliase buffers).
+    if (user->opcode() == HloOpcode::kTuple) {
+      continue;
+    }
+    VLOG(3) << "Setting  user layout : " << user->ToString();
+    if (user->opcode() == HloOpcode::kGetTupleElement) {
+      auto tuple_index = user->tuple_index();
+      CHECK(shape.IsTuple());
+      auto elem_shape = shape.tuple_shapes(tuple_index);
+      TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+          elem_shape, user, /*mandatory=*/false, /*dfs=*/false,
+          /*allow_alias=*/true));
+      TF_RETURN_IF_ERROR(
+          PropagateParameterLayoutToUsers(user, elem_shape, constraints));
+    } else {
+      TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+          shape, user, user->operand_index(instruction), /*mandatory=*/false,
+          /*dfs=*/false));
+    }
+  }
+  return Status::OK();
+}
+
 }  // namespace
 
 Status LayoutAssignment::AddMandatoryConstraints(
@@ -489,6 +543,10 @@
         // ComputationLayout, if there is one.
         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
             parameter_layout.shape(), instruction));
+        if (reverse_computation_order_) {
+          TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
+              instruction, parameter_layout.shape(), constraints));
+        }
       }
     } else if (IsLayoutConstrainedCustomCall(instruction)) {
       const HloCustomCallInstruction* custom_call =
@@ -1213,6 +1271,8 @@
       if (constraint->dfs()) {
         worklist.push_front(constraint);
       } else {
+        VLOG(3) << "push back constraint for propagation : "
+                << constraint->ToString();
         worklist.push_back(constraint);
       }
     }
@@ -1223,7 +1283,8 @@
     const LayoutConstraint* layout_constraint = worklist.front();
     worklist.pop_front();
     VLOG(2) << "Propagating " << layout_constraint->ToString()
-            << " to its neighbors.";
+            << " to its neighbors with priority = "
+            << layout_constraint->priority() << "\n";
     if (auto* buffer_constraint =
             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
       TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 2cc998f..3861cad 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -50,8 +50,8 @@
 // gathered together in LayoutConstraints object.
 class LayoutConstraint {
  public:
-  LayoutConstraint(bool mandatory, bool dfs)
-      : mandatory_(mandatory), dfs_(dfs) {}
+  LayoutConstraint(bool mandatory, bool dfs, int64 priority = kDefaultPriority)
+      : mandatory_(mandatory), dfs_(dfs), priority_(priority) {}
   virtual ~LayoutConstraint() = default;
 
   virtual string ToString() const = 0;
@@ -62,9 +62,17 @@
   // When true, propagate in DFS. When false, constraint will propagate in BFS.
   bool dfs() const { return dfs_; }
 
+  // Return the priority of the current constraint. When conflicting constraints
+  // are encountered, the higher priority one should win.
+  int64 priority() const { return priority_; }
+
+  // The default priority of all constraints when not set explicitly.
+  static constexpr int64 kDefaultPriority = 1;
+
  private:
   bool mandatory_;
   bool dfs_;
+  int64 priority_;
 };
 
 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
@@ -74,7 +82,8 @@
 class BufferLayoutConstraint : public LayoutConstraint {
  public:
   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
-                         bool mandatory, bool dfs);
+                         bool mandatory, bool dfs,
+                         int64 priority = LayoutConstraint::kDefaultPriority);
 
   const LogicalBuffer& buffer() const { return *buffer_; }
   const Layout& layout() const { return layout_; }
@@ -95,7 +104,8 @@
  public:
   OperandLayoutConstraint(const ShapeLayout& shape_layout,
                           const HloInstruction* instruction, int64 operand_no,
-                          bool mandatory, bool dfs);
+                          bool mandatory, bool dfs,
+                          int64 priority = LayoutConstraint::kDefaultPriority);
 
   const ShapeLayout& shape_layout() const { return shape_layout_; }
   const HloInstruction* instruction() const { return instruction_; }
@@ -115,8 +125,9 @@
 // Constraint on the layout of the result of the entry computation.
 class ResultLayoutConstraint : public LayoutConstraint {
  public:
-  explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
-                                  bool dfs = false)
+  explicit ResultLayoutConstraint(
+      const ShapeLayout& shape_layout, bool dfs = false,
+      int64 priority = LayoutConstraint::kDefaultPriority)
       : LayoutConstraint(/*mandatory=*/true, dfs),
         shape_layout_(shape_layout) {}