[XLA] Create explicit phi graph optimization during dataflow analysis.

Previously hlo dataflow analysis tries to create phi and remove phi at
the same time during propagation, which leads to several cases of
deadlocks, some are hard to fix. This cl changes this process to two phases:

1. During value propagation, Dataflow analysis always create phi
values once it see multiple inputs merging at the same point. It then
records those phi values as well as their inputs in a phi graph.

2. Post value propagation, Dataflow analysis can then do certain
optimization on the phi graph to prune uncessary phi nodes.

Both of the functions are guaranteed to exit thus we can avoid
deadlocks.

PiperOrigin-RevId: 301449515
Change-Id: I85f545ed9935ad5aee85b3f5bc05c2ba19da074a
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 98851fd..925afd6 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2763,6 +2763,7 @@
         ":call_graph",
         ":hlo",
         ":hlo_casting_utils",
+        ":hlo_phi_graph",
         ":hlo_value",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status",
@@ -2771,10 +2772,13 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -2783,6 +2787,7 @@
     name = "hlo_dataflow_analysis_test",
     srcs = ["hlo_dataflow_analysis_test.cc"],
     deps = [
+        ":flatten_call_graph",
         ":hlo",
         ":hlo_creation_utils",
         ":hlo_dataflow_analysis",
@@ -2804,6 +2809,53 @@
 )
 
 cc_library(
+    name = "hlo_phi_graph",
+    srcs = ["hlo_phi_graph.cc"],
+    hdrs = ["hlo_phi_graph.h"],
+    deps = [
+        ":call_graph",
+        ":hlo",
+        ":hlo_casting_utils",
+        ":hlo_value",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_cc_test(
+    name = "hlo_phi_graph_test",
+    srcs = ["hlo_phi_graph_test.cc"],
+    deps = [
+        ":hlo",
+        ":hlo_dataflow_analysis",
+        ":hlo_graph_dumper",
+        ":hlo_matchers",
+        ":hlo_ordering",
+        ":hlo_phi_graph",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+    ],
+)
+
+cc_library(
     name = "hlo_replication_analysis",
     srcs = ["hlo_replication_analysis.cc"],
     hdrs = ["hlo_replication_analysis.h"],
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 36da176..6a0b9e5 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -17,18 +17,23 @@
 
 #include <algorithm>
 #include <queue>
+#include <string>
 #include <vector>
 
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_value.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -118,10 +123,11 @@
 }
 
 void HloDataflowAnalysis::DeleteMarkedValues() {
-#ifndef NDEBUG
-  // Verify that no marked-for-deletion values are in any of the value sets.
+  // Use a set to prevent deleting an id twice.
   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
                                            value_ids_to_delete_.end());
+#ifndef NDEBUG
+  // Verify that no marked-for-deletion values are in any of the value sets.
   for (const auto& pair : value_sets_) {
     const HloInstruction* instruction = pair.first;
     const InstructionValueSet& instruction_value_set = pair.second;
@@ -138,7 +144,7 @@
   }
 #endif
 
-  for (HloValue::Id value_id : value_ids_to_delete_) {
+  for (HloValue::Id value_id : id_set) {
     values_.erase(value_id);
   }
   value_ids_to_delete_.clear();
@@ -216,22 +222,13 @@
     const HloValue* current_value =
         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
 
-    // Construct a vector of unique value IDs of the inputs.
-    // Don't add value ids where the input is equal to the definition.
+    // Construct a vector of value IDs of the inputs.
     std::vector<HloValue::Id> input_value_ids;
     for (const InstructionValueSet* input : inputs) {
       for (const HloValue* value : input->element(index).values()) {
-        if (value->defining_instruction() == instruction &&
-            value->defining_index() == index) {
-          continue;
-        }
         input_value_ids.push_back(value->id());
       }
     }
-    absl::c_sort(input_value_ids);
-    input_value_ids.erase(
-        std::unique(input_value_ids.begin(), input_value_ids.end()),
-        input_value_ids.end());
 
     // Remove the existing phi value (if it exists). The phi can be its own
     // input, for example, in while body parameters where the body passes
@@ -240,14 +237,7 @@
         (current_value != nullptr &&
          current_value->defining_instruction() == instruction &&
          current_value->defining_index() == index);
-    if (current_value_defined_here) {
-      VLOG(5) << "current_value_defined_here: " << current_value->ToString();
-      CHECK(current_value->is_phi());
-      auto it = absl::c_find(input_value_ids, current_value->id());
-      if (it != input_value_ids.end()) {
-        input_value_ids.erase(it);
-      }
-    }
+
     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
     if (input_value_ids.empty()) {
       // A value set which has at least one element should never have its value
@@ -277,11 +267,33 @@
       // Multiple distinct values reach this point. A phi value is
       // necessary.
       CHECK_GT(input_value_ids.size(), 1);
-      if (current_value == nullptr ||
-          !(current_value->is_phi() && current_value_defined_here)) {
+      bool phi_defined_here =
+          current_value_defined_here && current_value->is_phi();
+      if (current_value == nullptr || !phi_defined_here) {
         value_set.Clear();
         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
+
+        std::vector<HloValue*> inputs;
+        inputs.reserve(input_value_ids.size());
+        for (HloValue::Id id : input_value_ids) {
+          inputs.push_back(&GetValue(id));
+        }
+        // Register the phi into phi graph.
+        phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
         changed = true;
+      } else if (phi_defined_here) {
+        std::vector<HloValue*> new_inputs;
+        new_inputs.reserve(input_value_ids.size());
+        for (HloValue::Id id : input_value_ids) {
+          new_inputs.push_back(&GetValue(id));
+        }
+
+        if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
+          VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
+          // Update phi inputs.
+          phi_graph_.RegisterPhi(*current_value, new_inputs);
+          changed = true;
+        }
       }
     }
   }
@@ -564,9 +576,9 @@
       CHECK_EQ(parameter->parameter_number(), 0);
       inputs.push_back(
           &GetInstructionValueSet(callsite.instruction()->operand(0)));
-      // If the parameter *is* the root, then don't consider it's current state
-      // (InstructionValueSet) as we are recomputing its current
-      // state. Otherwise, the parameter state would never be updated.
+      // If the parameter *is not* the root, parameter state would be
+      // updated by the root, otherwise don't consider it's current state
+      // (InstructionValueSet) as we are recomputing its current state.
       if (parameter !=
           callsite.instruction()->while_body()->root_instruction()) {
         inputs.push_back(&GetInstructionValueSet(
@@ -599,7 +611,6 @@
                     "called from call, while, or conditional instructions";
     }
   }
-
   if (ssa_form_ && need_phi) {
     return Phi(parameter, inputs);
   } else {
@@ -722,10 +733,18 @@
       add_to_worklist(instruction);
     }
   }
+  VLOG(1) << "SSA_FORM_: " << ssa_form_;
 
   while (!worklist.empty()) {
     HloInstruction* instruction = worklist.front();
+    auto add_to_worklist = [&](HloInstruction* todo) {
+      if (workset.insert(todo).second) {
+        VLOG(1) << "  Adding todo : " << todo->name();
+        worklist.push(todo);
+      }
+    };
     worklist.pop();
+
     workset.erase(workset.find(instruction));
 
     VLOG(3) << "Worklist top: " << instruction->name();
@@ -913,6 +932,43 @@
   return Status::OK();
 }
 
+void HloDataflowAnalysis::OptimizePhiValues() {
+  // Only applicable to SSA form where phis are defined.
+  if (!ssa_form_) {
+    return;
+  }
+
+  VLOG(1) << "Before phi graph optimization";
+  XLA_VLOG_LINES(1, phi_graph_.ToString());
+  phi_graph_.Optimize();
+  VLOG(1) << "After phi graph optimization";
+  XLA_VLOG_LINES(1, phi_graph_.ToString());
+
+  for (const HloComputation* computation : module_.computations()) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      InstructionValueSet& instruction_value_set =
+          GetInstructionValueSet(instruction);
+      VLOG(1) << "inst: " << instruction->name();
+      VLOG(1) << instruction_value_set.ToString();
+      instruction_value_set.ForEachMutableElement(
+          [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
+            auto values = value_set->values();
+            if (!(values.size() == 1 && values[0]->is_phi())) {
+              return;
+            }
+            HloValue::Id phi_id = values[0]->id();
+            HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
+            if (new_id != phi_id) {
+              value_set->Clear();
+              const HloValue& new_value = GetValue(new_id);
+              value_set->AddValue(&new_value);
+              MarkValueForDeletion(phi_id);
+            }
+          });
+    }
+  }
+}
+
 /* static */
 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
@@ -925,6 +981,7 @@
 
   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
   dataflow_analysis->Propagate();
+  dataflow_analysis->OptimizePhiValues();
 
   // Delete all values marked for deletion.
   dataflow_analysis->DeleteMarkedValues();
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 294ffea..75bcf7e 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -20,15 +20,19 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
 
+#include <iterator>
 #include <memory>
 #include <string>
 #include <unordered_map>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_value.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status.h"
@@ -60,7 +64,8 @@
   //     SSA form is minimal in that a new phi value is defined only if the
   //     merge point is reachable by multiple different values. The SSA form is
   //     also in loop-closed form in that no values defined inside of a loop
-  //     (while body) is used outside of the loop.
+  //     (while body) is used outside of the loop. Example use of this ssa_form
+  //     mode is to reason about live range interference of buffers.
   //
   //     If ssa_form is false, then merge points do not define new
   //     values. Rather, the HloValueSet for the merge point contains the union
@@ -138,8 +143,8 @@
   // Returns true if 'user' cannot possibly use the buffer at 'index' in
   // 'operand'. Returns false otherwise.
   //
-  // 'operand' does not have to be an operand of 'user'. This can be the case
-  // with indirect uses.
+  // 'operand' does not have to be an operand of 'user'. This can be the
+  // case with indirect uses.
   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
                                const ShapeIndex& index,
                                const HloInstruction* user) const;
@@ -160,9 +165,22 @@
                       bool bitcast_defines_value = false,
                       const CanShareBuffer& can_share_buffer = nullptr);
 
+  // 1. During value propagation (Propagate function), always create phi
+  // values once it see multiple inputs merging at the same point. It then
+  // records those phi values as well as their inputs in a phi graph.
+  //
+  // 2. Post value propagation, Dataflow analysis can then do certain
+  // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi
+  // nodes.
+  //
+  // Note that this applies in SSA form, and Both of the functions are
+  // guaranteed to exit.
+  //
+  void OptimizePhiValues();
+
   // Returns a new HloValue defined at the given instruction and shape index.
   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
-                        bool is_phi = false);
+                        bool is_phi);
 
   // Marks the HloValue with the given ID for deletion.
   void MarkValueForDeletion(HloValue::Id value_id);
@@ -248,6 +266,9 @@
   // The Id to use for the next HloValue.
   HloValue::Id next_value_id_ = 0;
 
+  // An explicit graph holding phi values and edges.
+  PhiGraph phi_graph_;
+
   // Backend specific function that decides whether an instruction can share
   // a buffer with its operand.
   CanShareBuffer can_share_buffer_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 074d14f..1bbbb24 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
 
 #include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -50,6 +51,8 @@
   // reference to the generated analysis stored in analysis_.
   const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
                                          bool bitcast_defines_value = false) {
+    FlattenCallGraph flatten;
+    EXPECT_TRUE(flatten.Run(module_.get()).ok());
     analysis_ =
         HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
             .ConsumeValueOrDie();
@@ -299,102 +302,6 @@
   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
 }
 
-TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
-  // Test a subcomputation which is called twice with identical values.
-  auto subbuilder = HloComputation::Builder("Subcomputation");
-  auto subparam0 = subbuilder.AddInstruction(
-      HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
-  auto subparam1 = subbuilder.AddInstruction(
-      HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
-  auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
-      scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
-  HloComputation* called_computation =
-      module_->AddEmbeddedComputation(subbuilder.Build());
-
-  auto builder = HloComputation::Builder(TestName());
-  auto constant1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
-  auto constant2 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
-  auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
-      scalar_shape_, {constant1, constant2}, called_computation));
-  auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
-      scalar_shape_, {constant1, constant2}, called_computation));
-  auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
-      scalar_shape_, HloOpcode::kSubtract, call1, call2));
-  module_->AddEntryComputation(builder.Build());
-  SCOPED_TRACE(module_->ToString());
-
-  bool ssa_form = GetParam();
-  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
-
-  EXPECT_EQ(analysis.values().size(), 4);
-
-  // Definitions should be identical to the single callsite case.
-  EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
-  EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
-  EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
-  EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
-
-  EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
-              UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
-                                   HloUse{add, 0, {}}));
-  EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
-              UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
-                                   HloUse{add, 1, {}}));
-  // The Add from the subcomputation is used as both operands of the Subtract.
-  EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
-              UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
-
-  EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
-  EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
-}
-
-TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
-  // Test a subcomputation which is called twice with different argument values.
-  auto subbuilder = HloComputation::Builder("Subcomputation");
-  auto subparam0 = subbuilder.AddInstruction(
-      HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
-  auto subparam1 = subbuilder.AddInstruction(
-      HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
-  auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
-      scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
-  HloComputation* called_computation =
-      module_->AddEmbeddedComputation(subbuilder.Build());
-
-  auto builder = HloComputation::Builder(TestName());
-  auto constant1 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
-  auto constant2 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
-  auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
-      scalar_shape_, {constant1, constant2}, called_computation));
-  auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
-      scalar_shape_, {call1, constant2}, called_computation));
-  module_->AddEntryComputation(builder.Build());
-  SCOPED_TRACE(module_->ToString());
-
-  bool ssa_form = GetParam();
-  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
-
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
-
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
-
-  EXPECT_THAT(HloValuesAt(subparam0),
-              UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
-                                   analysis.GetValueDefinedAt(add)));
-  EXPECT_THAT(HloValuesAt(subparam1),
-              UnorderedElementsAre(analysis.GetValueDefinedAt(constant2)));
-
-  EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
-}
-
 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
   // Test a module with nested computations. HLO is:
   //
@@ -637,6 +544,100 @@
   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
 }
 
+TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) {
+  // Test nested while instructions. The level0 body (most inner while) and
+  // level1 body pass through the parameter, while level2 (most outer while)
+  // modifies it.
+  //
+  // level0_body((F32[]) %tuple_param):
+  //   return Tuple(%tuple_param{0})
+  //
+  // level1_body((F32[]) %tuple_param):
+  //   return While(%tuple_param{0}), body=level0
+  //
+  // level2_body((F32[]) %tuple_param):
+  //   while = While(%tuple_param{0}), body=level1
+  //.  return negate(%while{0})
+  //
+  // entry:
+  //   %constant = Constant(1.0)
+  //   %tuple = Tuple(%constant)
+  //   return While(%tuple), body=level2
+  //
+  const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
+  auto cond_builder = HloComputation::Builder("condition");
+  cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  cond_builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  HloComputation* condition =
+      module_->AddEmbeddedComputation(cond_builder.Build());
+
+  // level 0 passes transparently through the body.
+  auto level0_builder = HloComputation::Builder("level0_body");
+  auto level0_param = level0_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  auto level0_element_0 = level0_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0));
+  auto level0_root = level0_builder.AddInstruction(
+      HloInstruction::CreateTuple({level0_element_0}));
+  HloComputation* level0_body =
+      module_->AddEmbeddedComputation(level0_builder.Build());
+
+  // Element 1 passes transparently through the body.
+  auto level1_builder = HloComputation::Builder("level1_body");
+  auto level1_param = level1_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile(
+      tuple_shape, condition, level0_body, level1_param));
+  HloComputation* level1_body =
+      module_->AddEmbeddedComputation(level1_builder.Build());
+
+  // Element 1 passes transparently through the body.
+  auto level2_builder = HloComputation::Builder("level2_body");
+  auto level2_param = level2_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile(
+      tuple_shape, condition, level1_body, level2_param));
+  auto level2_element_0 = level2_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0));
+  auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary(
+      scalar_shape_, HloOpcode::kNegate, level2_element_0));
+  level2_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
+  HloComputation* level2_body =
+      module_->AddEmbeddedComputation(level2_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
+  builder.AddInstruction(
+      HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple));
+  module_->AddEntryComputation(builder.Build());
+  SCOPED_TRACE(module_->ToString());
+
+  bool ssa_form = GetParam();
+  if (!ssa_form) {
+    return;
+  }
+  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+  // Phi node on inner parameters and roots should have been eliminated.
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0}));
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0}));
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0}));
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0}));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0}));
+  EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}),
+            HloValuesAt(level2_param, /*index=*/{0}));
+  EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}),
+            HloValuesAt(level2_param, /*index=*/{0}));
+  EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}),
+            HloValuesAt(level2_param, /*index=*/{0}));
+  EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}),
+            HloValuesAt(level2_param, /*index=*/{0}));
+}
+
 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
   // Test nested while instructions. The inner body passes through element 0 of
   // its parameter, and the outer body passes through element 1.  HLO:
@@ -757,6 +758,58 @@
   }
 }
 
+TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) {
+  // Test a while instruction with a body which permutes it's tuple parameter
+  // elements. HLO:
+  //
+  // body((F32[], F32[]) %tuple_param):
+  //   return Tuple(%tuple_param{1}, %tuple_param{0})
+  //
+  // condition((F32[], F32[]) %tuple_param):
+  //   return Constant(false)
+  //
+  // entry:
+  //   %constant1 = Constant(1.0)
+  //   %tuple = Tuple(%constant1, %constant1)
+  //   return While(%tuple, body, condition)
+  //
+  const Shape tuple_shape =
+      ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+  auto body_builder = HloComputation::Builder("body");
+  auto body_param = body_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  auto body_element_0 = body_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
+  auto body_element_1 = body_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
+  body_builder.AddInstruction(
+      HloInstruction::CreateTuple({body_element_1, body_element_0}));
+  HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
+
+  auto cond_builder = HloComputation::Builder("condition");
+  cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "param"));
+  cond_builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  HloComputation* condition =
+      module_->AddEmbeddedComputation(cond_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+  auto tuple = builder.AddInstruction(
+      HloInstruction::CreateTuple({constant1, constant1}));
+  builder.AddInstruction(
+      HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
+  module_->AddEntryComputation(builder.Build());
+  SCOPED_TRACE(module_->ToString());
+
+  bool ssa_form = GetParam();
+  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
+}
+
 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
   // Test a while instruction with a body which permutes it's tuple parameter
   // elements. HLO:
@@ -1621,8 +1674,8 @@
 
   DependencyHloOrdering ordering(module_.get());
 
-  // Exp only use is the call so it should not interfere with values inside the
-  // embedded computation.
+  // Exp only use is the call so it should not interfere with values inside
+  // the embedded computation.
   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
 
   // Negate is live across the call and should interfere with values in the
@@ -2134,8 +2187,8 @@
   // The fusion instruction never uses tuple element 0, but does use element 1.
   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
-  // The same holds for the parameter tuple, except that the tuple elements are
-  // swapped in 'tuple'.
+  // The same holds for the parameter tuple, except that the tuple elements
+  // are swapped in 'tuple'.
   EXPECT_TRUE(
       dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
   EXPECT_FALSE(
diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.cc b/tensorflow/compiler/xla/service/hlo_phi_graph.cc
new file mode 100644
index 0000000..9b69771
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_phi_graph.cc
@@ -0,0 +1,233 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
+
+#include <queue>
+
+namespace xla {
+HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
+  Node* node = value_id_to_node_[value.id()];
+  return node->value_id;
+}
+
+// Returns true if the input to a hlo value is the same as `inputs`.
+bool PhiGraph::InputsEqualTo(const HloValue& value,
+                             absl::Span<const HloValue* const> inputs) {
+  auto iter = value_id_to_node_.find(value.id());
+  CHECK(iter != value_id_to_node_.end());
+  absl::flat_hash_set<HloValue::Id> existing_set;
+  for (Node* operand : iter->second->operands) {
+    existing_set.insert(operand->value_id);
+  }
+  absl::flat_hash_set<HloValue::Id> new_set;
+  for (const HloValue* input : inputs) {
+    new_set.insert(input->id());
+  }
+  return existing_set == new_set;
+}
+
+HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
+  auto iter = value_id_to_node_.find(id);
+  CHECK(iter != value_id_to_node_.end());
+  return iter->second->value_id;
+}
+
+PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
+  auto iter = value_id_to_node_.find(value.id());
+  if (iter == value_id_to_node_.end()) {
+    node_storage_.emplace_back(absl::make_unique<Node>());
+    Node* node = node_storage_.back().get();
+    node->value_id = value.id();
+    value_id_to_node_[value.id()] = node;
+    node_to_value_id_[node].push_back(value.id());
+    return node;
+  } else {
+    // A node is already registered with this value, check the value_id
+    // is the same as previously registrated.
+    CHECK_NE(iter->second, nullptr);
+    CHECK_EQ(iter->second->value_id, value.id());
+    return iter->second;
+  }
+}
+
+void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
+  // Update users.
+  CHECK(node->is_phi);
+  for (Node* user : node->users) {
+    absl::c_replace(user->operands, node, replace);
+  }
+
+  // Update operand's users
+  for (Node* operand : node->operands) {
+    absl::c_replace(operand->users, node, replace);
+  }
+  for (HloValue::Id value_id : node_to_value_id_[node]) {
+    CHECK(value_id_to_node_.contains(value_id));
+    value_id_to_node_[value_id] = replace;
+  }
+  // Update mappings to HloValue::Id.
+  absl::c_copy(node_to_value_id_[node],
+               std::back_inserter(node_to_value_id_[replace]));
+  node_to_value_id_[node].clear();
+  node->mark_as_dead = true;
+}
+
+void PhiGraph::RegisterPhi(const HloValue& value,
+                           absl::Span<const HloValue* const> inputs) {
+  Node* node = CreateOrReuseNode(value);
+  CHECK(value.is_phi());
+  node->is_phi = true;
+  node->operands.clear();
+  for (auto input : inputs) {
+    CHECK(input != nullptr);
+    Node* input_node = CreateOrReuseNode(*input);
+    node->operands.push_back(input_node);
+  }
+}
+
+std::string PhiGraph::ToString() {
+  std::string out = "PhiGraph: \n";
+  for (auto& node : node_storage_) {
+    std::string is_phi = node->is_phi ? ", phi" : "";
+    std::string is_optimized = node->mark_as_dead ? ", dead" : "";
+    absl::StrAppend(&out, node->value_id);
+    absl::StrAppend(&out, is_phi);
+    absl::StrAppend(&out, is_optimized, ":\n");
+    for (Node* input : node->operands) {
+      absl::StrAppend(&out, "  ", input->value_id);
+      absl::StrAppend(&out, "\n");
+    }
+  }
+  return out;
+}
+
+void PhiGraph::Optimize() {
+  // Set up users for each node.
+  for (auto& node : node_storage_) {
+    for (Node* input : node->operands) {
+      input->users.push_back(node.get());
+    }
+  }
+
+  // input_node->users.push_back(node);
+  bool changed = true;
+
+  // Run the optimization to a fixed point.
+  while (changed) {
+    changed = false;
+    absl::flat_hash_set<Node*> checked_for_closure;
+    for (auto& node : node_storage_) {
+      // Only optimize phi node.
+      if (!node->is_phi) {
+        continue;
+      }
+      // Skip dead nodes
+      if (node->mark_as_dead) {
+        continue;
+      }
+
+      Node* node_ptr = node.get();
+
+      CHECK_GE(node_ptr->operands.size(), 1);
+
+      // Remove self-referencing ids from users and operands.
+      auto it = absl::c_find(node_ptr->operands, node_ptr);
+      while (it != node_ptr->operands.end()) {
+        node_ptr->operands.erase(it);
+        it = absl::c_find(node_ptr->operands, node_ptr);
+      }
+
+      it = absl::c_find(node_ptr->users, node_ptr);
+      while (it != node_ptr->users.end()) {
+        node_ptr->users.erase(it);
+        it = absl::c_find(node_ptr->users, node_ptr);
+      }
+
+      // If all inputs to phi (after self referencing ids are removed) are the
+      // same value, replace the phi with that value.
+      //
+      // phi(A, A, ... A) => A
+      // phi(A, self) = phi(A) => A
+      CHECK_GE(node_ptr->operands.size(), 1);
+      bool all_inputs_are_same = absl::c_all_of(
+          node_ptr->operands,
+          [&](Node* elem) { return elem == node_ptr->operands[0]; });
+
+      if (all_inputs_are_same) {
+        ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
+        changed = true;
+        continue;
+      }
+
+      // Find a closure of inter-connected phis and one non-phi node. Replace
+      // all phis with that non-phi node.
+      //
+      // def A = phi(B, C)
+      // def B = phi(C, D)
+      // def C = phi(A, B)
+      // def D = non-phi
+      // Replace A, B, and C with D:
+      // A = phi(B, C) => D
+      // B = phi(C, D) => D
+      // C = phi(A, B) => D
+      if (checked_for_closure.contains(node_ptr)) {
+        continue;
+      }
+      // Keeps track of nodes in the current closure being tested.
+      absl::flat_hash_set<Node*> workset;
+      std::queue<Node*> worklist;
+      Node* non_phi = nullptr;
+      worklist.push(node_ptr);
+      while (!worklist.empty()) {
+        Node* todo = worklist.front();
+        worklist.pop();
+        if (workset.contains(todo)) {
+          continue;
+        }
+        checked_for_closure.insert(todo);
+        workset.insert(todo);
+        for (Node* operand : todo->operands) {
+          worklist.push(operand);
+        }
+        if (!todo->is_phi) {
+          if (non_phi != nullptr && non_phi != todo) {
+            // We see distinct non-phi nodes in the closure, can't apply the
+            // optimization.
+            non_phi = nullptr;
+            // Break the while loop non_phi setting to nullptr, signaling that
+            // the optimization can't be applied.
+            break;
+          } else {
+            // This is the non_phi node we are seeing so far.
+            non_phi = todo;
+          }
+        }
+      }
+      if (non_phi != nullptr) {
+        // Replace all phi nodes in the closure/workset with the non_phi node.
+        for (Node* node : workset) {
+          if (!node->is_phi) {
+            CHECK_EQ(node, non_phi);
+            continue;
+          }
+          ReplaceNodeWith(node, non_phi);
+          changed = true;
+        }
+      }
+    }
+  }
+}
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.h b/tensorflow/compiler/xla/service/hlo_phi_graph.h
new file mode 100644
index 0000000..a0eb994
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_phi_graph.h
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_
+
+#include <iterator>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_value.h"
+
+namespace xla {
+// Phi graph is a graph that contains and connects phi nodes build on top of
+// HloValues with explicit edges, as well as non-phi nodes that are direct
+// inputs to the phi nodes. The graph can be viewed as an 'overlay' on top of
+// HloValues, with some edges that connect them together. After optimization,
+// some phis nodes will be simplified away and the user can then ask two useful
+// questions:
+//
+// 1. Which HloValue should a phi node being replaced with?
+// 2. TODO(yunxing): What are the set of aliased HloValues that are connecting
+// to the same phi (Must-aliasing).
+class PhiGraph {
+ public:
+  // Register an hlo value into the phi node.
+  void RegisterPhi(const HloValue& value,
+                   absl::Span<const HloValue* const> inputs);
+
+  HloValue::Id GetOptimizedId(const HloValue& value);
+
+  // Returns true if the input to a hlo value is the same as `inputs`.
+  bool InputsEqualTo(const HloValue& value,
+                     absl::Span<const HloValue* const> inputs);
+
+  // Given `id`, returns the new id that `id` should be replaced with. If the
+  // node is not optimized, returns the same value.
+  HloValue::Id FindOptimizedValue(const HloValue::Id id);
+
+  // Optimize the entire graph.
+  void Optimize();
+
+  std::string ToString();
+
+ private:
+  struct Node {
+    bool is_phi;
+    // Users of this node. Non-phi node has no operands.
+    std::vector<Node*> users;
+    // Operands of this node.
+    std::vector<Node*> operands;
+
+    // The value that the node is originally registered with.
+    HloValue::Id value_id;
+
+    // mark_as_dead is set to true when a phi node is simplified away
+    //
+    // Precondition: node is a phi.
+    bool mark_as_dead = false;
+  };
+
+  Node* CreateOrReuseNode(const HloValue& value);
+
+  // Relace `node` with `replace`. Redirect all users to the `replace` and
+  // all HloValues pointing to the `node` to `replace`. Also mark `node` as
+  // dead.
+  //
+  // Precondition: node is a phi -- It's only possible to simplify away a
+  // phi node.
+  void ReplaceNodeWith(Node* node, Node* replace);
+
+  // A reverse mapping of a node in the phi graph and all HloValues pointing
+  // to that phi.
+  absl::flat_hash_map<Node*, std::vector<HloValue::Id>> node_to_value_id_;
+
+  // A mapping between a HloValue and node in the phi graph.
+  absl::flat_hash_map<HloValue::Id, Node*> value_id_to_node_;
+  std::vector<std::unique_ptr<Node>> node_storage_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PHI_GRAPH_H_
diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc
new file mode 100644
index 0000000..41f0454
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+class PhiGraphTest : public ::testing::Test {
+ protected:
+  HloValue NewHloValue(bool is_phi) {
+    static int64 id = 0;
+    return HloValue(id++, dummy_inst_.get(), {}, is_phi);
+  }
+
+  void SetUp() override {
+    dummy_inst_ = HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f));
+  }
+
+  // Dummy instruction used to fill unrelated argument when creating a
+  // HloValue.
+  std::unique_ptr<HloInstruction> dummy_inst_;
+};
+
+TEST_F(PhiGraphTest, SelfReferencingPhi) {
+  // Def A = non-phi
+  // Def B = phi(B, A)
+  //
+  // Optimize B into A.
+  PhiGraph phi_graph;
+  HloValue A = NewHloValue(false);
+  HloValue B = NewHloValue(true);
+  phi_graph.RegisterPhi(B, {&A, &B});
+  phi_graph.Optimize();
+  EXPECT_EQ(A.id(), phi_graph.FindOptimizedValue(B.id()));
+}
+
+TEST_F(PhiGraphTest, PhiWithSameInputs) {
+  // Def A = non-phi
+  // Def B = phi(A, A)
+  //
+  // Optimize B into A.
+  PhiGraph phi_graph;
+  HloValue A = NewHloValue(false);
+  HloValue B = NewHloValue(true);
+  phi_graph.RegisterPhi(B, {&A, &A});
+  phi_graph.Optimize();
+  EXPECT_EQ(A.id(), phi_graph.FindOptimizedValue(B.id()));
+}
+
+TEST_F(PhiGraphTest, CircularPhi) {
+  // def A = phi(B, C)
+  // def B = phi(C, D)
+  // def C = phi(A, B)
+  // def D = non-phi
+  // Replace A, B, and C with D:
+  PhiGraph phi_graph;
+  HloValue A = NewHloValue(true);
+  HloValue B = NewHloValue(true);
+  HloValue C = NewHloValue(true);
+  HloValue D = NewHloValue(false);
+  phi_graph.RegisterPhi(A, {&B, &C});
+  phi_graph.RegisterPhi(B, {&D, &C});
+  phi_graph.RegisterPhi(C, {&A, &B});
+  phi_graph.Optimize();
+  EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id()));
+  EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id()));
+  EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id()));
+}
+
+}  // namespace
+}  // namespace xla