generate version 2 blockimgdiff files

Generate version 2 of the block_image_update transfer list format.
This improves patch size by a different strategy for dealing with
out-of-order transfers.  If transfer A must be done before transfer B
due to B overwriting A's source but we want to do B before A, we
resolve the conflict by:

  - before B is executed, we save ("stash") the overlapping region (ie
    the blocks B will overwrite that A wants to read)

  - when A is executed, it will read those parts of source data from
    the stash rather than from the image.

This reverses the ordering constraint; with these additions now B
*must* go before A.  The implementation of the stash is left up to the
code that executes the transfer list to apply the patch; it could hold
stashed data in RAM or on a scratch disk such as /cache, if available.

The code retains the ability to build a version 1 block image patch;
it's needed for processing older target-files.

Change-Id: Ia9aa0bd45d5dc3ef7c5835e483b1b2ead10135fe
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index 216486c..cf7d7d9 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -16,6 +16,7 @@
 
 from collections import deque, OrderedDict
 from hashlib import sha1
+import heapq
 import itertools
 import multiprocessing
 import os
@@ -142,9 +143,16 @@
     self.goes_before = {}
     self.goes_after = {}
 
+    self.stash_before = []
+    self.use_stash = []
+
     self.id = len(by_id)
     by_id.append(self)
 
+  def NetStashChange(self):
+    return (sum(sr.size() for (_, sr) in self.stash_before) -
+            sum(sr.size() for (_, sr) in self.use_stash))
+
   def __str__(self):
     return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
             " to " + str(self.tgt_ranges) + ">")
@@ -182,11 +190,14 @@
 # original image.
 
 class BlockImageDiff(object):
-  def __init__(self, tgt, src=None, threads=None):
+  def __init__(self, tgt, src=None, threads=None, version=2):
     if threads is None:
       threads = multiprocessing.cpu_count() // 2
       if threads == 0: threads = 1
     self.threads = threads
+    self.version = version
+
+    assert version in (1, 2)
 
     self.tgt = tgt
     if src is None:
@@ -221,7 +232,12 @@
     self.FindVertexSequence()
     # Fix up the ordering dependencies that the sequence didn't
     # satisfy.
-    self.RemoveBackwardEdges()
+    if self.version == 1:
+      self.RemoveBackwardEdges()
+    else:
+      self.ReverseBackwardEdges()
+      self.ImproveVertexSequence()
+
     # Double-check our work.
     self.AssertSequenceGood()
 
@@ -231,18 +247,88 @@
   def WriteTransfers(self, prefix):
     out = []
 
-    out.append("1\n")   # format version number
+    out.append("%d\n" % (self.version,))   # format version number
     total = 0
     performs_read = False
 
+    stashes = {}
+    stashed_blocks = 0
+    max_stashed_blocks = 0
+
+    free_stash_ids = []
+    next_stash_id = 0
+
     for xf in self.transfers:
 
-      # zero [rangeset]
-      # new [rangeset]
-      # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
-      # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
-      # move [src rangeset] [tgt rangeset]
-      # erase [rangeset]
+      if self.version < 2:
+        assert not xf.stash_before
+        assert not xf.use_stash
+
+      for s, sr in xf.stash_before:
+        assert s not in stashes
+        if free_stash_ids:
+          sid = heapq.heappop(free_stash_ids)
+        else:
+          sid = next_stash_id
+          next_stash_id += 1
+        stashes[s] = sid
+        stashed_blocks += sr.size()
+        out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
+
+      if stashed_blocks > max_stashed_blocks:
+        max_stashed_blocks = stashed_blocks
+
+      if self.version == 1:
+        src_string = xf.src_ranges.to_string_raw()
+      elif self.version == 2:
+
+        #   <# blocks> <src ranges>
+        #     OR
+        #   <# blocks> <src ranges> <src locs> <stash refs...>
+        #     OR
+        #   <# blocks> - <stash refs...>
+
+        size = xf.src_ranges.size()
+        src_string = [str(size)]
+
+        unstashed_src_ranges = xf.src_ranges
+        mapped_stashes = []
+        for s, sr in xf.use_stash:
+          sid = stashes.pop(s)
+          stashed_blocks -= sr.size()
+          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
+          sr = xf.src_ranges.map_within(sr)
+          mapped_stashes.append(sr)
+          src_string.append("%d:%s" % (sid, sr.to_string_raw()))
+          heapq.heappush(free_stash_ids, sid)
+
+        if unstashed_src_ranges:
+          src_string.insert(1, unstashed_src_ranges.to_string_raw())
+          if xf.use_stash:
+            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
+            src_string.insert(2, mapped_unstashed.to_string_raw())
+            mapped_stashes.append(mapped_unstashed)
+            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
+        else:
+          src_string.insert(1, "-")
+          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
+
+        src_string = " ".join(src_string)
+
+      # both versions:
+      #   zero <rangeset>
+      #   new <rangeset>
+      #   erase <rangeset>
+      #
+      # version 1:
+      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
+      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
+      #   move <src rangeset> <tgt rangeset>
+      #
+      # version 2:
+      #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
+      #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
+      #   move <tgt rangeset> <src_string>
 
       tgt_size = xf.tgt_ranges.size()
 
@@ -255,17 +341,27 @@
         assert xf.tgt_ranges
         assert xf.src_ranges.size() == tgt_size
         if xf.src_ranges != xf.tgt_ranges:
-          out.append("%s %s %s\n" % (
-              xf.style,
-              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
+          if self.version == 1:
+            out.append("%s %s %s\n" % (
+                xf.style,
+                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
+          elif self.version == 2:
+            out.append("%s %s %s\n" % (
+                xf.style,
+                xf.tgt_ranges.to_string_raw(), src_string))
           total += tgt_size
       elif xf.style in ("bsdiff", "imgdiff"):
         performs_read = True
         assert xf.tgt_ranges
         assert xf.src_ranges
-        out.append("%s %d %d %s %s\n" % (
-            xf.style, xf.patch_start, xf.patch_len,
-            xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
+        if self.version == 1:
+          out.append("%s %d %d %s %s\n" % (
+              xf.style, xf.patch_start, xf.patch_len,
+              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
+        elif self.version == 2:
+          out.append("%s %d %d %s %s\n" % (
+              xf.style, xf.patch_start, xf.patch_len,
+              xf.tgt_ranges.to_string_raw(), src_string))
         total += tgt_size
       elif xf.style == "zero":
         assert xf.tgt_ranges
@@ -277,6 +373,15 @@
         raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
 
     out.insert(1, str(total) + "\n")
+    if self.version >= 2:
+      # version 2 only: after the total block count, we give the number
+      # of stash slots needed, and the maximum size needed (in blocks)
+      out.insert(2, str(next_stash_id) + "\n")
+      out.insert(3, str(max_stashed_blocks) + "\n")
+
+      # sanity check: abort if we're going to need more than 512 MB if
+      # stash space
+      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
 
     all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
     if performs_read:
@@ -295,6 +400,10 @@
       for i in out:
         f.write(i)
 
+    if self.version >= 2:
+      print("max stashed blocks: %d  (%d bytes)\n" % (
+          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
+
   def ComputePatches(self, prefix):
     print("Reticulating splines...")
     diff_q = []
@@ -409,7 +518,13 @@
     # Imagine processing the transfers in order.
     for xf in self.transfers:
       # Check that the input blocks for this transfer haven't yet been touched.
-      assert not touched.overlaps(xf.src_ranges)
+
+      x = xf.src_ranges
+      if self.version >= 2:
+        for _, sr in xf.use_stash:
+          x = x.subtract(sr)
+
+      assert not touched.overlaps(x)
       # Check that the output blocks for this transfer haven't yet been touched.
       assert not touched.overlaps(xf.tgt_ranges)
       # Touch all the blocks written by this transfer.
@@ -418,6 +533,47 @@
     # Check that we've written every target block.
     assert touched == self.tgt.care_map
 
+  def ImproveVertexSequence(self):
+    print("Improving vertex order...")
+
+    # At this point our digraph is acyclic; we reversed any edges that
+    # were backwards in the heuristically-generated sequence.  The
+    # previously-generated order is still acceptable, but we hope to
+    # find a better order that needs less memory for stashed data.
+    # Now we do a topological sort to generate a new vertex order,
+    # using a greedy algorithm to choose which vertex goes next
+    # whenever we have a choice.
+
+    # Make a copy of the edge set; this copy will get destroyed by the
+    # algorithm.
+    for xf in self.transfers:
+      xf.incoming = xf.goes_after.copy()
+      xf.outgoing = xf.goes_before.copy()
+
+    L = []   # the new vertex order
+
+    # S is the set of sources in the remaining graph; we always choose
+    # the one that leaves the least amount of stashed data after it's
+    # executed.
+    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
+         if not u.incoming]
+    heapq.heapify(S)
+
+    while S:
+      _, _, xf = heapq.heappop(S)
+      L.append(xf)
+      for u in xf.outgoing:
+        del u.incoming[xf]
+        if not u.incoming:
+          heapq.heappush(S, (u.NetStashChange(), u.order, u))
+
+    # if this fails then our graph had a cycle.
+    assert len(L) == len(self.transfers)
+
+    self.transfers = L
+    for i, xf in enumerate(L):
+      xf.order = i
+
   def RemoveBackwardEdges(self):
     print("Removing backward edges...")
     in_order = 0
@@ -425,19 +581,17 @@
     lost_source = 0
 
     for xf in self.transfers:
-      io = 0
-      ooo = 0
       lost = 0
       size = xf.src_ranges.size()
       for u in xf.goes_before:
         # xf should go before u
         if xf.order < u.order:
           # it does, hurray!
-          io += 1
+          in_order += 1
         else:
           # it doesn't, boo.  trim the blocks that u writes from xf's
           # source, so that xf can go after u.
-          ooo += 1
+          out_of_order += 1
           assert xf.src_ranges.overlaps(u.tgt_ranges)
           xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
           xf.intact = False
@@ -448,8 +602,6 @@
 
       lost = size - xf.src_ranges.size()
       lost_source += lost
-      in_order += io
-      out_of_order += ooo
 
     print(("  %d/%d dependencies (%.2f%%) were violated; "
            "%d source blocks removed.") %
@@ -458,6 +610,48 @@
            if (in_order + out_of_order) else 0.0,
            lost_source))
 
+  def ReverseBackwardEdges(self):
+    print("Reversing backward edges...")
+    in_order = 0
+    out_of_order = 0
+    stashes = 0
+    stash_size = 0
+
+    for xf in self.transfers:
+      lost = 0
+      size = xf.src_ranges.size()
+      for u in xf.goes_before.copy():
+        # xf should go before u
+        if xf.order < u.order:
+          # it does, hurray!
+          in_order += 1
+        else:
+          # it doesn't, boo.  modify u to stash the blocks that it
+          # writes that xf wants to read, and then require u to go
+          # before xf.
+          out_of_order += 1
+
+          overlap = xf.src_ranges.intersect(u.tgt_ranges)
+          assert overlap
+
+          u.stash_before.append((stashes, overlap))
+          xf.use_stash.append((stashes, overlap))
+          stashes += 1
+          stash_size += overlap.size()
+
+          # reverse the edge direction; now xf must go after u
+          del xf.goes_before[u]
+          del u.goes_after[xf]
+          xf.goes_after[u] = None    # value doesn't matter
+          u.goes_before[xf] = None
+
+    print(("  %d/%d dependencies (%.2f%%) were violated; "
+           "%d source blocks stashed.") %
+          (out_of_order, in_order + out_of_order,
+           (out_of_order * 100.0 / (in_order + out_of_order))
+           if (in_order + out_of_order) else 0.0,
+           stash_size))
+
   def FindVertexSequence(self):
     print("Finding vertex sequence...")
 
diff --git a/tools/releasetools/common.py b/tools/releasetools/common.py
index 815c76c..96075a9 100644
--- a/tools/releasetools/common.py
+++ b/tools/releasetools/common.py
@@ -1030,7 +1030,14 @@
     self.partition = partition
     self.check_first_block = check_first_block
 
-    b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads)
+    version = 1
+    if OPTIONS.info_dict:
+      version = max(
+          int(i) for i in
+          OPTIONS.info_dict.get("blockimgdiff_versions", "1").split(","))
+
+    b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads,
+                                    version=version)
     tmpdir = tempfile.mkdtemp()
     OPTIONS.tempfiles.append(tmpdir)
     self.path = os.path.join(tmpdir, partition)
diff --git a/tools/releasetools/rangelib.py b/tools/releasetools/rangelib.py
index 8a85d2d..b83396c 100644
--- a/tools/releasetools/rangelib.py
+++ b/tools/releasetools/rangelib.py
@@ -173,3 +173,39 @@
       else:
         total -= p
     return total
+
+  def map_within(self, other):
+    """'other' should be a subset of 'self'.  Returns a RangeSet
+    representing what 'other' would get translated to if the integers
+    of 'self' were translated down to be contiguous starting at zero.
+
+    >>> RangeSet.parse("0-9").map_within(RangeSet.parse("3-4")).to_string()
+    '3-4'
+    >>> RangeSet.parse("10-19").map_within(RangeSet.parse("13-14")).to_string()
+    '3-4'
+    >>> RangeSet.parse("10-19 30-39").map_within(
+    ...     RangeSet.parse("17-19 30-32")).to_string()
+    '7-12'
+    >>> RangeSet.parse("10-19 30-39").map_within(
+    ...     RangeSet.parse("12-13 17-19 30-32")).to_string()
+    '2-3 7-12'
+    """
+
+    out = []
+    offset = 0
+    start = None
+    for p, d in heapq.merge(zip(self.data, itertools.cycle((-5, +5))),
+                            zip(other.data, itertools.cycle((-1, +1)))):
+      if d == -5:
+        start = p
+      elif d == +5:
+        offset += p-start
+        start = None
+      else:
+        out.append(offset + p - start)
+    return RangeSet(data=out)
+
+
+if __name__ == "__main__":
+  import doctest
+  doctest.testmod()