Limit the number of parallel preview loadings in the image preview loader.
With ag/23192997 the preview cache size effectively increased the total
number of parallel image requests to a large enough value to DDoS Files
app's content provider i.e. when sharing a large number of images, cache
pre-population requests overlap with the actual preview loadings and
causing some of them to fail to load.
Update ImagePreviewImageLoader to use a Semaphore to limit the total
number of parallel preview loadings.
Fix: 283000541
Test: manual testing
Change-Id: I6152f6e589a8b36a4810d617633017b72202e66f
diff --git a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
index 89b79a0..22dd112 100644
--- a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
+++ b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
@@ -26,28 +26,42 @@
import androidx.collection.LruCache
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.coroutineScope
+import java.util.function.Consumer
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
-import java.util.function.Consumer
+import kotlinx.coroutines.sync.Semaphore
private const val TAG = "ImagePreviewImageLoader"
/**
- * Implements preview image loading for the content preview UI. Provides requests deduplication and
- * image caching.
+ * Implements preview image loading for the content preview UI. Provides requests deduplication,
+ * image caching, and a limit on the number of parallel loadings.
*/
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-class ImagePreviewImageLoader(
+class ImagePreviewImageLoader
+@VisibleForTesting
+constructor(
private val scope: CoroutineScope,
thumbnailSize: Int,
private val contentResolver: ContentResolver,
cacheSize: Int,
+ // TODO: consider providing a scope with the dispatcher configured with
+ // [CoroutineDispatcher#limitedParallelism] instead
+ private val contentResolverSemaphore: Semaphore,
) : ImageLoader {
+ constructor(
+ scope: CoroutineScope,
+ thumbnailSize: Int,
+ contentResolver: ContentResolver,
+ cacheSize: Int,
+ maxSimultaneousRequests: Int = 4
+ ) : this(scope, thumbnailSize, contentResolver, cacheSize, Semaphore(maxSimultaneousRequests))
+
private val thumbnailSize: Size = Size(thumbnailSize, thumbnailSize)
private val lock = Any()
@@ -103,13 +117,16 @@
}
}
- private fun RequestRecord.loadBitmap() {
+ private suspend fun RequestRecord.loadBitmap() {
+ contentResolverSemaphore.acquire()
val bitmap =
try {
contentResolver.loadThumbnail(uri, thumbnailSize, null)
} catch (t: Throwable) {
Log.d(TAG, "failed to load $uri preview", t)
null
+ } finally {
+ contentResolverSemaphore.release()
}
complete(bitmap)
}
@@ -136,4 +153,4 @@
val deferred: CompletableDeferred<Bitmap?>,
@GuardedBy("lock") var caching: Boolean
)
-}
\ No newline at end of file
+}
diff --git a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
index 184401a..6e57c28 100644
--- a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
+++ b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
@@ -27,20 +27,33 @@
import com.android.intentresolver.anyOrNull
import com.android.intentresolver.mock
import com.android.intentresolver.whenever
+import com.google.common.truth.Truth.assertThat
+import java.util.ArrayDeque
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit.MILLISECONDS
+import java.util.concurrent.TimeUnit.SECONDS
+import java.util.concurrent.atomic.AtomicInteger
+import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineStart.UNDISPATCHED
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.Runnable
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.plus
+import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
+import kotlinx.coroutines.yield
import org.junit.After
import org.junit.Before
import org.junit.Test
@@ -72,7 +85,7 @@
lifecycleOwner.lifecycle.coroutineScope + dispatcher,
imageSize.width,
contentResolver,
- 1,
+ cacheSize = 1,
)
}
@@ -118,7 +131,7 @@
lifecycleOwner.lifecycle.coroutineScope + dispatcher,
imageSize.width,
contentResolver,
- 1,
+ cacheSize = 1,
)
coroutineScope {
launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -164,7 +177,7 @@
lifecycleOwner.lifecycle.coroutineScope + dispatcher,
imageSize.width,
contentResolver,
- 1
+ cacheSize = 1,
)
coroutineScope {
val deferred = async(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -183,7 +196,7 @@
lifecycleOwner.lifecycle.coroutineScope + dispatcher,
imageSize.width,
contentResolver,
- 1
+ cacheSize = 1,
)
coroutineScope {
launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -194,4 +207,160 @@
verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
}
+
+ @Test
+ fun invoke_semaphoreGuardsContentResolverCalls() = runTest {
+ val contentResolver =
+ mock<ContentResolver> {
+ whenever(loadThumbnail(any(), any(), anyOrNull()))
+ .thenThrow(SecurityException("test"))
+ }
+ val acquireCount = AtomicInteger()
+ val releaseCount = AtomicInteger()
+ val testSemaphore =
+ object : Semaphore {
+ override val availablePermits: Int
+ get() = error("Unexpected invocation")
+
+ override suspend fun acquire() {
+ acquireCount.getAndIncrement()
+ }
+
+ override fun tryAcquire(): Boolean {
+ error("Unexpected invocation")
+ }
+
+ override fun release() {
+ releaseCount.getAndIncrement()
+ }
+ }
+
+ val testSubject =
+ ImagePreviewImageLoader(
+ lifecycleOwner.lifecycle.coroutineScope + dispatcher,
+ imageSize.width,
+ contentResolver,
+ cacheSize = 1,
+ testSemaphore,
+ )
+ testSubject(uriOne, false)
+
+ verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
+ assertThat(acquireCount.get()).isEqualTo(1)
+ assertThat(releaseCount.get()).isEqualTo(1)
+ }
+
+ @Test
+ fun invoke_semaphoreIsReleasedAfterContentResolverFailure() = runTest {
+ val semaphoreDeferred = CompletableDeferred<Unit>()
+ val releaseCount = AtomicInteger()
+ val testSemaphore =
+ object : Semaphore {
+ override val availablePermits: Int
+ get() = error("Unexpected invocation")
+
+ override suspend fun acquire() {
+ semaphoreDeferred.await()
+ }
+
+ override fun tryAcquire(): Boolean {
+ error("Unexpected invocation")
+ }
+
+ override fun release() {
+ releaseCount.getAndIncrement()
+ }
+ }
+
+ val testSubject =
+ ImagePreviewImageLoader(
+ lifecycleOwner.lifecycle.coroutineScope + dispatcher,
+ imageSize.width,
+ contentResolver,
+ cacheSize = 1,
+ testSemaphore,
+ )
+ launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
+
+ verify(contentResolver, never()).loadThumbnail(any(), any(), anyOrNull())
+
+ semaphoreDeferred.complete(Unit)
+
+ verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
+ assertThat(releaseCount.get()).isEqualTo(1)
+ }
+
+ @Test
+ fun invoke_multipleSimultaneousCalls_limitOnNumberOfSimultaneousOutgoingCallsIsRespected() {
+ val requestCount = 4
+ val thumbnailCallsCdl = CountDownLatch(requestCount)
+ val pendingThumbnailCalls = ArrayDeque<CountDownLatch>()
+ val contentResolver =
+ mock<ContentResolver> {
+ whenever(loadThumbnail(any(), any(), anyOrNull())).thenAnswer {
+ val latch = CountDownLatch(1)
+ synchronized(pendingThumbnailCalls) { pendingThumbnailCalls.offer(latch) }
+ thumbnailCallsCdl.countDown()
+ latch.await()
+ bitmap
+ }
+ }
+ val name = "LoadImage"
+ val maxSimultaneousRequests = 2
+ val threadsStartedCdl = CountDownLatch(requestCount)
+ val dispatcher = NewThreadDispatcher(name) { threadsStartedCdl.countDown() }
+ val testSubject =
+ ImagePreviewImageLoader(
+ lifecycleOwner.lifecycle.coroutineScope + dispatcher + CoroutineName(name),
+ imageSize.width,
+ contentResolver,
+ cacheSize = 1,
+ maxSimultaneousRequests,
+ )
+ runTest {
+ repeat(requestCount) {
+ launch { testSubject(Uri.parse("content://org.pkg.app/image-$it.png")) }
+ }
+ yield()
+ // wait for all requests to be dispatched
+ assertThat(threadsStartedCdl.await(5, SECONDS)).isTrue()
+
+ assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse()
+ synchronized(pendingThumbnailCalls) {
+ assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+ }
+
+ pendingThumbnailCalls.poll()?.countDown()
+ assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse()
+ synchronized(pendingThumbnailCalls) {
+ assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+ }
+
+ pendingThumbnailCalls.poll()?.countDown()
+ assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isTrue()
+ synchronized(pendingThumbnailCalls) {
+ assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+ }
+ for (cdl in pendingThumbnailCalls) {
+ cdl.countDown()
+ }
+ }
+ }
+}
+
+private class NewThreadDispatcher(
+ private val coroutineName: String,
+ private val launchedCallback: () -> Unit
+) : CoroutineDispatcher() {
+ override fun isDispatchNeeded(context: CoroutineContext): Boolean = true
+
+ override fun dispatch(context: CoroutineContext, block: Runnable) {
+ Thread {
+ if (coroutineName == context[CoroutineName.Key]?.name) {
+ launchedCallback()
+ }
+ block.run()
+ }
+ .start()
+ }
}