Merge "Parallel calculation of v2 signatures"
am: 8ab05bf25c

Change-Id: Ibd655ee6b6e9bf6d5e07df1aa04e88f656860b2b
diff --git a/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java b/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
index 556c643..9444702 100644
--- a/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
+++ b/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
@@ -57,12 +57,18 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 public class ApkSigningBlockUtils {
 
     private static final char[] HEX_DIGITS = "01234567890abcdef".toCharArray();
-    private static final int CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES = 1024 * 1024;
+    private static final long CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES = 1024 * 1024;
     public static final int ANDROID_COMMON_PAGE_ALIGNMENT_BYTES = 4096;
     public static final byte[] APK_SIGNING_BLOCK_MAGIC =
           new byte[] {
@@ -409,7 +415,8 @@
                 .filter(a -> a == ContentDigestAlgorithm.CHUNKED_SHA256 ||
                              a == ContentDigestAlgorithm.CHUNKED_SHA512)
                 .collect(Collectors.toSet());
-        computeOneMbChunkContentDigests(oneMbChunkBasedAlgorithm,
+        computeOneMbChunkContentDigestsMultithread(
+                oneMbChunkBasedAlgorithm,
                 new DataSource[] { beforeCentralDir, centralDir, eocd },
                 contentDigests);
 
@@ -419,7 +426,7 @@
         return contentDigests;
     }
 
-    private static void computeOneMbChunkContentDigests(
+    static void computeOneMbChunkContentDigests(
             Set<ContentDigestAlgorithm> digestAlgorithms,
             DataSource[] contents,
             Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
@@ -522,6 +529,236 @@
         }
     }
 
+    static void computeOneMbChunkContentDigestsMultithread(
+            Set<ContentDigestAlgorithm> digestAlgorithms,
+            DataSource[] contents,
+            Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
+            throws NoSuchAlgorithmException, DigestException {
+        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
+        computeOneMbChunkContentDigestsMultithread(
+                digestAlgorithms,
+                contents,
+                outputContentDigests,
+                forkJoinPool::submit,
+                forkJoinPool.getParallelism());
+        forkJoinPool.shutdown();
+    }
+
+    private static void computeOneMbChunkContentDigestsMultithread(
+            Set<ContentDigestAlgorithm> digestAlgorithms,
+            DataSource[] contents,
+            Map<ContentDigestAlgorithm, byte[]> outputContentDigests,
+            Function<Runnable, Future<?>> jobRunner,
+            int jobCount)
+            throws NoSuchAlgorithmException, DigestException {
+        long chunkCountLong = 0;
+        for (DataSource input : contents) {
+            chunkCountLong +=
+                    getChunkCount(input.size(), CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+        }
+        if (chunkCountLong > Integer.MAX_VALUE) {
+            throw new DigestException("Input too long: " + chunkCountLong + " chunks");
+        }
+        int chunkCount = (int) chunkCountLong;
+
+        List<ChunkDigests> chunkDigestsList = new ArrayList<>(digestAlgorithms.size());
+        for (ContentDigestAlgorithm algorithms : digestAlgorithms) {
+            chunkDigestsList.add(new ChunkDigests(algorithms, chunkCount));
+        }
+
+        List<Future<?>> jobs = new ArrayList<>(jobCount);
+        ChunkSupplier chunkSupplier = new ChunkSupplier(contents);
+        for (int i = 0; i < jobCount; i++) {
+            jobs.add(jobRunner.apply(new ChunkDigester(chunkSupplier, chunkDigestsList)));
+        }
+
+        try {
+            for (Future<?> future : jobs) {
+                future.get();
+            }
+        }
+        catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            throw new RuntimeException(e);
+        } catch (ExecutionException e) {
+            throw new RuntimeException(e);
+        }
+
+        // Compute and write out final digest for each algorithm.
+        for (ChunkDigests chunkDigests : chunkDigestsList) {
+            MessageDigest messageDigest = chunkDigests.createMessageDigest();
+            outputContentDigests.put(
+                    chunkDigests.algorithm,
+                    messageDigest.digest(chunkDigests.concatOfDigestsOfChunks));
+        }
+    }
+
+    private static class ChunkDigests {
+        private ContentDigestAlgorithm algorithm;
+        private int digestOutputSize;
+        private byte[] concatOfDigestsOfChunks;
+
+        private ChunkDigests(ContentDigestAlgorithm algorithm, int chunkCount) {
+            this.algorithm = algorithm;
+            digestOutputSize = this.algorithm.getChunkDigestOutputSizeBytes();
+            concatOfDigestsOfChunks = new byte[1 + 4 + chunkCount * digestOutputSize];
+
+            // Fill the initial values of the concatenated digests of chunks, which is
+            // {0x5a, 4-bytes-of-little-endian-chunk-count, digests*...}.
+            concatOfDigestsOfChunks[0] = 0x5a;
+            setUnsignedInt32LittleEndian(chunkCount, concatOfDigestsOfChunks, 1);
+        }
+
+        private MessageDigest createMessageDigest() throws NoSuchAlgorithmException {
+            return MessageDigest.getInstance(algorithm.getJcaMessageDigestAlgorithm());
+        }
+
+        private int getOffset(int chunkIndex) {
+            return 1 + 4 + chunkIndex * digestOutputSize;
+        }
+    }
+
+    /**
+     * A per-thread digest worker.
+     */
+    private static class ChunkDigester implements Runnable {
+        private final ChunkSupplier dataSupplier;
+        private final List<ChunkDigests> chunkDigests;
+        private final List<MessageDigest> messageDigests;
+        private final DataSink mdSink;
+
+        private ChunkDigester(ChunkSupplier dataSupplier, List<ChunkDigests> chunkDigests)
+                throws NoSuchAlgorithmException {
+            this.dataSupplier = dataSupplier;
+            this.chunkDigests = chunkDigests;
+            messageDigests = new ArrayList<>(chunkDigests.size());
+            for (ChunkDigests chunkDigest : chunkDigests) {
+                messageDigests.add(chunkDigest.createMessageDigest());
+            }
+            mdSink = DataSinks.asDataSink(messageDigests.toArray(new MessageDigest[0]));
+        }
+
+        @Override
+        public void run() {
+            byte[] chunkContentPrefix = new byte[5];
+            chunkContentPrefix[0] = (byte) 0xa5;
+
+            try {
+                for (ChunkSupplier.Chunk chunk = dataSupplier.get();
+                     chunk != null;
+                     chunk = dataSupplier.get()) {
+                    long size = chunk.dataSource.size();
+                    if (size > CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES) {
+                        throw new RuntimeException("Chunk size greater than expected: " + size);
+                    }
+
+                    // First update with the chunk prefix.
+                    setUnsignedInt32LittleEndian((int)size, chunkContentPrefix, 1);
+                    mdSink.consume(chunkContentPrefix, 0, chunkContentPrefix.length);
+
+                    // Then update with the chunk data.
+                    chunk.dataSource.feed(0, size, mdSink);
+
+                    // Now finalize chunk for all algorithms.
+                    for (int i = 0; i < chunkDigests.size(); i++) {
+                        ChunkDigests chunkDigest = chunkDigests.get(i);
+                        int actualDigestSize = messageDigests.get(i).digest(
+                                chunkDigest.concatOfDigestsOfChunks,
+                                chunkDigest.getOffset(chunk.chunkIndex),
+                                chunkDigest.digestOutputSize);
+                        if (actualDigestSize != chunkDigest.digestOutputSize) {
+                            throw new RuntimeException(
+                                    "Unexpected output size of " + chunkDigest.algorithm
+                                            + " digest: " + actualDigestSize);
+                        }
+                    }
+                }
+            } catch (IOException | DigestException e) {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
+    /**
+     * Thread-safe 1MB DataSource chunk supplier. When bounds are met in a
+     * supplied {@link DataSource}, the data from the next {@link DataSource}
+     * are NOT concatenated. Only the next call to get() will fetch from the
+     * next {@link DataSource} in the input {@link DataSource} array.
+     */
+    private static class ChunkSupplier implements Supplier<ChunkSupplier.Chunk> {
+        private final DataSource[] dataSources;
+        private final int[] chunkCounts;
+        private final int totalChunkCount;
+        private final AtomicInteger nextIndex;
+
+        private ChunkSupplier(DataSource[] dataSources) {
+            this.dataSources = dataSources;
+            chunkCounts = new int[dataSources.length];
+            int totalChunkCount = 0;
+            for (int i = 0; i < dataSources.length; i++) {
+                long chunkCount = getChunkCount(dataSources[i].size(),
+                        CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+                if (chunkCount > Integer.MAX_VALUE) {
+                    throw new RuntimeException(
+                            String.format(
+                                    "Number of chunks in dataSource[%d] is greater than max int.",
+                                    i));
+                }
+                chunkCounts[i] = (int)chunkCount;
+                totalChunkCount += chunkCount;
+            }
+            this.totalChunkCount = totalChunkCount;
+            nextIndex = new AtomicInteger(0);
+        }
+
+        /**
+         * We map an integer index to the termination-adjusted dataSources 1MB chunks.
+         * Note that {@link Chunk}s could be less than 1MB, namely the last 1MB-aligned
+         * blocks in each input {@link DataSource} (unless the DataSource itself is
+         * 1MB-aligned).
+         */
+        @Override
+        public ChunkSupplier.Chunk get() {
+            int index = nextIndex.getAndIncrement();
+            if (index < 0 || index >= totalChunkCount) {
+                return null;
+            }
+
+            int dataSourceIndex = 0;
+            int dataSourceChunkOffset = index;
+            for (; dataSourceIndex < dataSources.length; dataSourceIndex++) {
+                if (dataSourceChunkOffset < chunkCounts[dataSourceIndex]) {
+                    break;
+                }
+                dataSourceChunkOffset -= chunkCounts[dataSourceIndex];
+            }
+
+            long remainingSize = Math.min(
+                    dataSources[dataSourceIndex].size() -
+                            dataSourceChunkOffset * CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES,
+                    CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+            // Note that slicing may involve its own locking. We may wish to reimplement the
+            // underlying mechanism to get rid of that lock (e.g. ByteBufferDataSource should
+            // probably get reimplemented to a delegate model, such that grabbing a slice
+            // doesn't incur a lock).
+            return new Chunk(
+                    dataSources[dataSourceIndex].slice(
+                            dataSourceChunkOffset * CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES,
+                            remainingSize),
+                    index);
+        }
+
+        static class Chunk {
+            private final int chunkIndex;
+            private final DataSource dataSource;
+
+            private Chunk(DataSource parentSource, int chunkIndex) {
+                this.chunkIndex = chunkIndex;
+                dataSource = parentSource;
+            }
+        }
+    }
+
     private static void computeApkVerityDigest(DataSource beforeCentralDir, DataSource centralDir,
             DataSource eocd, Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
             throws IOException, NoSuchAlgorithmException {
@@ -545,7 +782,7 @@
         outputContentDigests.put(ContentDigestAlgorithm.VERITY_CHUNKED_SHA256, encoded.array());
     }
 
-    private static final long getChunkCount(long inputSize, int chunkSize) {
+    private static final long getChunkCount(long inputSize, long chunkSize) {
         return (inputSize + chunkSize - 1) / chunkSize;
     }
 
diff --git a/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java b/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java
new file mode 100644
index 0000000..77a8dab
--- /dev/null
+++ b/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java
@@ -0,0 +1,76 @@
+package com.android.apksig.internal.apk;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import com.android.apksig.util.DataSource;
+import com.android.apksig.util.DataSources;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ApkSigningBlockUtilsTest {
+    @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+    private static int BASE = 255; // Intentionally not power of 2 to test properly
+
+    @Test
+    public void testMultithreadVersionMatchesSinglethreaded() throws Exception {
+        Set<ContentDigestAlgorithm> algos = new HashSet<>(Arrays
+                .asList(ContentDigestAlgorithm.CHUNKED_SHA512));
+        Map<ContentDigestAlgorithm, byte[]> outputContentDigests = new HashMap<>();
+        Map<ContentDigestAlgorithm, byte[]> outputContentDigestsMultithread = new HashMap<>();
+
+        byte[] part1 = new byte[80 * 1024 * 1024 + 12345];
+        for (int i = 0; i < part1.length; ++i) {
+            part1[i] = (byte)(i % BASE);
+        }
+
+        File dataFile = temporaryFolder.newFile("fake.apk");
+
+        try (FileOutputStream fos = new FileOutputStream(dataFile)) {
+            fos.write(part1);
+        }
+        RandomAccessFile raf = new RandomAccessFile(dataFile, "r");
+
+        byte[] part2 = new byte[1_500_000];
+        for (int i = 0; i < part2.length; ++i) {
+            part2[i] = (byte)(i % BASE);
+        }
+        byte[] part3 = new byte[30_000];
+        for (int i = 0; i < part3.length; ++i) {
+            part3[i] = (byte)(i % BASE);
+        }
+
+        DataSource[] dataSource = {
+                DataSources.asDataSource(raf),
+                DataSources.asDataSource(ByteBuffer.wrap(part2)),
+                DataSources.asDataSource(ByteBuffer.wrap(part3)),
+        };
+
+        ApkSigningBlockUtils.computeOneMbChunkContentDigests(
+                algos, dataSource, outputContentDigests);
+
+        ApkSigningBlockUtils.computeOneMbChunkContentDigestsMultithread(
+                algos, dataSource, outputContentDigestsMultithread);
+
+        assertEquals(outputContentDigestsMultithread.keySet(), outputContentDigests.keySet());
+        for (ContentDigestAlgorithm algo : outputContentDigests.keySet()) {
+            byte[] digest1 = outputContentDigestsMultithread.get(algo);
+            byte[] digest2 = outputContentDigests.get(algo);
+            assertArrayEquals(digest1, digest2);
+        }
+    }
+}