Stream the output to disk.

Now that update_engine no longer uses minor version 1, old file and
new file are always different in the device, so we can write the
output to disk right away to save memory.
The old behavior is still kept if the new file is overlapping with the
old file for paycheck.py usage.

Test: bsdiff_unittest
Test: apply a 3M patch overwriting old file, uses 25M memory.
Test: apply a 3M patch not overwriting old file, uses 9M memory.
Bug: 25773600

Change-Id: Ic289c9bcc6f1810d0f222a9a4241c05756084b1c
diff --git a/Android.mk b/Android.mk
index 2591267..bf24879 100644
--- a/Android.mk
+++ b/Android.mk
@@ -39,11 +39,13 @@
     bspatch.cc \
     extents.cc \
     extents_file.cc \
-    file.cc
+    file.cc \
+    memory_file.cc
 
 # Unit test files.
 bsdiff_common_unittests := \
     bsdiff_unittest.cc \
+    bspatch_unittest.cc \
     extents_file_unittest.cc \
     extents_unittest.cc \
     test_utils.cc
diff --git a/Makefile b/Makefile
index d887364..fbfe56a 100644
--- a/Makefile
+++ b/Makefile
@@ -34,11 +34,13 @@
   bspatch.o \
   extents.o \
   extents_file.o \
-  file.o
+  file.o \
+  memory_file.o
 
-UNITTEST_LIBS = -lgmock -lgtest
+UNITTEST_LIBS = -lgmock -lgtest -lpthread
 UNITTEST_OBJS = \
   bsdiff_unittest.o \
+  bspatch_unittest.o \
   extents_file_unittest.o \
   extents_unittest.o \
   test_utils.o \
@@ -63,6 +65,7 @@
 bspatch.o: bspatch.cc bspatch.h extents.h extents_file.h file_interface.h \
  file.h
 bspatch_main.o: bspatch_main.cc bspatch.h
+bspatch_unittest.o: bspatch_unittest.cc bspatch.h test_utils.h
 extents.o: extents.cc extents.h extents_file.h file_interface.h
 extents_file.o: extents_file.cc extents_file.h file_interface.h
 extents_file_unittest.o: extents_file_unittest.cc extents_file.h \
@@ -70,6 +73,7 @@
 extents_unittest.o: extents_unittest.cc extents.h extents_file.h \
  file_interface.h
 file.o: file.cc file.h file_interface.h
+memory_file.o: memory_file.cc memory_file.h file_interface.h
 testrunner.o: testrunner.cc
 test_utils.o: test_utils.cc test_utils.h
 
diff --git a/bspatch.cc b/bspatch.cc
index aa484cd..aae6118 100644
--- a/bspatch.cc
+++ b/bspatch.cc
@@ -32,21 +32,25 @@
 
 #include <bzlib.h>
 #include <err.h>
+#include <errno.h>
 #include <fcntl.h>
 #include <inttypes.h>
 #include <stdlib.h>
 #include <string.h>
 #include <unistd.h>
+#include <sys/stat.h>
 #include <sys/types.h>
 
 #include <algorithm>
 #include <memory>
 #include <limits>
+#include <vector>
 
 #include "extents.h"
 #include "extents_file.h"
 #include "file.h"
 #include "file_interface.h"
+#include "memory_file.h"
 
 namespace {
 
@@ -75,22 +79,84 @@
   return y;
 }
 
+bool ReadBZ2(BZFILE* pfbz2, uint8_t* data, size_t size) {
+  int bz2err;
+  size_t lenread = BZ2_bzRead(&bz2err, pfbz2, data, size);
+  if (lenread < size || (bz2err != BZ_OK && bz2err != BZ_STREAM_END))
+    return false;
+  return true;
+}
+
+bool ReadBZ2AndWriteAll(const std::unique_ptr<bsdiff::FileInterface>& file,
+                        BZFILE* pfbz2,
+                        size_t size,
+                        uint8_t* buf,
+                        size_t buf_size) {
+  while (size > 0) {
+    size_t bytes_to_read = std::min(size, buf_size);
+    if (!ReadBZ2(pfbz2, buf, bytes_to_read))
+      return false;
+    if (!WriteAll(file, buf, bytes_to_read))
+      return false;
+    size -= bytes_to_read;
+  }
+  return true;
+}
+
 }  // namespace
 
 namespace bsdiff {
 
+bool WriteAll(const std::unique_ptr<FileInterface>& file,
+              const uint8_t* data,
+              size_t size) {
+  size_t offset = 0, written;
+  while (offset < size) {
+    if (!file->Write(data + offset, size - offset, &written))
+      return false;
+    offset += written;
+  }
+  return true;
+}
+
+bool IsOverlapping(const char* old_filename,
+                   const char* new_filename,
+                   const std::vector<ex_t>& old_extents,
+                   const std::vector<ex_t>& new_extents) {
+  struct stat old_stat, new_stat;
+  if (stat(new_filename, &new_stat) == -1) {
+    if (errno == ENOENT)
+      return false;
+    err(1, "Error stat the new filename %s", new_filename);
+  }
+  if (stat(old_filename, &old_stat) == -1)
+    err(1, "Error stat the old filename %s", old_filename);
+
+  if (old_stat.st_dev != new_stat.st_dev || old_stat.st_ino != new_stat.st_ino)
+    return false;
+
+  if (old_extents.empty() && new_extents.empty())
+    return true;
+
+  for (ex_t old_ex : old_extents)
+    for (ex_t new_ex : new_extents)
+      if (static_cast<uint64_t>(old_ex.off) < new_ex.off + new_ex.len &&
+          static_cast<uint64_t>(new_ex.off) < old_ex.off + old_ex.len)
+        return true;
+
+  return false;
+}
+
 int bspatch(
     const char* old_filename, const char* new_filename,
     const char* patch_filename,
     const char* old_extents, const char* new_extents) {
   FILE* f, *cpf, *dpf, *epf;
   BZFILE* cpfbz2, *dpfbz2, *epfbz2;
-  int cbz2err, dbz2err, ebz2err;
+  int bz2err;
   ssize_t bzctrllen, bzdatalen;
   u_char header[32], buf[8];
-  u_char* new_buf;
   off_t ctrl[3];
-  off_t lenread;
 
   int using_extents = (old_extents != NULL || new_extents != NULL);
 
@@ -137,29 +203,29 @@
     err(1, "fopen(%s)", patch_filename);
   if (fseek(cpf, 32, SEEK_SET))
     err(1, "fseeko(%s, %lld)", patch_filename, (long long)32);
-  if ((cpfbz2 = BZ2_bzReadOpen(&cbz2err, cpf, 0, 0, NULL, 0)) == NULL)
-    errx(1, "BZ2_bzReadOpen, bz2err = %d", cbz2err);
+  if ((cpfbz2 = BZ2_bzReadOpen(&bz2err, cpf, 0, 0, NULL, 0)) == NULL)
+    errx(1, "BZ2_bzReadOpen, bz2err = %d", bz2err);
   if ((dpf = fopen(patch_filename, "r")) == NULL)
     err(1, "fopen(%s)", patch_filename);
   if (fseek(dpf, 32 + bzctrllen, SEEK_SET))
     err(1, "fseeko(%s, %lld)", patch_filename, (long long)(32 + bzctrllen));
-  if ((dpfbz2 = BZ2_bzReadOpen(&dbz2err, dpf, 0, 0, NULL, 0)) == NULL)
-    errx(1, "BZ2_bzReadOpen, bz2err = %d", dbz2err);
+  if ((dpfbz2 = BZ2_bzReadOpen(&bz2err, dpf, 0, 0, NULL, 0)) == NULL)
+    errx(1, "BZ2_bzReadOpen, bz2err = %d", bz2err);
   if ((epf = fopen(patch_filename, "r")) == NULL)
     err(1, "fopen(%s)", patch_filename);
   if (fseek(epf, 32 + bzctrllen + bzdatalen, SEEK_SET))
     err(1, "fseeko(%s, %lld)", patch_filename,
         (long long)(32 + bzctrllen + bzdatalen));
-  if ((epfbz2 = BZ2_bzReadOpen(&ebz2err, epf, 0, 0, NULL, 0)) == NULL)
-    errx(1, "BZ2_bzReadOpen, bz2err = %d", ebz2err);
+  if ((epfbz2 = BZ2_bzReadOpen(&bz2err, epf, 0, 0, NULL, 0)) == NULL)
+    errx(1, "BZ2_bzReadOpen, bz2err = %d", bz2err);
 
   // Open input file for reading.
   std::unique_ptr<FileInterface> old_file = File::FOpen(old_filename, O_RDONLY);
   if (!old_file)
     err(1, "Error opening the old filename");
 
+  std::vector<ex_t> parsed_old_extents;
   if (using_extents) {
-    std::vector<ex_t> parsed_old_extents;
     if (!ParseExtentStr(old_extents, &parsed_old_extents))
       errx(1, "Error parsing the old extents");
     old_file.reset(new ExtentsFile(std::move(old_file), parsed_old_extents));
@@ -169,22 +235,38 @@
     err(1, "cannot obtain the size of %s", old_filename);
   uint64_t old_file_pos = 0;
 
-  if ((new_buf = static_cast<u_char*>(malloc(newsize + 1))) == NULL)
-    err(1, NULL);
+  // Open output file for writing.
+  std::unique_ptr<FileInterface> new_file =
+      File::FOpen(new_filename, O_CREAT | O_WRONLY);
+  if (!new_file)
+    err(1, "Error opening the new filename %s", new_filename);
+
+  std::vector<ex_t> parsed_new_extents;
+  if (using_extents) {
+    if (!ParseExtentStr(new_extents, &parsed_new_extents))
+      errx(1, "Error parsing the new extents");
+    new_file.reset(new ExtentsFile(std::move(new_file), parsed_new_extents));
+  }
+
+  if (IsOverlapping(old_filename, new_filename, parsed_old_extents,
+                    parsed_new_extents)) {
+    // New and old file is overlapping, we can not stream output to new file,
+    // cache it in the memory and write to the file at the end.
+    new_file.reset(new MemoryFile(std::move(new_file), newsize));
+  }
 
   // The oldpos can be negative, but the new pos is only incremented linearly.
   int64_t oldpos = 0;
   uint64_t newpos = 0;
-  std::vector<u_char> old_buf(1024 * 1024);
+  std::vector<uint8_t> old_buf(1024 * 1024), new_buf(1024 * 1024);
   while (newpos < newsize) {
-    int64_t i, j;
+    int64_t i;
     // Read control data.
     for (i = 0; i <= 2; i++) {
-      lenread = BZ2_bzRead(&cbz2err, cpfbz2, buf, 8);
-      if ((lenread < 8) || ((cbz2err != BZ_OK) && (cbz2err != BZ_STREAM_END)))
+      if (!ReadBZ2(cpfbz2, buf, 8))
         errx(1, "Corrupt patch\n");
       ctrl[i] = ParseInt64(buf);
-    };
+    }
 
     // Sanity-check.
     if (ctrl[0] < 0 || ctrl[1] < 0)
@@ -194,37 +276,40 @@
     if (newpos + ctrl[0] > newsize)
       errx(1, "Corrupt patch\n");
 
-    // Read diff string.
-    lenread = BZ2_bzRead(&dbz2err, dpfbz2, new_buf + newpos, ctrl[0]);
-    if ((lenread < ctrl[0]) ||
-        ((dbz2err != BZ_OK) && (dbz2err != BZ_STREAM_END)))
-      errx(1, "Corrupt patch\n");
-
     // Add old data to diff string. It is enough to fseek once, at
     // the beginning of the sequence, to avoid unnecessary overhead.
-    j = newpos;
     if ((i = oldpos) < 0) {
-      j -= i;
+      // Write diff block directly to new file without adding old data,
+      // because we will skip part where |oldpos| < 0.
+      if (!ReadBZ2AndWriteAll(new_file, dpfbz2, -i, new_buf.data(),
+                              new_buf.size()))
+        errx(1, "Error during ReadBZ2AndWriteAll()");
+
       i = 0;
     }
+
     // We just checked that |i| is not negative.
     if (static_cast<uint64_t>(i) != old_file_pos && !old_file->Seek(i))
       err(1, "error seeking input file to offset %" PRId64, i);
     if ((old_file_pos = oldpos + ctrl[0]) > oldsize)
       old_file_pos = oldsize;
 
-    uint64_t chunk_size = old_file_pos - i;
+    size_t chunk_size = old_file_pos - i;
     while (chunk_size > 0) {
       size_t read_bytes;
-      size_t bytes_to_read =
-          std::min(chunk_size, static_cast<uint64_t>(old_buf.size()));
+      size_t bytes_to_read = std::min(chunk_size, old_buf.size());
       if (!old_file->Read(old_buf.data(), bytes_to_read, &read_bytes))
         err(1, "error reading from input file");
       if (!read_bytes)
         errx(1, "EOF reached while reading from input file");
+      // Read same amount of bytes from diff block
+      if (!ReadBZ2(dpfbz2, new_buf.data(), read_bytes))
+        errx(1, "Corrupt patch\n");
       // new_buf already has data from diff block, adds old data to it.
       for (size_t k = 0; k < read_bytes; k++)
-        new_buf[j++] += old_buf[k];
+        new_buf[k] += old_buf[k];
+      if (!WriteAll(new_file, new_buf.data(), read_bytes))
+        err(1, "Error writing new file.");
       chunk_size -= read_bytes;
     }
 
@@ -232,58 +317,41 @@
     newpos += ctrl[0];
     oldpos += ctrl[0];
 
+    if (oldpos > static_cast<int64_t>(oldsize)) {
+      // Write diff block directly to new file without adding old data,
+      // because we skipped part where |oldpos| > oldsize.
+      if (!ReadBZ2AndWriteAll(new_file, dpfbz2, oldpos - oldsize,
+                              new_buf.data(), new_buf.size()))
+        errx(1, "Error during ReadBZ2AndWriteAll()");
+    }
+
     // Sanity-check.
     if (newpos + ctrl[1] > newsize)
       errx(1, "Corrupt patch\n");
 
-    // Read extra string.
-    lenread = BZ2_bzRead(&ebz2err, epfbz2, new_buf + newpos, ctrl[1]);
-    if ((lenread < ctrl[1]) ||
-        ((ebz2err != BZ_OK) && (ebz2err != BZ_STREAM_END)))
-      errx(1, "Corrupt patch\n");
+    // Read extra block.
+    if (!ReadBZ2AndWriteAll(new_file, epfbz2, ctrl[1], new_buf.data(),
+                            new_buf.size()))
+      errx(1, "Error during ReadBZ2AndWriteAll()");
 
     // Adjust pointers.
     newpos += ctrl[1];
     oldpos += ctrl[2];
-  };
+  }
 
   // Close input file.
   old_file->Close();
 
   // Clean up the bzip2 reads.
-  BZ2_bzReadClose(&cbz2err, cpfbz2);
-  BZ2_bzReadClose(&dbz2err, dpfbz2);
-  BZ2_bzReadClose(&ebz2err, epfbz2);
+  BZ2_bzReadClose(&bz2err, cpfbz2);
+  BZ2_bzReadClose(&bz2err, dpfbz2);
+  BZ2_bzReadClose(&bz2err, epfbz2);
   if (fclose(cpf) || fclose(dpf) || fclose(epf))
     err(1, "fclose(%s)", patch_filename);
 
-  // Write the new file.
-  std::unique_ptr<FileInterface> new_file =
-      File::FOpen(new_filename, O_CREAT | O_WRONLY);
-  if (!new_file)
-    err(1, "Error opening the new filename %s", new_filename);
-
-  if (using_extents) {
-    std::vector<ex_t> parsed_new_extents;
-    if (!ParseExtentStr(new_extents, &parsed_new_extents))
-      errx(1, "Error parsing the new extents");
-    new_file.reset(new ExtentsFile(std::move(new_file), parsed_new_extents));
-  }
-
-  u_char* temp_new_buf = new_buf;   // new_buf needed for free()
-  while (newsize > 0) {
-    size_t bytes_written;
-    if (!new_file->Write(temp_new_buf, newsize, &bytes_written))
-      err(1, "Error writing new file %s", new_filename);
-    newsize -= bytes_written;
-    temp_new_buf += bytes_written;
-  }
-
   if (!new_file->Close())
     err(1, "Error closing new file %s", new_filename);
 
-  free(new_buf);
-
   return 0;
 }
 
diff --git a/bspatch.h b/bspatch.h
index 269074e..3de2874 100644
--- a/bspatch.h
+++ b/bspatch.h
@@ -5,6 +5,11 @@
 #ifndef _BSDIFF_BSPATCH_H_
 #define _BSDIFF_BSPATCH_H_
 
+#include <memory>
+#include <vector>
+
+#include "extents_file.h"
+
 namespace bsdiff {
 
 int bspatch(const char* old_filename,
@@ -13,6 +18,15 @@
             const char* old_extents,
             const char* new_extents);
 
+bool WriteAll(const std::unique_ptr<FileInterface>& file,
+              const uint8_t* data,
+              size_t size);
+
+bool IsOverlapping(const char* old_filename,
+                   const char* new_filename,
+                   const std::vector<ex_t>& old_extents,
+                   const std::vector<ex_t>& new_extents);
+
 }  // namespace bsdiff
 
 #endif  // _BSDIFF_BSPATCH_H_
diff --git a/bspatch_unittest.cc b/bspatch_unittest.cc
new file mode 100644
index 0000000..04ec666
--- /dev/null
+++ b/bspatch_unittest.cc
@@ -0,0 +1,42 @@
+// Copyright 2016 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "bspatch.h"
+
+#include <unistd.h>
+
+#include <gtest/gtest.h>
+
+#include "test_utils.h"
+
+namespace bsdiff {
+
+class BspatchTest : public testing::Test {
+ protected:
+  BspatchTest()
+      : old_file_("bsdiff_oldfile.XXXXXX"),
+        new_file_("bsdiff_newfile.XXXXXX") {}
+
+  test_utils::ScopedTempFile old_file_;
+  test_utils::ScopedTempFile new_file_;
+};
+
+TEST_F(BspatchTest, IsOverlapping) {
+  const char* old_filename = old_file_.c_str();
+  const char* new_filename = new_file_.c_str();
+  EXPECT_FALSE(IsOverlapping(old_filename, "does-not-exist", {}, {}));
+  EXPECT_FALSE(IsOverlapping(old_filename, new_filename, {}, {}));
+  EXPECT_EQ(0, unlink(new_filename));
+  EXPECT_EQ(0, symlink(old_filename, new_filename));
+  EXPECT_TRUE(IsOverlapping(old_filename, new_filename, {}, {}));
+  EXPECT_TRUE(IsOverlapping(old_filename, old_filename, {}, {}));
+  EXPECT_FALSE(IsOverlapping(old_filename, old_filename, {{0, 1}}, {{1, 1}}));
+  EXPECT_FALSE(IsOverlapping(old_filename, old_filename, {{2, 1}}, {{1, 1}}));
+  EXPECT_TRUE(IsOverlapping(old_filename, old_filename, {{0, 1}}, {{0, 1}}));
+  EXPECT_TRUE(IsOverlapping(old_filename, old_filename, {{0, 4}}, {{2, 1}}));
+  EXPECT_TRUE(IsOverlapping(old_filename, old_filename, {{1, 1}}, {{0, 2}}));
+  EXPECT_TRUE(IsOverlapping(old_filename, old_filename, {{3, 2}}, {{2, 2}}));
+}
+
+}  // namespace bsdiff
diff --git a/memory_file.cc b/memory_file.cc
new file mode 100644
index 0000000..996a10e
--- /dev/null
+++ b/memory_file.cc
@@ -0,0 +1,46 @@
+// Copyright 2016 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "memory_file.h"
+
+#include "bspatch.h"
+
+namespace bsdiff {
+
+MemoryFile::MemoryFile(std::unique_ptr<FileInterface> file, size_t size)
+    : file_(std::move(file)) {
+  buffer_.reserve(size);
+}
+
+MemoryFile::~MemoryFile() {
+  Close();
+}
+
+bool MemoryFile::Read(void* buf, size_t count, size_t* bytes_read) {
+  return false;
+}
+
+bool MemoryFile::Write(const void* buf, size_t count, size_t* bytes_written) {
+  const uint8_t* data = static_cast<const uint8_t*>(buf);
+  buffer_.insert(buffer_.end(), data, data + count);
+  *bytes_written = count;
+  return true;
+}
+
+bool MemoryFile::Seek(off_t pos) {
+  return false;
+}
+
+bool MemoryFile::Close() {
+  if (!WriteAll(file_, buffer_.data(), buffer_.size()))
+    return false;
+  return file_->Close();
+}
+
+bool MemoryFile::GetSize(uint64_t* size) {
+  *size = buffer_.size();
+  return true;
+}
+
+}  // namespace bsdiff
diff --git a/memory_file.h b/memory_file.h
new file mode 100644
index 0000000..3e80b8a
--- /dev/null
+++ b/memory_file.h
@@ -0,0 +1,42 @@
+// Copyright 2016 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef _BSDIFF_MEMORY_FILE_H_
+#define _BSDIFF_MEMORY_FILE_H_
+
+#include <memory>
+#include <vector>
+
+#include "file_interface.h"
+
+namespace bsdiff {
+
+class MemoryFile : public FileInterface {
+ public:
+  // Creates a MemoryFile based on the underlying |file| passed. The MemoryFile
+  // will cache all the write in memory and write it to to |file| when it's
+  // closed. MemoryFile does not support read and seek.
+  // |size| should be the estimated total file size, it is used to reserve
+  // buffer space.
+  MemoryFile(std::unique_ptr<FileInterface> file, size_t size);
+
+  ~MemoryFile() override;
+
+  // FileInterface overrides.
+  bool Read(void* buf, size_t count, size_t* bytes_read) override;
+  bool Write(const void* buf, size_t count, size_t* bytes_written) override;
+  bool Seek(off_t pos) override;
+  bool Close() override;
+  bool GetSize(uint64_t* size) override;
+
+ private:
+  // The underlying FileInterace instance.
+  std::unique_ptr<FileInterface> file_;
+
+  std::vector<uint8_t> buffer_;
+};
+
+}  // namespace bsdiff
+
+#endif  // _BSDIFF_MEMORY_FILE_H_