[XLA] Simplify Tuple Simplifier to run in post order.

PiperOrigin-RevId: 280709537
Change-Id: I4986e5fce746013c77920d431ca271f8f2614abe
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc
index 77bdcc9..e9c1d93 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc
@@ -30,105 +30,92 @@
 
 namespace xla {
 
-TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) :
-    exclude_entry_computation_(exclude_entry_computation) {}
+TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
+    : exclude_entry_computation_(exclude_entry_computation) {}
 
 StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
   // Initially add all GTE and Tuple instructions to the worklist.
-  std::queue<HloInstruction*> worklist;
+  bool changed = false;
   for (auto* computation : module->computations()) {
     if (exclude_entry_computation_ &&
         computation == module->entry_computation()) {
       continue;
     }
-    for (auto* instruction : computation->instructions()) {
-      if (instruction->opcode() == HloOpcode::kTuple ||
-          instruction->opcode() == HloOpcode::kGetTupleElement) {
-        worklist.push(instruction);
-      }
-    }
-  }
-
-  bool changed = false;
-  while (!worklist.empty()) {
-    HloInstruction* instruction = worklist.front();
-    worklist.pop();
-
-    if (instruction->user_count() == 0 &&
-        instruction != instruction->parent()->root_instruction()) {
-      // Tuple simplification works by replacing users of optimized away
-      // instructions with a simpler form. If there is no user of the
-      // instruction (including being the root), then there is nothing to do.
-      continue;
-    }
-
-    if (instruction->opcode() == HloOpcode::kTuple) {
-      // Collapse the following structure into just 'Tuple-shaped Op':
-      //
-      //   Tuple-shaped Op
-      //         |
-      //   +-----+-----+
-      //   |     |     |
-      //  GTE   GTE   GTE
-      //   |     |     |
-      //   +-----+-----+
-      //         |
-      //       Tuple
-      //
-      HloInstruction* top_tuple = nullptr;
-      bool can_simplify = true;
-      for (int64 operand_number = 0;
-           operand_number < instruction->operand_count(); ++operand_number) {
-        HloInstruction* operand = instruction->mutable_operand(operand_number);
-        if (operand->opcode() != HloOpcode::kGetTupleElement ||
-            operand->tuple_index() != operand_number) {
-          can_simplify = false;
-          break;
-        }
-        if (top_tuple == nullptr) {
-          top_tuple = operand->mutable_operand(0);
-          if (!ShapeUtil::Compatible(top_tuple->shape(),
-                                     instruction->shape())) {
+    for (auto* instruction : computation->MakeInstructionPostOrder()) {
+      if (instruction->opcode() == HloOpcode::kTuple) {
+        // Collapse the following structure into just 'Tuple-shaped Op':
+        //
+        //   Tuple-shaped Op
+        //         |
+        //   +-----+-----+
+        //   |     |     |
+        //  GTE   GTE   GTE
+        //   |     |     |
+        //   +-----+-----+
+        //         |
+        //       Tuple
+        //
+        HloInstruction* top_tuple = nullptr;
+        bool can_simplify = true;
+        for (int64 operand_number = 0;
+             operand_number < instruction->operand_count(); ++operand_number) {
+          HloInstruction* operand =
+              instruction->mutable_operand(operand_number);
+          if (operand->opcode() != HloOpcode::kGetTupleElement ||
+              operand->tuple_index() != operand_number) {
             can_simplify = false;
             break;
           }
-        } else if (top_tuple != operand->operand(0)) {
-          can_simplify = false;
-          break;
-        }
-      }
-      if (can_simplify && top_tuple != nullptr) {
-        changed = true;
-        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(top_tuple));
-        // No need to add anything to the worklist.
-      }
-    } else {
-      CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement);
-      // If possible replace a GTE with the operation which produces the
-      // element. For example, replace uses of GTE with below with just 'Op'
-      // (assuming 'Op' is at the index of the GTE instruction):
-      //
-      //     ...  Op ...
-      //       \  |   /
-      //        Tuple
-      //          |
-      //         GTE
-      if (instruction->operand(0)->opcode() == HloOpcode::kTuple) {
-        HloInstruction* element_source =
-            instruction->mutable_operand(0)->mutable_operand(
-                instruction->tuple_index());
-        changed = true;
-        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
-        for (HloInstruction* user : element_source->users()) {
-          if (user->opcode() == HloOpcode::kTuple ||
-              user->opcode() == HloOpcode::kGetTupleElement) {
-            worklist.push(user);
+          if (top_tuple == nullptr) {
+            top_tuple = operand->mutable_operand(0);
+            if (!ShapeUtil::Compatible(top_tuple->shape(),
+                                       instruction->shape())) {
+              can_simplify = false;
+              break;
+            }
+          } else if (top_tuple != operand->operand(0)) {
+            can_simplify = false;
+            break;
           }
         }
+        if (can_simplify && top_tuple != nullptr) {
+          changed = true;
+          TF_RETURN_IF_ERROR(
+              computation->ReplaceInstruction(instruction, top_tuple));
+        }
+      } else {
+        auto ancestor = instruction->LatestNonGteAncestorAndIndex();
+        if (ancestor.first == instruction) {
+          continue;
+        }
+        // If possible replace a chain of GTE with the operation which produces
+        // the element. For example, replace uses of GTE with below with just
+        // 'Op' (assuming 'Op' is at the index of the GTE instruction):
+        //
+        //     ...  Op ...
+        //       \  |   /
+        //        Tuple
+        //          |
+        //         GTE
+        //         ...
+        //          |
+        //         GTE
+        //          |
+        //         GTE
+        if (ShapeUtil::Compatible(ancestor.first->shape(),
+                                  instruction->shape())) {
+          changed = true;
+          TF_RETURN_IF_ERROR(
+              computation->ReplaceInstruction(instruction, ancestor.first));
+        } else if (ancestor.first->opcode() == HloOpcode::kTuple) {
+          changed = true;
+          TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
+              instruction,
+              ancestor.first->mutable_operand(ancestor.second[0])));
+        }
       }
     }
   }
-
   return changed;
 }