Merge "Parallel calculation of v2 signatures"
am: 8ab05bf25c
Change-Id: Ibd655ee6b6e9bf6d5e07df1aa04e88f656860b2b
diff --git a/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java b/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
index 556c643..9444702 100644
--- a/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
+++ b/src/main/java/com/android/apksig/internal/apk/ApkSigningBlockUtils.java
@@ -57,12 +57,18 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
public class ApkSigningBlockUtils {
private static final char[] HEX_DIGITS = "01234567890abcdef".toCharArray();
- private static final int CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES = 1024 * 1024;
+ private static final long CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES = 1024 * 1024;
public static final int ANDROID_COMMON_PAGE_ALIGNMENT_BYTES = 4096;
public static final byte[] APK_SIGNING_BLOCK_MAGIC =
new byte[] {
@@ -409,7 +415,8 @@
.filter(a -> a == ContentDigestAlgorithm.CHUNKED_SHA256 ||
a == ContentDigestAlgorithm.CHUNKED_SHA512)
.collect(Collectors.toSet());
- computeOneMbChunkContentDigests(oneMbChunkBasedAlgorithm,
+ computeOneMbChunkContentDigestsMultithread(
+ oneMbChunkBasedAlgorithm,
new DataSource[] { beforeCentralDir, centralDir, eocd },
contentDigests);
@@ -419,7 +426,7 @@
return contentDigests;
}
- private static void computeOneMbChunkContentDigests(
+ static void computeOneMbChunkContentDigests(
Set<ContentDigestAlgorithm> digestAlgorithms,
DataSource[] contents,
Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
@@ -522,6 +529,236 @@
}
}
+ static void computeOneMbChunkContentDigestsMultithread(
+ Set<ContentDigestAlgorithm> digestAlgorithms,
+ DataSource[] contents,
+ Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
+ throws NoSuchAlgorithmException, DigestException {
+ ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
+ computeOneMbChunkContentDigestsMultithread(
+ digestAlgorithms,
+ contents,
+ outputContentDigests,
+ forkJoinPool::submit,
+ forkJoinPool.getParallelism());
+ forkJoinPool.shutdown();
+ }
+
+ private static void computeOneMbChunkContentDigestsMultithread(
+ Set<ContentDigestAlgorithm> digestAlgorithms,
+ DataSource[] contents,
+ Map<ContentDigestAlgorithm, byte[]> outputContentDigests,
+ Function<Runnable, Future<?>> jobRunner,
+ int jobCount)
+ throws NoSuchAlgorithmException, DigestException {
+ long chunkCountLong = 0;
+ for (DataSource input : contents) {
+ chunkCountLong +=
+ getChunkCount(input.size(), CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+ }
+ if (chunkCountLong > Integer.MAX_VALUE) {
+ throw new DigestException("Input too long: " + chunkCountLong + " chunks");
+ }
+ int chunkCount = (int) chunkCountLong;
+
+ List<ChunkDigests> chunkDigestsList = new ArrayList<>(digestAlgorithms.size());
+ for (ContentDigestAlgorithm algorithms : digestAlgorithms) {
+ chunkDigestsList.add(new ChunkDigests(algorithms, chunkCount));
+ }
+
+ List<Future<?>> jobs = new ArrayList<>(jobCount);
+ ChunkSupplier chunkSupplier = new ChunkSupplier(contents);
+ for (int i = 0; i < jobCount; i++) {
+ jobs.add(jobRunner.apply(new ChunkDigester(chunkSupplier, chunkDigestsList)));
+ }
+
+ try {
+ for (Future<?> future : jobs) {
+ future.get();
+ }
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+
+ // Compute and write out final digest for each algorithm.
+ for (ChunkDigests chunkDigests : chunkDigestsList) {
+ MessageDigest messageDigest = chunkDigests.createMessageDigest();
+ outputContentDigests.put(
+ chunkDigests.algorithm,
+ messageDigest.digest(chunkDigests.concatOfDigestsOfChunks));
+ }
+ }
+
+ private static class ChunkDigests {
+ private ContentDigestAlgorithm algorithm;
+ private int digestOutputSize;
+ private byte[] concatOfDigestsOfChunks;
+
+ private ChunkDigests(ContentDigestAlgorithm algorithm, int chunkCount) {
+ this.algorithm = algorithm;
+ digestOutputSize = this.algorithm.getChunkDigestOutputSizeBytes();
+ concatOfDigestsOfChunks = new byte[1 + 4 + chunkCount * digestOutputSize];
+
+ // Fill the initial values of the concatenated digests of chunks, which is
+ // {0x5a, 4-bytes-of-little-endian-chunk-count, digests*...}.
+ concatOfDigestsOfChunks[0] = 0x5a;
+ setUnsignedInt32LittleEndian(chunkCount, concatOfDigestsOfChunks, 1);
+ }
+
+ private MessageDigest createMessageDigest() throws NoSuchAlgorithmException {
+ return MessageDigest.getInstance(algorithm.getJcaMessageDigestAlgorithm());
+ }
+
+ private int getOffset(int chunkIndex) {
+ return 1 + 4 + chunkIndex * digestOutputSize;
+ }
+ }
+
+ /**
+ * A per-thread digest worker.
+ */
+ private static class ChunkDigester implements Runnable {
+ private final ChunkSupplier dataSupplier;
+ private final List<ChunkDigests> chunkDigests;
+ private final List<MessageDigest> messageDigests;
+ private final DataSink mdSink;
+
+ private ChunkDigester(ChunkSupplier dataSupplier, List<ChunkDigests> chunkDigests)
+ throws NoSuchAlgorithmException {
+ this.dataSupplier = dataSupplier;
+ this.chunkDigests = chunkDigests;
+ messageDigests = new ArrayList<>(chunkDigests.size());
+ for (ChunkDigests chunkDigest : chunkDigests) {
+ messageDigests.add(chunkDigest.createMessageDigest());
+ }
+ mdSink = DataSinks.asDataSink(messageDigests.toArray(new MessageDigest[0]));
+ }
+
+ @Override
+ public void run() {
+ byte[] chunkContentPrefix = new byte[5];
+ chunkContentPrefix[0] = (byte) 0xa5;
+
+ try {
+ for (ChunkSupplier.Chunk chunk = dataSupplier.get();
+ chunk != null;
+ chunk = dataSupplier.get()) {
+ long size = chunk.dataSource.size();
+ if (size > CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES) {
+ throw new RuntimeException("Chunk size greater than expected: " + size);
+ }
+
+ // First update with the chunk prefix.
+ setUnsignedInt32LittleEndian((int)size, chunkContentPrefix, 1);
+ mdSink.consume(chunkContentPrefix, 0, chunkContentPrefix.length);
+
+ // Then update with the chunk data.
+ chunk.dataSource.feed(0, size, mdSink);
+
+ // Now finalize chunk for all algorithms.
+ for (int i = 0; i < chunkDigests.size(); i++) {
+ ChunkDigests chunkDigest = chunkDigests.get(i);
+ int actualDigestSize = messageDigests.get(i).digest(
+ chunkDigest.concatOfDigestsOfChunks,
+ chunkDigest.getOffset(chunk.chunkIndex),
+ chunkDigest.digestOutputSize);
+ if (actualDigestSize != chunkDigest.digestOutputSize) {
+ throw new RuntimeException(
+ "Unexpected output size of " + chunkDigest.algorithm
+ + " digest: " + actualDigestSize);
+ }
+ }
+ }
+ } catch (IOException | DigestException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ /**
+ * Thread-safe 1MB DataSource chunk supplier. When bounds are met in a
+ * supplied {@link DataSource}, the data from the next {@link DataSource}
+ * are NOT concatenated. Only the next call to get() will fetch from the
+ * next {@link DataSource} in the input {@link DataSource} array.
+ */
+ private static class ChunkSupplier implements Supplier<ChunkSupplier.Chunk> {
+ private final DataSource[] dataSources;
+ private final int[] chunkCounts;
+ private final int totalChunkCount;
+ private final AtomicInteger nextIndex;
+
+ private ChunkSupplier(DataSource[] dataSources) {
+ this.dataSources = dataSources;
+ chunkCounts = new int[dataSources.length];
+ int totalChunkCount = 0;
+ for (int i = 0; i < dataSources.length; i++) {
+ long chunkCount = getChunkCount(dataSources[i].size(),
+ CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+ if (chunkCount > Integer.MAX_VALUE) {
+ throw new RuntimeException(
+ String.format(
+ "Number of chunks in dataSource[%d] is greater than max int.",
+ i));
+ }
+ chunkCounts[i] = (int)chunkCount;
+ totalChunkCount += chunkCount;
+ }
+ this.totalChunkCount = totalChunkCount;
+ nextIndex = new AtomicInteger(0);
+ }
+
+ /**
+ * We map an integer index to the termination-adjusted dataSources 1MB chunks.
+ * Note that {@link Chunk}s could be less than 1MB, namely the last 1MB-aligned
+ * blocks in each input {@link DataSource} (unless the DataSource itself is
+ * 1MB-aligned).
+ */
+ @Override
+ public ChunkSupplier.Chunk get() {
+ int index = nextIndex.getAndIncrement();
+ if (index < 0 || index >= totalChunkCount) {
+ return null;
+ }
+
+ int dataSourceIndex = 0;
+ int dataSourceChunkOffset = index;
+ for (; dataSourceIndex < dataSources.length; dataSourceIndex++) {
+ if (dataSourceChunkOffset < chunkCounts[dataSourceIndex]) {
+ break;
+ }
+ dataSourceChunkOffset -= chunkCounts[dataSourceIndex];
+ }
+
+ long remainingSize = Math.min(
+ dataSources[dataSourceIndex].size() -
+ dataSourceChunkOffset * CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES,
+ CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES);
+ // Note that slicing may involve its own locking. We may wish to reimplement the
+ // underlying mechanism to get rid of that lock (e.g. ByteBufferDataSource should
+ // probably get reimplemented to a delegate model, such that grabbing a slice
+ // doesn't incur a lock).
+ return new Chunk(
+ dataSources[dataSourceIndex].slice(
+ dataSourceChunkOffset * CONTENT_DIGESTED_CHUNK_MAX_SIZE_BYTES,
+ remainingSize),
+ index);
+ }
+
+ static class Chunk {
+ private final int chunkIndex;
+ private final DataSource dataSource;
+
+ private Chunk(DataSource parentSource, int chunkIndex) {
+ this.chunkIndex = chunkIndex;
+ dataSource = parentSource;
+ }
+ }
+ }
+
private static void computeApkVerityDigest(DataSource beforeCentralDir, DataSource centralDir,
DataSource eocd, Map<ContentDigestAlgorithm, byte[]> outputContentDigests)
throws IOException, NoSuchAlgorithmException {
@@ -545,7 +782,7 @@
outputContentDigests.put(ContentDigestAlgorithm.VERITY_CHUNKED_SHA256, encoded.array());
}
- private static final long getChunkCount(long inputSize, int chunkSize) {
+ private static final long getChunkCount(long inputSize, long chunkSize) {
return (inputSize + chunkSize - 1) / chunkSize;
}
diff --git a/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java b/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java
new file mode 100644
index 0000000..77a8dab
--- /dev/null
+++ b/src/test/java/com/android/apksig/internal/apk/ApkSigningBlockUtilsTest.java
@@ -0,0 +1,76 @@
+package com.android.apksig.internal.apk;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import com.android.apksig.util.DataSource;
+import com.android.apksig.util.DataSources;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ApkSigningBlockUtilsTest {
+ @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ private static int BASE = 255; // Intentionally not power of 2 to test properly
+
+ @Test
+ public void testMultithreadVersionMatchesSinglethreaded() throws Exception {
+ Set<ContentDigestAlgorithm> algos = new HashSet<>(Arrays
+ .asList(ContentDigestAlgorithm.CHUNKED_SHA512));
+ Map<ContentDigestAlgorithm, byte[]> outputContentDigests = new HashMap<>();
+ Map<ContentDigestAlgorithm, byte[]> outputContentDigestsMultithread = new HashMap<>();
+
+ byte[] part1 = new byte[80 * 1024 * 1024 + 12345];
+ for (int i = 0; i < part1.length; ++i) {
+ part1[i] = (byte)(i % BASE);
+ }
+
+ File dataFile = temporaryFolder.newFile("fake.apk");
+
+ try (FileOutputStream fos = new FileOutputStream(dataFile)) {
+ fos.write(part1);
+ }
+ RandomAccessFile raf = new RandomAccessFile(dataFile, "r");
+
+ byte[] part2 = new byte[1_500_000];
+ for (int i = 0; i < part2.length; ++i) {
+ part2[i] = (byte)(i % BASE);
+ }
+ byte[] part3 = new byte[30_000];
+ for (int i = 0; i < part3.length; ++i) {
+ part3[i] = (byte)(i % BASE);
+ }
+
+ DataSource[] dataSource = {
+ DataSources.asDataSource(raf),
+ DataSources.asDataSource(ByteBuffer.wrap(part2)),
+ DataSources.asDataSource(ByteBuffer.wrap(part3)),
+ };
+
+ ApkSigningBlockUtils.computeOneMbChunkContentDigests(
+ algos, dataSource, outputContentDigests);
+
+ ApkSigningBlockUtils.computeOneMbChunkContentDigestsMultithread(
+ algos, dataSource, outputContentDigestsMultithread);
+
+ assertEquals(outputContentDigestsMultithread.keySet(), outputContentDigests.keySet());
+ for (ContentDigestAlgorithm algo : outputContentDigests.keySet()) {
+ byte[] digest1 = outputContentDigestsMultithread.get(algo);
+ byte[] digest2 = outputContentDigests.get(algo);
+ assertArrayEquals(digest1, digest2);
+ }
+ }
+}