Quick: Separate null check elimination and type inference.

Change-Id: I4566ae9354c91ca935481cb4f5b729bba05c1592
diff --git a/compiler/dex/bb_optimizations.h b/compiler/dex/bb_optimizations.h
index b2c348b..fce23bc 100644
--- a/compiler/dex/bb_optimizations.h
+++ b/compiler/dex/bb_optimizations.h
@@ -137,20 +137,20 @@
 };
 
 /**
- * @class NullCheckEliminationAndTypeInference
- * @brief Null check elimination and type inference.
+ * @class NullCheckElimination
+ * @brief Null check elimination pass.
  */
-class NullCheckEliminationAndTypeInference : public PassME {
+class NullCheckElimination : public PassME {
  public:
-  NullCheckEliminationAndTypeInference()
-    : PassME("NCE_TypeInference", kRepeatingTopologicalSortTraversal, "4_post_nce_cfg") {
+  NullCheckElimination()
+    : PassME("NCE", kRepeatingTopologicalSortTraversal, "3_post_nce_cfg") {
   }
 
-  void Start(PassDataHolder* data) const {
+  bool Gate(const PassDataHolder* data) const {
     DCHECK(data != nullptr);
-    CompilationUnit* c_unit = down_cast<PassMEDataHolder*>(data)->c_unit;
+    CompilationUnit* c_unit = down_cast<const PassMEDataHolder*>(data)->c_unit;
     DCHECK(c_unit != nullptr);
-    c_unit->mir_graph->EliminateNullChecksAndInferTypesStart();
+    return c_unit->mir_graph->EliminateNullChecksGate();
   }
 
   bool Worker(PassDataHolder* data) const {
@@ -160,14 +160,35 @@
     DCHECK(c_unit != nullptr);
     BasicBlock* bb = pass_me_data_holder->bb;
     DCHECK(bb != nullptr);
-    return c_unit->mir_graph->EliminateNullChecksAndInferTypes(bb);
+    return c_unit->mir_graph->EliminateNullChecks(bb);
   }
 
   void End(PassDataHolder* data) const {
     DCHECK(data != nullptr);
     CompilationUnit* c_unit = down_cast<PassMEDataHolder*>(data)->c_unit;
     DCHECK(c_unit != nullptr);
-    c_unit->mir_graph->EliminateNullChecksAndInferTypesEnd();
+    c_unit->mir_graph->EliminateNullChecksEnd();
+  }
+};
+
+/**
+ * @class TypeInference
+ * @brief Type inference pass.
+ */
+class TypeInference : public PassME {
+ public:
+  TypeInference()
+    : PassME("TypeInference", kRepeatingTopologicalSortTraversal, "4_post_type_cfg") {
+  }
+
+  bool Worker(PassDataHolder* data) const {
+    DCHECK(data != nullptr);
+    PassMEDataHolder* pass_me_data_holder = down_cast<PassMEDataHolder*>(data);
+    CompilationUnit* c_unit = pass_me_data_holder->c_unit;
+    DCHECK(c_unit != nullptr);
+    BasicBlock* bb = pass_me_data_holder->bb;
+    DCHECK(bb != nullptr);
+    return c_unit->mir_graph->InferTypes(bb);
   }
 };
 
diff --git a/compiler/dex/mir_graph.h b/compiler/dex/mir_graph.h
index fe6fb75..21b6914 100644
--- a/compiler/dex/mir_graph.h
+++ b/compiler/dex/mir_graph.h
@@ -1022,9 +1022,10 @@
   int SRegToVReg(int ssa_reg) const;
   void VerifyDataflow();
   void CheckForDominanceFrontier(BasicBlock* dom_bb, const BasicBlock* succ_bb);
-  void EliminateNullChecksAndInferTypesStart();
-  bool EliminateNullChecksAndInferTypes(BasicBlock* bb);
-  void EliminateNullChecksAndInferTypesEnd();
+  bool EliminateNullChecksGate();
+  bool EliminateNullChecks(BasicBlock* bb);
+  void EliminateNullChecksEnd();
+  bool InferTypes(BasicBlock* bb);
   bool EliminateClassInitChecksGate();
   bool EliminateClassInitChecks(BasicBlock* bb);
   void EliminateClassInitChecksEnd();
diff --git a/compiler/dex/mir_optimization.cc b/compiler/dex/mir_optimization.cc
index 35dae00..bf3d7df 100644
--- a/compiler/dex/mir_optimization.cc
+++ b/compiler/dex/mir_optimization.cc
@@ -819,96 +819,95 @@
   }
 }
 
-void MIRGraph::EliminateNullChecksAndInferTypesStart() {
-  if ((cu_->disable_opt & (1 << kNullCheckElimination)) == 0) {
-    if (kIsDebugBuild) {
-      AllNodesIterator iter(this);
-      for (BasicBlock* bb = iter.Next(); bb != nullptr; bb = iter.Next()) {
-        CHECK(bb->data_flow_info == nullptr || bb->data_flow_info->ending_check_v == nullptr);
-      }
-    }
-
-    DCHECK(temp_scoped_alloc_.get() == nullptr);
-    temp_scoped_alloc_.reset(ScopedArenaAllocator::Create(&cu_->arena_stack));
-    temp_bit_vector_size_ = GetNumSSARegs();
-    temp_bit_vector_ = new (temp_scoped_alloc_.get()) ArenaBitVector(
-        temp_scoped_alloc_.get(), temp_bit_vector_size_, false, kBitMapTempSSARegisterV);
+bool MIRGraph::EliminateNullChecksGate() {
+  if ((cu_->disable_opt & (1 << kNullCheckElimination)) != 0 ||
+      (merged_df_flags_ & DF_HAS_NULL_CHKS) == 0) {
+    return false;
   }
+
+  if (kIsDebugBuild) {
+    AllNodesIterator iter(this);
+    for (BasicBlock* bb = iter.Next(); bb != nullptr; bb = iter.Next()) {
+      CHECK(bb->data_flow_info == nullptr || bb->data_flow_info->ending_check_v == nullptr);
+    }
+  }
+
+  DCHECK(temp_scoped_alloc_.get() == nullptr);
+  temp_scoped_alloc_.reset(ScopedArenaAllocator::Create(&cu_->arena_stack));
+  temp_bit_vector_size_ = GetNumSSARegs();
+  temp_bit_vector_ = new (temp_scoped_alloc_.get()) ArenaBitVector(
+      temp_scoped_alloc_.get(), temp_bit_vector_size_, false, kBitMapTempSSARegisterV);
+  return true;
 }
 
 /*
- * Eliminate unnecessary null checks for a basic block.   Also, while we're doing
- * an iterative walk go ahead and perform type and size inference.
+ * Eliminate unnecessary null checks for a basic block.
  */
-bool MIRGraph::EliminateNullChecksAndInferTypes(BasicBlock* bb) {
-  if (bb->data_flow_info == NULL) return false;
-  bool infer_changed = false;
-  bool do_nce = ((cu_->disable_opt & (1 << kNullCheckElimination)) == 0);
+bool MIRGraph::EliminateNullChecks(BasicBlock* bb) {
+  if (bb->data_flow_info == nullptr) return false;
 
   ArenaBitVector* ssa_regs_to_check = temp_bit_vector_;
-  if (do_nce) {
-    /*
-     * Set initial state. Catch blocks don't need any special treatment.
-     */
-    if (bb->block_type == kEntryBlock) {
-      ssa_regs_to_check->ClearAllBits();
-      // Assume all ins are objects.
-      for (uint16_t in_reg = GetFirstInVR();
-           in_reg < GetNumOfCodeVRs(); in_reg++) {
-        ssa_regs_to_check->SetBit(in_reg);
-      }
-      if ((cu_->access_flags & kAccStatic) == 0) {
-        // If non-static method, mark "this" as non-null
-        int this_reg = GetFirstInVR();
-        ssa_regs_to_check->ClearBit(this_reg);
-      }
-    } else if (bb->predecessors.size() == 1) {
-      BasicBlock* pred_bb = GetBasicBlock(bb->predecessors[0]);
-      // pred_bb must have already been processed at least once.
-      DCHECK(pred_bb->data_flow_info->ending_check_v != nullptr);
-      ssa_regs_to_check->Copy(pred_bb->data_flow_info->ending_check_v);
-      if (pred_bb->block_type == kDalvikByteCode) {
-        // Check to see if predecessor had an explicit null-check.
-        MIR* last_insn = pred_bb->last_mir_insn;
-        if (last_insn != nullptr) {
-          Instruction::Code last_opcode = last_insn->dalvikInsn.opcode;
-          if (last_opcode == Instruction::IF_EQZ) {
-            if (pred_bb->fall_through == bb->id) {
-              // The fall-through of a block following a IF_EQZ, set the vA of the IF_EQZ to show that
-              // it can't be null.
-              ssa_regs_to_check->ClearBit(last_insn->ssa_rep->uses[0]);
-            }
-          } else if (last_opcode == Instruction::IF_NEZ) {
-            if (pred_bb->taken == bb->id) {
-              // The taken block following a IF_NEZ, set the vA of the IF_NEZ to show that it can't be
-              // null.
-              ssa_regs_to_check->ClearBit(last_insn->ssa_rep->uses[0]);
-            }
+  /*
+   * Set initial state. Catch blocks don't need any special treatment.
+   */
+  if (bb->block_type == kEntryBlock) {
+    ssa_regs_to_check->ClearAllBits();
+    // Assume all ins are objects.
+    for (uint16_t in_reg = GetFirstInVR();
+         in_reg < GetNumOfCodeVRs(); in_reg++) {
+      ssa_regs_to_check->SetBit(in_reg);
+    }
+    if ((cu_->access_flags & kAccStatic) == 0) {
+      // If non-static method, mark "this" as non-null
+      int this_reg = GetFirstInVR();
+      ssa_regs_to_check->ClearBit(this_reg);
+    }
+  } else if (bb->predecessors.size() == 1) {
+    BasicBlock* pred_bb = GetBasicBlock(bb->predecessors[0]);
+    // pred_bb must have already been processed at least once.
+    DCHECK(pred_bb->data_flow_info->ending_check_v != nullptr);
+    ssa_regs_to_check->Copy(pred_bb->data_flow_info->ending_check_v);
+    if (pred_bb->block_type == kDalvikByteCode) {
+      // Check to see if predecessor had an explicit null-check.
+      MIR* last_insn = pred_bb->last_mir_insn;
+      if (last_insn != nullptr) {
+        Instruction::Code last_opcode = last_insn->dalvikInsn.opcode;
+        if (last_opcode == Instruction::IF_EQZ) {
+          if (pred_bb->fall_through == bb->id) {
+            // The fall-through of a block following a IF_EQZ, set the vA of the IF_EQZ to show that
+            // it can't be null.
+            ssa_regs_to_check->ClearBit(last_insn->ssa_rep->uses[0]);
+          }
+        } else if (last_opcode == Instruction::IF_NEZ) {
+          if (pred_bb->taken == bb->id) {
+            // The taken block following a IF_NEZ, set the vA of the IF_NEZ to show that it can't be
+            // null.
+            ssa_regs_to_check->ClearBit(last_insn->ssa_rep->uses[0]);
           }
         }
       }
-    } else {
-      // Starting state is union of all incoming arcs
-      bool copied_first = false;
-      for (BasicBlockId pred_id : bb->predecessors) {
-        BasicBlock* pred_bb = GetBasicBlock(pred_id);
-        DCHECK(pred_bb != nullptr);
-        DCHECK(pred_bb->data_flow_info != nullptr);
-        if (pred_bb->data_flow_info->ending_check_v == nullptr) {
-          continue;
-        }
-        if (!copied_first) {
-          copied_first = true;
-          ssa_regs_to_check->Copy(pred_bb->data_flow_info->ending_check_v);
-        } else {
-          ssa_regs_to_check->Union(pred_bb->data_flow_info->ending_check_v);
-        }
-      }
-      DCHECK(copied_first);  // At least one predecessor must have been processed before this bb.
     }
-    // At this point, ssa_regs_to_check shows which sregs have an object definition with
-    // no intervening uses.
+  } else {
+    // Starting state is union of all incoming arcs
+    bool copied_first = false;
+    for (BasicBlockId pred_id : bb->predecessors) {
+      BasicBlock* pred_bb = GetBasicBlock(pred_id);
+      DCHECK(pred_bb != nullptr);
+      DCHECK(pred_bb->data_flow_info != nullptr);
+      if (pred_bb->data_flow_info->ending_check_v == nullptr) {
+        continue;
+      }
+      if (!copied_first) {
+        copied_first = true;
+        ssa_regs_to_check->Copy(pred_bb->data_flow_info->ending_check_v);
+      } else {
+        ssa_regs_to_check->Union(pred_bb->data_flow_info->ending_check_v);
+      }
+    }
+    DCHECK(copied_first);  // At least one predecessor must have been processed before this bb.
   }
+  // At this point, ssa_regs_to_check shows which sregs have an object definition with
+  // no intervening uses.
 
   // Walk through the instruction in the block, updating as necessary
   for (MIR* mir = bb->first_mir_insn; mir != NULL; mir = mir->next) {
@@ -916,12 +915,6 @@
         continue;
     }
 
-    // Propagate type info.
-    infer_changed = InferTypeAndSize(bb, mir, infer_changed);
-    if (!do_nce) {
-      continue;
-    }
-
     uint64_t df_attributes = GetDataFlowAttributes(mir);
 
     // Might need a null check?
@@ -1022,35 +1015,50 @@
 
   // Did anything change?
   bool nce_changed = false;
-  if (do_nce) {
-    if (bb->data_flow_info->ending_check_v == nullptr) {
-      DCHECK(temp_scoped_alloc_.get() != nullptr);
-      bb->data_flow_info->ending_check_v = new (temp_scoped_alloc_.get()) ArenaBitVector(
-          temp_scoped_alloc_.get(), temp_bit_vector_size_, false, kBitMapNullCheck);
-      nce_changed = ssa_regs_to_check->GetHighestBitSet() != -1;
-      bb->data_flow_info->ending_check_v->Copy(ssa_regs_to_check);
-    } else if (!ssa_regs_to_check->SameBitsSet(bb->data_flow_info->ending_check_v)) {
-      nce_changed = true;
-      bb->data_flow_info->ending_check_v->Copy(ssa_regs_to_check);
-    }
+  if (bb->data_flow_info->ending_check_v == nullptr) {
+    DCHECK(temp_scoped_alloc_.get() != nullptr);
+    bb->data_flow_info->ending_check_v = new (temp_scoped_alloc_.get()) ArenaBitVector(
+        temp_scoped_alloc_.get(), temp_bit_vector_size_, false, kBitMapNullCheck);
+    nce_changed = ssa_regs_to_check->GetHighestBitSet() != -1;
+    bb->data_flow_info->ending_check_v->Copy(ssa_regs_to_check);
+  } else if (!ssa_regs_to_check->SameBitsSet(bb->data_flow_info->ending_check_v)) {
+    nce_changed = true;
+    bb->data_flow_info->ending_check_v->Copy(ssa_regs_to_check);
   }
-  return infer_changed | nce_changed;
+  return nce_changed;
 }
 
-void MIRGraph::EliminateNullChecksAndInferTypesEnd() {
-  if ((cu_->disable_opt & (1 << kNullCheckElimination)) == 0) {
-    // Clean up temporaries.
-    temp_bit_vector_size_ = 0u;
-    temp_bit_vector_ = nullptr;
-    AllNodesIterator iter(this);
-    for (BasicBlock* bb = iter.Next(); bb != nullptr; bb = iter.Next()) {
-      if (bb->data_flow_info != nullptr) {
-        bb->data_flow_info->ending_check_v = nullptr;
-      }
+void MIRGraph::EliminateNullChecksEnd() {
+  // Clean up temporaries.
+  temp_bit_vector_size_ = 0u;
+  temp_bit_vector_ = nullptr;
+  AllNodesIterator iter(this);
+  for (BasicBlock* bb = iter.Next(); bb != nullptr; bb = iter.Next()) {
+    if (bb->data_flow_info != nullptr) {
+      bb->data_flow_info->ending_check_v = nullptr;
     }
-    DCHECK(temp_scoped_alloc_.get() != nullptr);
-    temp_scoped_alloc_.reset();
   }
+  DCHECK(temp_scoped_alloc_.get() != nullptr);
+  temp_scoped_alloc_.reset();
+}
+
+/*
+ * Perform type and size inference for a basic block.
+ */
+bool MIRGraph::InferTypes(BasicBlock* bb) {
+  if (bb->data_flow_info == nullptr) return false;
+
+  bool infer_changed = false;
+  for (MIR* mir = bb->first_mir_insn; mir != NULL; mir = mir->next) {
+    if (mir->ssa_rep == NULL) {
+        continue;
+    }
+
+    // Propagate type info.
+    infer_changed = InferTypeAndSize(bb, mir, infer_changed);
+  }
+
+  return infer_changed;
 }
 
 bool MIRGraph::EliminateClassInitChecksGate() {
diff --git a/compiler/dex/pass_driver_me_opts.cc b/compiler/dex/pass_driver_me_opts.cc
index 6281062..cd3ffd4 100644
--- a/compiler/dex/pass_driver_me_opts.cc
+++ b/compiler/dex/pass_driver_me_opts.cc
@@ -37,7 +37,8 @@
   GetPassInstance<CacheMethodLoweringInfo>(),
   GetPassInstance<SpecialMethodInliner>(),
   GetPassInstance<CodeLayout>(),
-  GetPassInstance<NullCheckEliminationAndTypeInference>(),
+  GetPassInstance<NullCheckElimination>(),
+  GetPassInstance<TypeInference>(),
   GetPassInstance<ClassInitCheckElimination>(),
   GetPassInstance<GlobalValueNumberingPass>(),
   GetPassInstance<BBCombine>(),