Refactor deps log loading, fix 2 bugs

There were two bugs in DepsLog::Load's main parsing pass:

 * Previously, with an invalid log file, the main pass could initialize
   dep_index[output_id] with the index of a record after the point where
   the log is truncated, e.g.:

    - Chunk 1: path record for node #0
    - Chunk 1: invalid record
    - Chunk 2: path record for node #1
    - Chunk 2: deps record outputting node #0, needs node #1

   The result of the parse could depend on chunk boundaries (e.g. how many
   threads the machine has), and the parser could crash if the later deps
   record has source IDs that were also truncated.

   Fix the problem by moving dep_index initialization to a later pass. The
   validation and truncation work is factored out into a ValidateDepsLog
   function.

 * Fix node ID validation of deps record inputs. The existing code to do
   this had no effect:

        if (output_id < 0 || output_id >= next_node_id) break;
        for (size_t i = 4; i < size; ++i) {
          int input_id = log.table[index + i];
          if (input_id < 0 || input_id >= next_node_id) break;

   The outer break exited the for-each-record loop early, signaling that
   parsing had failed. The nested break exited the for-each-input loop,
   which merely prevented validation of later node IDs. Replace the break
   statements with "return false" in IsValidRecord.

These two changes regressed ".ninja_deps load" run-time by about 10ms
on a 580MB .ninja_deps file. (e.g. about 140ms -> 150ms). I suspect the
compiler may have been optimizing out the source ID checking.

Test: ninja_test
Change-Id: I13c3a314cfa7d2bf15c724962a9ec35f55176779
diff --git a/src/deps_log.cc b/src/deps_log.cc
index f7b6476..7ebe15c 100644
--- a/src/deps_log.cc
+++ b/src/deps_log.cc
@@ -407,6 +407,121 @@
   return true;
 }
 
+template <typename Func>
+static bool ForEachRecord(DepsLogData log, DepsLogWordSpan chunk,
+                          Func&& callback) {
+  InputRecord record {};
+  for (size_t index = chunk.begin; index < chunk.end; index += record.size) {
+    record = ParseRecord(log, index);
+    if (!callback(static_cast<const InputRecord&>(record))) return false;
+    if (record.kind == InputRecord::InvalidHeader) break;
+  }
+  return true;
+}
+
+struct DepsLogNodeSpan {
+  int begin; // starting node ID of span
+  int end; // stopping node ID of span
+};
+
+/// Determine the range of node IDs for each chunk of the deps log. The range
+/// for a chunk will only be used if the preceding chunks are valid.
+static std::vector<DepsLogNodeSpan>
+FindInitialNodeSpans(DepsLogData log,
+                     const std::vector<DepsLogWordSpan>& chunk_words,
+                     ThreadPool* thread_pool) {
+  // First count the number of path records (nodes) in each chunk.
+  std::vector<int> chunk_node_counts = ParallelMap(thread_pool, chunk_words,
+      [log](DepsLogWordSpan chunk) {
+    int node_count = 0;
+    ForEachRecord(log, chunk, [&node_count](const InputRecord& record) {
+      if (record.kind == InputRecord::PathRecord) {
+        ++node_count;
+      }
+      return true;
+    });
+    return node_count;
+  });
+
+  // Compute an initial [begin, end) node ID range for each chunk.
+  std::vector<DepsLogNodeSpan> result;
+  int next_node_id = 0;
+  for (int num_nodes : chunk_node_counts) {
+    result.push_back({ next_node_id, next_node_id + num_nodes });
+    next_node_id += num_nodes;
+  }
+  return result;
+}
+
+static bool IsValidRecord(const InputRecord& record, int next_node_id) {
+  auto is_valid_id = [next_node_id](int id) {
+    return id >= 0 && id < next_node_id;
+  };
+
+  switch (record.kind) {
+  case InputRecord::InvalidHeader:
+    return false;
+  case InputRecord::PathRecord:
+    // Validate the path's checksum.
+    if (record.u.path.checksum != ~next_node_id) {
+      return false;
+    }
+    break;
+  case InputRecord::DepsRecord:
+    // Verify that input/output node IDs are valid.
+    if (!is_valid_id(record.u.deps.output_id)) {
+      return false;
+    }
+    for (size_t i = 0; i < record.u.deps.deps_count; ++i) {
+      if (!is_valid_id(record.u.deps.deps[i])) {
+        return false;
+      }
+    }
+    break;
+  }
+  return true;
+}
+
+/// Validate the deps log. If there is an invalid record, the function truncates
+/// the word+node span vectors just before the invalid record.
+static void ValidateDepsLog(DepsLogData log,
+                            std::vector<DepsLogWordSpan>* chunk_words,
+                            std::vector<DepsLogNodeSpan>* chunk_nodes,
+                            ThreadPool* thread_pool) {
+  std::atomic<size_t> num_valid_chunks { chunk_words->size() };
+
+  ParallelMap(thread_pool, IntegralRange<size_t>(0, num_valid_chunks),
+      [&](size_t chunk_index) {
+
+    size_t next_word_idx = (*chunk_words)[chunk_index].begin;
+    int next_node_id = (*chunk_nodes)[chunk_index].begin;
+
+    bool success = ForEachRecord(log, (*chunk_words)[chunk_index],
+        [&](const InputRecord& record) {
+      if (!IsValidRecord(record, next_node_id)) {
+        return false;
+      }
+      next_word_idx += record.size;
+      if (record.kind == InputRecord::PathRecord) {
+        ++next_node_id;
+      }
+      return true;
+    });
+
+    if (success) {
+      assert(next_word_idx == (*chunk_words)[chunk_index].end);
+      assert(next_node_id == (*chunk_nodes)[chunk_index].end);
+    } else {
+      (*chunk_words)[chunk_index].end = next_word_idx;
+      (*chunk_nodes)[chunk_index].end = next_node_id;
+      AtomicUpdateMinimum(&num_valid_chunks, chunk_index + 1);
+    }
+  });
+
+  chunk_words->resize(num_valid_chunks);
+  chunk_nodes->resize(num_valid_chunks);
+}
+
 bool DepsLog::Load(const string& path, State* state, string* err) {
   METRIC_RECORD(".ninja_deps load");
 
@@ -418,168 +533,114 @@
 
   DepsLogData log = log_file.data;
 
-  struct NINJA_ALIGNAS_CACHE_LINE Chunk {
-    size_t start = 0;
-    size_t stop = 0;
-    int first_node_id = 0;
-    int initial_node_count = 0;
-    int final_node_count = 0;
-    size_t deps_count = 0;
-    bool parse_error = false;
-  };
-
   std::unique_ptr<ThreadPool> thread_pool = CreateThreadPool();
 
-  std::vector<Chunk> chunks;
-  for (DepsLogWordSpan span : SplitDepsLog(log, thread_pool.get())) {
-    Chunk chunk {};
-    chunk.start = span.begin;
-    chunk.stop = span.end;
-    chunks.push_back(chunk);
-  }
+  std::vector<DepsLogWordSpan> chunk_words = SplitDepsLog(log, thread_pool.get());
+  std::vector<DepsLogNodeSpan> chunk_nodes =
+      FindInitialNodeSpans(log, chunk_words, thread_pool.get());
 
-  // Compute the starting node ID for each chunk. The result is correct as long as
-  // preceding chunks are parsed successfully. If there is a parsing error in a
-  // chunk, then following chunks are discarded after the validation pass.
-  ParallelMap(thread_pool.get(), chunks, [log](Chunk& chunk) {
-    size_t index = chunk.start;
-    while (index < chunk.stop) {
-      InputRecord record = ParseRecord(log, index);
-      if (record.kind == InputRecord::InvalidHeader) return;
-      if (record.kind == InputRecord::PathRecord) {
-        ++chunk.initial_node_count;
-      }
-      index += record.size;
-    }
-  });
-  int initial_node_count = 0;
-  for (size_t i = 0; i < chunks.size(); ++i) {
-    Chunk& chunk = chunks[i];
-    chunk.first_node_id = initial_node_count;
-    initial_node_count += chunk.initial_node_count;
-  }
+  // Validate the log and truncate the vectors after an invalid record.
+  ValidateDepsLog(log, &chunk_words, &chunk_nodes, thread_pool.get());
+  assert(chunk_words.size() == chunk_nodes.size());
+
+  const size_t chunk_count = chunk_words.size();
+  const int node_count = chunk_nodes.empty() ? 0 : chunk_nodes.back().end;
+
+  // The state path hash table doesn't automatically resize, so make sure that
+  // it has at least one bucket for each node in this deps log.
+  state->paths_.reserve(node_count);
+
+  nodes_.resize(node_count);
 
   // A map from a node ID to the final file index of the deps record outputting
   // the given node ID.
-  std::vector<std::atomic<ssize_t>> dep_index(initial_node_count);
+  std::vector<std::atomic<ssize_t>> dep_index(node_count);
   for (auto& index : dep_index) {
     // Write a value of -1 to indicate that no deps record outputs this ID. We
     // don't need these stores to be synchronized with other threads, so use
     // relaxed stores, which are much faster.
     index.store(-1, std::memory_order_relaxed);
   }
-  // A map from a node ID to the file index of that node.
-  std::vector<size_t> node_index(initial_node_count);
 
-  // The main parsing pass. Validate each chunk's entries and, for each node ID,
-  // record the location of its node and deps records. If there is parser error,
-  // truncate the log just before the problem record.
-  ParallelMap(thread_pool.get(), chunks,
-      [log, &dep_index, &node_index](Chunk& chunk) {
-    size_t index = chunk.start;
-    int next_node_id = chunk.first_node_id;
-    while (index < chunk.stop) {
-      InputRecord record = ParseRecord(log, index);
-      if (record.kind == InputRecord::InvalidHeader) break;
-      if (record.kind == InputRecord::DepsRecord) {
-        // Verify that input/output node IDs are valid.
-        int output_id = record.u.deps.output_id;
-        if (output_id < 0 || output_id >= next_node_id) break;
-        for (size_t i = 0; i < record.u.deps.deps_count; ++i) {
-          int input_id = record.u.deps.deps[i];
-          if (input_id < 0 || input_id >= next_node_id) break;
-        }
-        AtomicUpdateMaximum(&dep_index[output_id], static_cast<ssize_t>(index));
-        ++chunk.deps_count;
-      } else {
-        // Validate the path's checksum.
-        if (record.u.path.checksum != ~next_node_id) break;
-        node_index[next_node_id] = index;
-        ++next_node_id;
-        ++chunk.final_node_count;
+  // Add the nodes into the build graph, find the last deps record
+  // outputting each node, and count the total number of deps records.
+  const std::vector<size_t> dep_record_counts = ParallelMap(thread_pool.get(),
+      IntegralRange<size_t>(0, chunk_count),
+      [log, state, chunk_words, chunk_nodes, &dep_index,
+        this](size_t chunk_index) {
+    size_t next_word_idx = chunk_words[chunk_index].begin;
+    int next_node_id = chunk_nodes[chunk_index].begin;
+    int stop_node_id = chunk_nodes[chunk_index].end;
+    (void)stop_node_id; // suppress unused variable compiler warning
+    size_t dep_record_count = 0;
+
+    ForEachRecord(log, chunk_words[chunk_index],
+        [&, this](const InputRecord& record) {
+      assert(record.kind != InputRecord::InvalidHeader);
+      if (record.kind == InputRecord::PathRecord) {
+        int node_id = next_node_id++;
+        assert(node_id < stop_node_id);
+        assert(record.u.path.checksum == ~node_id);
+
+        // It is not necessary to pass in a correct slash_bits here. It will
+        // either be a Node that's in the manifest (in which case it will
+        // already have a correct slash_bits that GetNode will look up), or it
+        // is an implicit dependency from a .d which does not affect the build
+        // command (and so need not have its slashes maintained).
+        Node* node = state->GetNode(record.u.path.path, 0);
+        assert(node->id() < 0);
+        node->set_id(node_id);
+        nodes_[node_id] = node;
+      } else if (record.kind == InputRecord::DepsRecord) {
+        const int output_id = record.u.deps.output_id;
+        assert(static_cast<size_t>(output_id) < dep_index.size());
+        AtomicUpdateMaximum(&dep_index[output_id],
+                            static_cast<ssize_t>(next_word_idx));
+        ++dep_record_count;
       }
-      index += record.size;
-    }
-    // We'll exit early on a parser error.
-    if (index < chunk.stop) {
-      chunk.stop = index;
-      chunk.parse_error = true;
-    }
+      next_word_idx += record.size;
+      return true;
+    });
+    assert(next_node_id == stop_node_id);
+    return dep_record_count;
   });
-  int node_count = 0;
-  size_t total_dep_record_count = 0;
-  for (size_t i = 0; i < chunks.size(); ++i) {
-    Chunk& chunk = chunks[i];
-    assert(chunk.first_node_id == node_count);
-    total_dep_record_count += chunk.deps_count;
-    node_count += chunk.final_node_count;
-    if (chunk.parse_error) {
-      // Part of this chunk may have been parsed successfully, so keep it, but
-      // discard all later chunks.
-      chunks.resize(i + 1);
-      break;
+
+  // Count the number of total and unique deps records.
+  const size_t total_dep_record_count =
+      std::accumulate(dep_record_counts.begin(), dep_record_counts.end(),
+                      static_cast<size_t>(0));
+  size_t unique_dep_record_count = 0;
+  for (auto& index : dep_index) {
+    if (index.load(std::memory_order_relaxed) != -1) {
+      ++unique_dep_record_count;
     }
   }
 
-  // The final node count could be smaller than the initial count if there was a
-  // parser error.
-  assert(node_count <= initial_node_count);
-
-  // The log is valid. Commit the nodes into the state graph. First make sure
-  // that the hash table has at least one bucket for each node in this deps log.
-  state->paths_.reserve(node_count);
-  nodes_.resize(node_count);
-  ParallelMap(thread_pool.get(), IntegralRange<int>(0, node_count),
-      [this, state, log, &node_index](int node_id) {
-    size_t index = node_index[node_id];
-
-    InputRecord record = ParseRecord(log, index);
-    assert(record.kind == InputRecord::PathRecord);
-    assert(record.u.path.checksum == ~node_id);
-
-    // It is not necessary to pass in a correct slash_bits here. It will
-    // either be a Node that's in the manifest (in which case it will
-    // already have a correct slash_bits that GetNode will look up), or it
-    // is an implicit dependency from a .d which does not affect the build
-    // command (and so need not have its slashes maintained).
-    Node* node = state->GetNode(record.u.path.path, 0);
-    assert(node->id() < 0);
-    node->set_id(node_id);
-    nodes_[node_id] = node;
-  });
-
   // Add the deps records.
   deps_.resize(node_count);
-  std::vector<size_t> unique_counts = ParallelMap(thread_pool.get(),
-      SplitByThreads(node_count),
-      [this, log, &dep_index](std::pair<int, int> node_chunk) {
-    size_t unique_count = 0;
-    for (int node_id = node_chunk.first; node_id < node_chunk.second; ++node_id) {
-      ssize_t index = dep_index[node_id];
-      if (index == -1) continue;
-      ++unique_count;
+  ParallelMap(thread_pool.get(), IntegralRange<int>(0, node_count),
+      [this, log, &dep_index](int node_id) {
+    ssize_t index = dep_index[node_id];
+    if (index == -1) return;
 
-      InputRecord record = ParseRecord(log, index);
-      assert(record.kind == InputRecord::DepsRecord);
-      assert(record.u.deps.output_id == node_id);
+    InputRecord record = ParseRecord(log, index);
+    assert(record.kind == InputRecord::DepsRecord);
+    assert(record.u.deps.output_id == node_id);
 
-      Deps* deps = new Deps(record.u.deps.mtime, record.u.deps.deps_count);
-      for (size_t i = 0; i < record.u.deps.deps_count; ++i) {
-        int input_id = record.u.deps.deps[i];
-        Node* node = nodes_[input_id];
-        assert(node != nullptr);
-        deps->nodes[i] = node;
-      }
-      deps_[node_id] = deps;
+    Deps* deps = new Deps(record.u.deps.mtime, record.u.deps.deps_count);
+    for (size_t i = 0; i < record.u.deps.deps_count; ++i) {
+      const int input_id = record.u.deps.deps[i];
+      assert(static_cast<size_t>(input_id) < nodes_.size());
+      Node* node = nodes_[input_id];
+      assert(node != nullptr);
+      deps->nodes[i] = node;
     }
-    return unique_count;
+    deps_[node_id] = deps;
   });
-  size_t unique_dep_record_count = std::accumulate(unique_counts.begin(),
-                                                   unique_counts.end(), 0);
 
   const size_t actual_file_size = log_file.file->content().size();
   const size_t parsed_file_size = kFileHeaderSize +
-      (chunks.empty() ? 0 : chunks.back().stop) * sizeof(uint32_t);
+      (chunk_words.empty() ? 0 : chunk_words.back().end) * sizeof(uint32_t);
   assert(parsed_file_size <= actual_file_size);
   if (parsed_file_size < actual_file_size) {
     // An error occurred while loading; try to recover by truncating the file to
diff --git a/src/deps_log_test.cc b/src/deps_log_test.cc
index f39f7da..c43529e 100644
--- a/src/deps_log_test.cc
+++ b/src/deps_log_test.cc
@@ -498,6 +498,114 @@
   }
 }
 
+template <typename Func>
+static void DoLoadInvalidLogTest(Func&& func) {
+  State state;
+  DepsLog log;
+  std::string err;
+  ASSERT_TRUE(log.Load(kTestFilename, &state, &err));
+  ASSERT_EQ("premature end of file; recovering", err);
+  func(&state, &log);
+}
+
+TEST_F(DepsLogTest, LoadInvalidLog) {
+  struct Item {
+    Item(int num) : is_num(true), num(num) {}
+    Item(const char* str) : is_num(false), str(str) {}
+
+    bool is_num;
+    uint32_t num;
+    const char* str;
+  };
+
+  auto write_file = [](std::vector<Item> items) {
+    FILE* fp = fopen(kTestFilename, "wb");
+    for (const Item& item : items) {
+      if (item.is_num) {
+        ASSERT_EQ(1, fwrite(&item.num, sizeof(item.num), 1, fp));
+      } else {
+        ASSERT_EQ(strlen(item.str), fwrite(item.str, 1, strlen(item.str), fp));
+      }
+    }
+    fclose(fp);
+  };
+
+  const int kCurrentVersion = 4;
+  auto path_hdr = [](int path_len) -> int {
+    return RoundUp(path_len, 4) + 4;
+  };
+  auto deps_hdr = [](int deps_cnt) -> int {
+    return 0x80000000 | ((3 * sizeof(uint32_t)) + (deps_cnt * 4));
+  };
+
+  write_file({
+    "# ninjadeps\n", kCurrentVersion,
+    path_hdr(4), "foo0", ~0, // node #0
+    path_hdr(4), "foo1", ~2, // invalid path ID
+  });
+  DoLoadInvalidLogTest([](State* state, DepsLog* log) {
+    ASSERT_EQ(0, state->LookupNode("foo0")->id());
+    ASSERT_EQ(nullptr, state->LookupNode("foo1"));
+  });
+
+  write_file({
+    "# ninjadeps\n", kCurrentVersion,
+    path_hdr(4), "foo0", ~0, // node #0
+    deps_hdr(1), /*node*/0, /*mtime*/5, 0, /*node*/1, // invalid src ID
+    path_hdr(4), "foo1", ~1, // node #1
+  });
+  DoLoadInvalidLogTest([](State* state, DepsLog* log) {
+    ASSERT_EQ(0, state->LookupNode("foo0")->id());
+    ASSERT_EQ(nullptr, log->GetDeps(state->LookupNode("foo0")));
+    ASSERT_EQ(nullptr, state->LookupNode("foo1"));
+  });
+
+  write_file({
+    "# ninjadeps\n", kCurrentVersion,
+    path_hdr(4), "foo0", ~0, // node #0
+    deps_hdr(1), /*node*/1, /*mtime*/5, 0, /*node*/0, // invalid out ID
+    path_hdr(4), "foo1", ~1, // node #1
+  });
+  DoLoadInvalidLogTest([](State* state, DepsLog* log) {
+    ASSERT_EQ(0, state->LookupNode("foo0")->id());
+    ASSERT_EQ(nullptr, state->LookupNode("foo1"));
+  });
+
+  write_file({
+    "# ninjadeps\n", kCurrentVersion,
+    path_hdr(4), "foo0", ~0, // node #0
+    path_hdr(4), "foo1", ~1, // node #1
+    path_hdr(4), "foo2", ~2, // node #2
+    deps_hdr(1), /*node*/2, /*mtime*/5, 0, /*node*/1,
+    deps_hdr(1), /*node*/2, /*mtime*/6, 0, /*node*/3, // invalid src ID
+
+    // No records after the invalid record are parsed.
+    path_hdr(4), "foo3", ~3, // node #3
+    deps_hdr(1), /*node*/3, /*mtime*/7, 0, /*node*/0,
+    path_hdr(4), "foo4", ~4, // node #4
+    deps_hdr(1), /*node*/4, /*mtime*/8, 0, /*node*/0,
+
+    // Truncation must be handled before looking for the last deps record
+    // that outputs a given node.
+    deps_hdr(1), /*node*/2, /*mtime*/9, 0, /*node*/0,
+    deps_hdr(1), /*node*/2, /*mtime*/9, 0, /*node*/3,
+  });
+  DoLoadInvalidLogTest([](State* state, DepsLog* log) {
+    ASSERT_EQ(0, state->LookupNode("foo0")->id());
+    ASSERT_EQ(1, state->LookupNode("foo1")->id());
+    ASSERT_EQ(2, state->LookupNode("foo2")->id());
+    ASSERT_EQ(nullptr, state->LookupNode("foo3"));
+    ASSERT_EQ(nullptr, state->LookupNode("foo4"));
+
+    ASSERT_EQ(nullptr, log->GetDeps(state->LookupNode("foo1")));
+
+    DepsLog::Deps* deps = log->GetDeps(state->LookupNode("foo2"));
+    ASSERT_EQ(5, deps->mtime);
+    ASSERT_EQ(1, deps->node_count);
+    ASSERT_EQ(1, deps->nodes[0]->id());
+  });
+}
+
 TEST_F(DepsLogTest, MustBeDepsRecordHeader) {
   // Mark a word as a candidate.
   static constexpr uint64_t kCandidate = 0x100000000;