Merge pull request #2443 from digit-google/critical-path-with-topo-sort

ComputeCriticalPath: Use topological sort to speed up function.
diff --git a/src/build.cc b/src/build.cc
index fca119c..e52345b 100644
--- a/src/build.cc
+++ b/src/build.cc
@@ -463,49 +463,87 @@
 
 void Plan::ComputeCriticalPath() {
   METRIC_RECORD("ComputeCriticalPath");
-  // Remove duplicate targets
-  std::unordered_set<const Node*> unique_targets(targets_.begin(),
-                                                 targets_.end());
 
-  // Use backflow algorithm to compute the critical path for all
-  // nodes, starting from the destination nodes.
-  // XXX: ignores pools
-  std::queue<Edge*> work_queue;        // Queue, for breadth-first traversal
-  // The set of edges currently in work_queue, to avoid duplicates.
-  std::unordered_set<const Edge*> active_edges;
-
-  for (const Node* target : unique_targets) {
-    if (Edge* in = target->in_edge()) {
-      int64_t edge_weight = EdgeWeightHeuristic(in);
-      in->set_critical_path_weight(
-          std::max<int64_t>(edge_weight, in->critical_path_weight()));
-      if (active_edges.insert(in).second) {
-        work_queue.push(in);
-      }
+  // Convenience class to perform a topological sort of all edges
+  // reachable from a set of unique targets. Usage is:
+  //
+  // 1) Create instance.
+  //
+  // 2) Call VisitTarget() as many times as necessary.
+  //    Note that duplicate targets are properly ignored.
+  //
+  // 3) Call result() to get a sorted list of edges,
+  //    where each edge appears _after_ its parents,
+  //    i.e. the edges producing its inputs, in the list.
+  //
+  struct TopoSort {
+    void VisitTarget(const Node* target) {
+      Edge* producer = target->in_edge();
+      if (producer)
+        Visit(producer);
     }
+
+    const std::vector<Edge*>& result() const { return sorted_edges_; }
+
+   private:
+    // Implementation note:
+    //
+    // This is the regular depth-first-search algorithm described
+    // at https://en.wikipedia.org/wiki/Topological_sorting, except
+    // that:
+    //
+    // - Edges are appended to the end of the list, for performance
+    //   reasons. Hence the order used in result().
+    //
+    // - Since the graph cannot have any cycles, temporary marks
+    //   are not necessary, and a simple set is used to record
+    //   which edges have already been visited.
+    //
+    void Visit(Edge* edge) {
+      auto insertion = visited_set_.emplace(edge);
+      if (!insertion.second)
+        return;
+
+      for (const Node* input : edge->inputs_) {
+        Edge* producer = input->in_edge();
+        if (producer)
+          Visit(producer);
+      }
+      sorted_edges_.push_back(edge);
+    }
+
+    std::unordered_set<Edge*> visited_set_;
+    std::vector<Edge*> sorted_edges_;
+  };
+
+  TopoSort topo_sort;
+  for (const Node* target : targets_) {
+    topo_sort.VisitTarget(target);
   }
 
-  while (!work_queue.empty()) {
-    Edge* e = work_queue.front();
-    work_queue.pop();
-    // If the critical path of any dependent edges is updated, this
-    // edge may need to be processed again. So re-allow insertion.
-    active_edges.erase(e);
+  const auto& sorted_edges = topo_sort.result();
 
-    for (const Node* input : e->inputs_) {
-      Edge* in = input->in_edge();
-      if (!in) {
+  // First, reset all weights to 1.
+  for (Edge* edge : sorted_edges)
+    edge->set_critical_path_weight(EdgeWeightHeuristic(edge));
+
+  // Second propagate / increment weidghts from
+  // children to parents. Scan the list
+  // in reverse order to do so.
+  for (auto reverse_it = sorted_edges.rbegin();
+       reverse_it != sorted_edges.rend(); ++reverse_it) {
+    Edge* edge = *reverse_it;
+    int64_t edge_weight = edge->critical_path_weight();
+
+    for (const Node* input : edge->inputs_) {
+      Edge* producer = input->in_edge();
+      if (!producer)
         continue;
-      }
-      // Only process edge if this node offers a higher weighted path
-      const int64_t edge_weight = EdgeWeightHeuristic(in);
-      const int64_t proposed_weight = e->critical_path_weight() + edge_weight;
-      if (proposed_weight > in->critical_path_weight()) {
-        in->set_critical_path_weight(proposed_weight);
-        if (active_edges.insert(in).second) {
-          work_queue.push(in);
-        }
-      }
+
+      int64_t producer_weight = producer->critical_path_weight();
+      int64_t candidate_weight = edge_weight + EdgeWeightHeuristic(producer);
+      if (candidate_weight > producer_weight)
+        producer->set_critical_path_weight(candidate_weight);
     }
   }
 }