Limit the number of parallel preview loadings in the image preview loader.

With ag/23192997 the preview cache size effectively increased the total
number of parallel image requests to a large enough value to DDoS Files
app's content provider i.e. when sharing a large number of images, cache
pre-population requests overlap with the actual preview loadings and
causing some of them to fail to load.

Update ImagePreviewImageLoader to use a Semaphore to limit the total
number of parallel preview loadings.

Fix: 283000541
Test: manual testing
Change-Id: I6152f6e589a8b36a4810d617633017b72202e66f
diff --git a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
index 89b79a0..22dd112 100644
--- a/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
+++ b/java/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoader.kt
@@ -26,28 +26,42 @@
 import androidx.collection.LruCache
 import androidx.lifecycle.Lifecycle
 import androidx.lifecycle.coroutineScope
+import java.util.function.Consumer
 import kotlinx.coroutines.CancellationException
 import kotlinx.coroutines.CompletableDeferred
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.Deferred
 import kotlinx.coroutines.isActive
 import kotlinx.coroutines.launch
-import java.util.function.Consumer
+import kotlinx.coroutines.sync.Semaphore
 
 private const val TAG = "ImagePreviewImageLoader"
 
 /**
- * Implements preview image loading for the content preview UI. Provides requests deduplication and
- * image caching.
+ * Implements preview image loading for the content preview UI. Provides requests deduplication,
+ * image caching, and a limit on the number of parallel loadings.
  */
 @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-class ImagePreviewImageLoader(
+class ImagePreviewImageLoader
+@VisibleForTesting
+constructor(
     private val scope: CoroutineScope,
     thumbnailSize: Int,
     private val contentResolver: ContentResolver,
     cacheSize: Int,
+    // TODO: consider providing a scope with the dispatcher configured with
+    //  [CoroutineDispatcher#limitedParallelism] instead
+    private val contentResolverSemaphore: Semaphore,
 ) : ImageLoader {
 
+    constructor(
+        scope: CoroutineScope,
+        thumbnailSize: Int,
+        contentResolver: ContentResolver,
+        cacheSize: Int,
+        maxSimultaneousRequests: Int = 4
+    ) : this(scope, thumbnailSize, contentResolver, cacheSize, Semaphore(maxSimultaneousRequests))
+
     private val thumbnailSize: Size = Size(thumbnailSize, thumbnailSize)
 
     private val lock = Any()
@@ -103,13 +117,16 @@
             }
     }
 
-    private fun RequestRecord.loadBitmap() {
+    private suspend fun RequestRecord.loadBitmap() {
+        contentResolverSemaphore.acquire()
         val bitmap =
             try {
                 contentResolver.loadThumbnail(uri, thumbnailSize, null)
             } catch (t: Throwable) {
                 Log.d(TAG, "failed to load $uri preview", t)
                 null
+            } finally {
+                contentResolverSemaphore.release()
             }
         complete(bitmap)
     }
@@ -136,4 +153,4 @@
         val deferred: CompletableDeferred<Bitmap?>,
         @GuardedBy("lock") var caching: Boolean
     )
-}
\ No newline at end of file
+}
diff --git a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
index 184401a..6e57c28 100644
--- a/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
+++ b/java/tests/src/com/android/intentresolver/contentpreview/ImagePreviewImageLoaderTest.kt
@@ -27,20 +27,33 @@
 import com.android.intentresolver.anyOrNull
 import com.android.intentresolver.mock
 import com.android.intentresolver.whenever
+import com.google.common.truth.Truth.assertThat
+import java.util.ArrayDeque
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit.MILLISECONDS
+import java.util.concurrent.TimeUnit.SECONDS
+import java.util.concurrent.atomic.AtomicInteger
+import kotlin.coroutines.CoroutineContext
 import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.CoroutineName
 import kotlinx.coroutines.CoroutineStart.UNDISPATCHED
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.Runnable
 import kotlinx.coroutines.async
 import kotlinx.coroutines.coroutineScope
 import kotlinx.coroutines.launch
 import kotlinx.coroutines.plus
+import kotlinx.coroutines.sync.Semaphore
 import kotlinx.coroutines.test.StandardTestDispatcher
 import kotlinx.coroutines.test.TestCoroutineScheduler
 import kotlinx.coroutines.test.UnconfinedTestDispatcher
 import kotlinx.coroutines.test.resetMain
 import kotlinx.coroutines.test.runTest
 import kotlinx.coroutines.test.setMain
+import kotlinx.coroutines.yield
 import org.junit.After
 import org.junit.Before
 import org.junit.Test
@@ -72,7 +85,7 @@
                 lifecycleOwner.lifecycle.coroutineScope + dispatcher,
                 imageSize.width,
                 contentResolver,
-                1,
+                cacheSize = 1,
             )
     }
 
@@ -118,7 +131,7 @@
                 lifecycleOwner.lifecycle.coroutineScope + dispatcher,
                 imageSize.width,
                 contentResolver,
-                1,
+                cacheSize = 1,
             )
         coroutineScope {
             launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -164,7 +177,7 @@
                 lifecycleOwner.lifecycle.coroutineScope + dispatcher,
                 imageSize.width,
                 contentResolver,
-                1
+                cacheSize = 1,
             )
         coroutineScope {
             val deferred = async(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -183,7 +196,7 @@
                 lifecycleOwner.lifecycle.coroutineScope + dispatcher,
                 imageSize.width,
                 contentResolver,
-                1
+                cacheSize = 1,
             )
         coroutineScope {
             launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
@@ -194,4 +207,160 @@
 
         verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
     }
+
+    @Test
+    fun invoke_semaphoreGuardsContentResolverCalls() = runTest {
+        val contentResolver =
+            mock<ContentResolver> {
+                whenever(loadThumbnail(any(), any(), anyOrNull()))
+                    .thenThrow(SecurityException("test"))
+            }
+        val acquireCount = AtomicInteger()
+        val releaseCount = AtomicInteger()
+        val testSemaphore =
+            object : Semaphore {
+                override val availablePermits: Int
+                    get() = error("Unexpected invocation")
+
+                override suspend fun acquire() {
+                    acquireCount.getAndIncrement()
+                }
+
+                override fun tryAcquire(): Boolean {
+                    error("Unexpected invocation")
+                }
+
+                override fun release() {
+                    releaseCount.getAndIncrement()
+                }
+            }
+
+        val testSubject =
+            ImagePreviewImageLoader(
+                lifecycleOwner.lifecycle.coroutineScope + dispatcher,
+                imageSize.width,
+                contentResolver,
+                cacheSize = 1,
+                testSemaphore,
+            )
+        testSubject(uriOne, false)
+
+        verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
+        assertThat(acquireCount.get()).isEqualTo(1)
+        assertThat(releaseCount.get()).isEqualTo(1)
+    }
+
+    @Test
+    fun invoke_semaphoreIsReleasedAfterContentResolverFailure() = runTest {
+        val semaphoreDeferred = CompletableDeferred<Unit>()
+        val releaseCount = AtomicInteger()
+        val testSemaphore =
+            object : Semaphore {
+                override val availablePermits: Int
+                    get() = error("Unexpected invocation")
+
+                override suspend fun acquire() {
+                    semaphoreDeferred.await()
+                }
+
+                override fun tryAcquire(): Boolean {
+                    error("Unexpected invocation")
+                }
+
+                override fun release() {
+                    releaseCount.getAndIncrement()
+                }
+            }
+
+        val testSubject =
+            ImagePreviewImageLoader(
+                lifecycleOwner.lifecycle.coroutineScope + dispatcher,
+                imageSize.width,
+                contentResolver,
+                cacheSize = 1,
+                testSemaphore,
+            )
+        launch(start = UNDISPATCHED) { testSubject(uriOne, false) }
+
+        verify(contentResolver, never()).loadThumbnail(any(), any(), anyOrNull())
+
+        semaphoreDeferred.complete(Unit)
+
+        verify(contentResolver, times(1)).loadThumbnail(uriOne, imageSize, null)
+        assertThat(releaseCount.get()).isEqualTo(1)
+    }
+
+    @Test
+    fun invoke_multipleSimultaneousCalls_limitOnNumberOfSimultaneousOutgoingCallsIsRespected() {
+        val requestCount = 4
+        val thumbnailCallsCdl = CountDownLatch(requestCount)
+        val pendingThumbnailCalls = ArrayDeque<CountDownLatch>()
+        val contentResolver =
+            mock<ContentResolver> {
+                whenever(loadThumbnail(any(), any(), anyOrNull())).thenAnswer {
+                    val latch = CountDownLatch(1)
+                    synchronized(pendingThumbnailCalls) { pendingThumbnailCalls.offer(latch) }
+                    thumbnailCallsCdl.countDown()
+                    latch.await()
+                    bitmap
+                }
+            }
+        val name = "LoadImage"
+        val maxSimultaneousRequests = 2
+        val threadsStartedCdl = CountDownLatch(requestCount)
+        val dispatcher = NewThreadDispatcher(name) { threadsStartedCdl.countDown() }
+        val testSubject =
+            ImagePreviewImageLoader(
+                lifecycleOwner.lifecycle.coroutineScope + dispatcher + CoroutineName(name),
+                imageSize.width,
+                contentResolver,
+                cacheSize = 1,
+                maxSimultaneousRequests,
+            )
+        runTest {
+            repeat(requestCount) {
+                launch { testSubject(Uri.parse("content://org.pkg.app/image-$it.png")) }
+            }
+            yield()
+            // wait for all requests to be dispatched
+            assertThat(threadsStartedCdl.await(5, SECONDS)).isTrue()
+
+            assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse()
+            synchronized(pendingThumbnailCalls) {
+                assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+            }
+
+            pendingThumbnailCalls.poll()?.countDown()
+            assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isFalse()
+            synchronized(pendingThumbnailCalls) {
+                assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+            }
+
+            pendingThumbnailCalls.poll()?.countDown()
+            assertThat(thumbnailCallsCdl.await(100, MILLISECONDS)).isTrue()
+            synchronized(pendingThumbnailCalls) {
+                assertThat(pendingThumbnailCalls.size).isEqualTo(maxSimultaneousRequests)
+            }
+            for (cdl in pendingThumbnailCalls) {
+                cdl.countDown()
+            }
+        }
+    }
+}
+
+private class NewThreadDispatcher(
+    private val coroutineName: String,
+    private val launchedCallback: () -> Unit
+) : CoroutineDispatcher() {
+    override fun isDispatchNeeded(context: CoroutineContext): Boolean = true
+
+    override fun dispatch(context: CoroutineContext, block: Runnable) {
+        Thread {
+                if (coroutineName == context[CoroutineName.Key]?.name) {
+                    launchedCallback()
+                }
+                block.run()
+            }
+            .start()
+    }
 }