Filter out notif updates based on active media session

When an app is casting there can be two active media sessions, one local
and one remote. In this situation, there should only be a media object
for the remote session in QS. To achieve this, filter out updates for
the local session.

Bug: 158604810
Bug: 158813341
Bug: 158652134
Test: manual - cast from iHeartRadio. Toggle play/pause from media
object and check that device shown in output switcher isn't rapidly
changing.
Test: manual - cast from LivePhish. Check that only a single media
object appears in QS.

Change-Id: Ia7b8d2f3df3e68f63049e3443315fc26811efec4
diff --git a/packages/SystemUI/src/com/android/systemui/dagger/SystemServicesModule.java b/packages/SystemUI/src/com/android/systemui/dagger/SystemServicesModule.java
index 251ce13..4a5d142 100644
--- a/packages/SystemUI/src/com/android/systemui/dagger/SystemServicesModule.java
+++ b/packages/SystemUI/src/com/android/systemui/dagger/SystemServicesModule.java
@@ -43,6 +43,7 @@
 import android.hardware.display.DisplayManager;
 import android.media.AudioManager;
 import android.media.MediaRouter2Manager;
+import android.media.session.MediaSessionManager;
 import android.net.ConnectivityManager;
 import android.net.NetworkScoreManager;
 import android.net.wifi.WifiManager;
@@ -226,6 +227,11 @@
     }
 
     @Provides
+    static MediaSessionManager provideMediaSessionManager(Context context) {
+        return context.getSystemService(MediaSessionManager.class);
+    }
+
+    @Provides
     @Singleton
     static NetworkScoreManager provideNetworkScoreManager(Context context) {
         return context.getSystemService(NetworkScoreManager.class);
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaCarouselController.kt b/packages/SystemUI/src/com/android/systemui/media/MediaCarouselController.kt
index b993faa..b13be7b 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaCarouselController.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaCarouselController.kt
@@ -42,7 +42,7 @@
     private val mediaHostStatesManager: MediaHostStatesManager,
     private val activityStarter: ActivityStarter,
     @Main executor: DelayableExecutor,
-    mediaManager: MediaDataFilter,
+    mediaManager: MediaDataManager,
     configurationController: ConfigurationController,
     falsingManager: FalsingManager
 ) {
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaDataCombineLatest.kt b/packages/SystemUI/src/com/android/systemui/media/MediaDataCombineLatest.kt
index d0642cc..aa3699e 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaDataCombineLatest.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaDataCombineLatest.kt
@@ -17,65 +17,48 @@
 package com.android.systemui.media
 
 import javax.inject.Inject
-import javax.inject.Singleton
 
 /**
- * Combines updates from [MediaDataManager] with [MediaDeviceManager].
+ * Combines [MediaDataManager.Listener] events with [MediaDeviceManager.Listener] events.
  */
-@Singleton
-class MediaDataCombineLatest @Inject constructor(
-    private val dataSource: MediaDataManager,
-    private val deviceSource: MediaDeviceManager
-) {
+class MediaDataCombineLatest @Inject constructor() : MediaDataManager.Listener,
+        MediaDeviceManager.Listener {
+
     private val listeners: MutableSet<MediaDataManager.Listener> = mutableSetOf()
     private val entries: MutableMap<String, Pair<MediaData?, MediaDeviceData?>> = mutableMapOf()
 
-    init {
-        dataSource.addListener(object : MediaDataManager.Listener {
-            override fun onMediaDataLoaded(key: String, oldKey: String?, data: MediaData) {
-                if (oldKey != null && oldKey != key && entries.contains(oldKey)) {
-                    entries[key] = data to entries.remove(oldKey)?.second
-                    update(key, oldKey)
-                } else {
-                    entries[key] = data to entries[key]?.second
-                    update(key, key)
-                }
-            }
-            override fun onMediaDataRemoved(key: String) {
-                remove(key)
-            }
-        })
-        deviceSource.addListener(object : MediaDeviceManager.Listener {
-            override fun onMediaDeviceChanged(
-                key: String,
-                oldKey: String?,
-                data: MediaDeviceData?
-            ) {
-                if (oldKey != null && oldKey != key && entries.contains(oldKey)) {
-                    entries[key] = entries.remove(oldKey)?.first to data
-                    update(key, oldKey)
-                } else {
-                    entries[key] = entries[key]?.first to data
-                    update(key, key)
-                }
-            }
-            override fun onKeyRemoved(key: String) {
-                remove(key)
-            }
-        })
+    override fun onMediaDataLoaded(key: String, oldKey: String?, data: MediaData) {
+        if (oldKey != null && oldKey != key && entries.contains(oldKey)) {
+            entries[key] = data to entries.remove(oldKey)?.second
+            update(key, oldKey)
+        } else {
+            entries[key] = data to entries[key]?.second
+            update(key, key)
+        }
     }
 
-    /**
-     * Get a map of all non-null data entries
-     */
-    fun getData(): Map<String, MediaData> {
-        return entries.filter {
-            (key, pair) -> pair.first != null && pair.second != null
-        }.mapValues {
-            (key, pair) -> pair.first!!.copy(device = pair.second)
+    override fun onMediaDataRemoved(key: String) {
+        remove(key)
+    }
+
+    override fun onMediaDeviceChanged(
+        key: String,
+        oldKey: String?,
+        data: MediaDeviceData?
+    ) {
+        if (oldKey != null && oldKey != key && entries.contains(oldKey)) {
+            entries[key] = entries.remove(oldKey)?.first to data
+            update(key, oldKey)
+        } else {
+            entries[key] = entries[key]?.first to data
+            update(key, key)
         }
     }
 
+    override fun onKeyRemoved(key: String) {
+        remove(key)
+    }
+
     /**
      * Add a listener for [MediaData] changes that has been combined with latest [MediaDeviceData].
      */
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaDataFilter.kt b/packages/SystemUI/src/com/android/systemui/media/MediaDataFilter.kt
index 24ca970..0664a41 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaDataFilter.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaDataFilter.kt
@@ -24,7 +24,6 @@
 import com.android.systemui.statusbar.NotificationLockscreenUserManager
 import java.util.concurrent.Executor
 import javax.inject.Inject
-import javax.inject.Singleton
 
 private const val TAG = "MediaDataFilter"
 private const val DEBUG = true
@@ -33,24 +32,24 @@
  * Filters data updates from [MediaDataCombineLatest] based on the current user ID, and handles user
  * switches (removing entries for the previous user, adding back entries for the current user)
  *
- * This is added downstream of [MediaDataManager] since we may still need to handle callbacks from
- * background users (e.g. timeouts) that UI classes should ignore.
- * Instead, UI classes should listen to this so they can stay in sync with the current user.
+ * This is added at the end of the pipeline since we may still need to handle callbacks from
+ * background users (e.g. timeouts).
  */
-@Singleton
 class MediaDataFilter @Inject constructor(
-    private val dataSource: MediaDataCombineLatest,
     private val broadcastDispatcher: BroadcastDispatcher,
     private val mediaResumeListener: MediaResumeListener,
-    private val mediaDataManager: MediaDataManager,
     private val lockscreenUserManager: NotificationLockscreenUserManager,
     @Main private val executor: Executor
 ) : MediaDataManager.Listener {
     private val userTracker: CurrentUserTracker
-    private val listeners: MutableSet<MediaDataManager.Listener> = mutableSetOf()
+    private val _listeners: MutableSet<MediaDataManager.Listener> = mutableSetOf()
+    internal val listeners: Set<MediaDataManager.Listener>
+        get() = _listeners.toSet()
+    internal lateinit var mediaDataManager: MediaDataManager
 
-    // The filtered mediaEntries, which will be a subset of all mediaEntries in MediaDataManager
-    private val mediaEntries: LinkedHashMap<String, MediaData> = LinkedHashMap()
+    private val allEntries: LinkedHashMap<String, MediaData> = LinkedHashMap()
+    // The filtered userEntries, which will be a subset of all userEntries in MediaDataManager
+    private val userEntries: LinkedHashMap<String, MediaData> = LinkedHashMap()
 
     init {
         userTracker = object : CurrentUserTracker(broadcastDispatcher) {
@@ -60,31 +59,34 @@
             }
         }
         userTracker.startTracking()
-        dataSource.addListener(this)
     }
 
     override fun onMediaDataLoaded(key: String, oldKey: String?, data: MediaData) {
+        if (oldKey != null && oldKey != key) {
+            allEntries.remove(oldKey)
+        }
+        allEntries.put(key, data)
+
         if (!lockscreenUserManager.isCurrentProfile(data.userId)) {
             return
         }
 
-        if (oldKey != null) {
-            mediaEntries.remove(oldKey)
+        if (oldKey != null && oldKey != key) {
+            userEntries.remove(oldKey)
         }
-        mediaEntries.put(key, data)
+        userEntries.put(key, data)
 
         // Notify listeners
-        val listenersCopy = listeners.toSet()
-        listenersCopy.forEach {
+        listeners.forEach {
             it.onMediaDataLoaded(key, oldKey, data)
         }
     }
 
     override fun onMediaDataRemoved(key: String) {
-        mediaEntries.remove(key)?.let {
+        allEntries.remove(key)
+        userEntries.remove(key)?.let {
             // Only notify listeners if something actually changed
-            val listenersCopy = listeners.toSet()
-            listenersCopy.forEach {
+            listeners.forEach {
                 it.onMediaDataRemoved(key)
             }
         }
@@ -93,11 +95,11 @@
     @VisibleForTesting
     internal fun handleUserSwitched(id: Int) {
         // If the user changes, remove all current MediaData objects and inform listeners
-        val listenersCopy = listeners.toSet()
-        val keyCopy = mediaEntries.keys.toMutableList()
+        val listenersCopy = listeners
+        val keyCopy = userEntries.keys.toMutableList()
         // Clear the list first, to make sure callbacks from listeners if we have any entries
         // are up to date
-        mediaEntries.clear()
+        userEntries.clear()
         keyCopy.forEach {
             if (DEBUG) Log.d(TAG, "Removing $it after user change")
             listenersCopy.forEach { listener ->
@@ -105,10 +107,10 @@
             }
         }
 
-        dataSource.getData().forEach { (key, data) ->
+        allEntries.forEach { (key, data) ->
             if (lockscreenUserManager.isCurrentProfile(data.userId)) {
                 if (DEBUG) Log.d(TAG, "Re-adding $key after user change")
-                mediaEntries.put(key, data)
+                userEntries.put(key, data)
                 listenersCopy.forEach { listener ->
                     listener.onMediaDataLoaded(key, null, data)
                 }
@@ -121,7 +123,7 @@
      */
     fun onSwipeToDismiss() {
         if (DEBUG) Log.d(TAG, "Media carousel swiped away")
-        val mediaKeys = mediaEntries.keys.toSet()
+        val mediaKeys = userEntries.keys.toSet()
         mediaKeys.forEach {
             mediaDataManager.setTimedOut(it, timedOut = true)
         }
@@ -130,7 +132,7 @@
     /**
      * Are there any media notifications active?
      */
-    fun hasActiveMedia() = mediaEntries.any { it.value.active }
+    fun hasActiveMedia() = userEntries.any { it.value.active }
 
     /**
      * Are there any media entries we should display?
@@ -138,7 +140,7 @@
      * If resumption is disabled, we only want to show active players
      */
     fun hasAnyMedia() = if (mediaResumeListener.isResumptionEnabled()) {
-        mediaEntries.isNotEmpty()
+        userEntries.isNotEmpty()
     } else {
         hasActiveMedia()
     }
@@ -146,10 +148,10 @@
     /**
      * Add a listener for filtered [MediaData] changes
      */
-    fun addListener(listener: MediaDataManager.Listener) = listeners.add(listener)
+    fun addListener(listener: MediaDataManager.Listener) = _listeners.add(listener)
 
     /**
      * Remove a listener that was registered with addListener
      */
-    fun removeListener(listener: MediaDataManager.Listener) = listeners.remove(listener)
-}
\ No newline at end of file
+    fun removeListener(listener: MediaDataManager.Listener) = _listeners.remove(listener)
+}
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaDataManager.kt b/packages/SystemUI/src/com/android/systemui/media/MediaDataManager.kt
index bff334e..e239ba9 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaDataManager.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaDataManager.kt
@@ -101,12 +101,23 @@
     dumpManager: DumpManager,
     mediaTimeoutListener: MediaTimeoutListener,
     mediaResumeListener: MediaResumeListener,
+    mediaSessionBasedFilter: MediaSessionBasedFilter,
+    mediaDeviceManager: MediaDeviceManager,
+    mediaDataCombineLatest: MediaDataCombineLatest,
+    private val mediaDataFilter: MediaDataFilter,
     private val activityStarter: ActivityStarter,
     private var useMediaResumption: Boolean,
     private val useQsMediaPlayer: Boolean
 ) : Dumpable {
 
-    private val listeners: MutableSet<Listener> = mutableSetOf()
+    // Internal listeners are part of the internal pipeline. External listeners (those registered
+    // with [MediaDeviceManager.addListener]) receive events after they have propagated through
+    // the internal pipeline.
+    // Another way to think of the distinction between internal and external listeners is the
+    // following. Internal listeners are listeners that MediaDataManager depends on, and external
+    // listeners are listeners that depend on MediaDataManager.
+    // TODO(b/159539991#comment5): Move internal listeners to separate package.
+    private val internalListeners: MutableSet<Listener> = mutableSetOf()
     private val mediaEntries: LinkedHashMap<String, MediaData> = LinkedHashMap()
     internal var appsBlockedFromResume: MutableSet<String> = Utils.getBlockedMediaApps(context)
         set(value) {
@@ -130,9 +141,14 @@
         broadcastDispatcher: BroadcastDispatcher,
         mediaTimeoutListener: MediaTimeoutListener,
         mediaResumeListener: MediaResumeListener,
+        mediaSessionBasedFilter: MediaSessionBasedFilter,
+        mediaDeviceManager: MediaDeviceManager,
+        mediaDataCombineLatest: MediaDataCombineLatest,
+        mediaDataFilter: MediaDataFilter,
         activityStarter: ActivityStarter
     ) : this(context, backgroundExecutor, foregroundExecutor, mediaControllerFactory,
             broadcastDispatcher, dumpManager, mediaTimeoutListener, mediaResumeListener,
+            mediaSessionBasedFilter, mediaDeviceManager, mediaDataCombineLatest, mediaDataFilter,
             activityStarter, Utils.useMediaResumption(context), Utils.useQsMediaPlayer(context))
 
     private val appChangeReceiver = object : BroadcastReceiver() {
@@ -155,12 +171,26 @@
 
     init {
         dumpManager.registerDumpable(TAG, this)
+
+        // Initialize the internal processing pipeline. The listeners at the front of the pipeline
+        // are set as internal listeners so that they receive events. From there, events are
+        // propagated through the pipeline. The end of the pipeline is currently mediaDataFilter,
+        // so it is responsible for dispatching events to external listeners. To achieve this,
+        // external listeners that are registered with [MediaDataManager.addListener] are actually
+        // registered as listeners to mediaDataFilter.
+        addInternalListener(mediaTimeoutListener)
+        addInternalListener(mediaResumeListener)
+        addInternalListener(mediaSessionBasedFilter)
+        mediaSessionBasedFilter.addListener(mediaDeviceManager)
+        mediaSessionBasedFilter.addListener(mediaDataCombineLatest)
+        mediaDeviceManager.addListener(mediaDataCombineLatest)
+        mediaDataCombineLatest.addListener(mediaDataFilter)
+
+        // Set up links back into the pipeline for listeners that need to send events upstream.
         mediaTimeoutListener.timeoutCallback = { token: String, timedOut: Boolean ->
             setTimedOut(token, timedOut) }
-        addListener(mediaTimeoutListener)
-
         mediaResumeListener.setManager(this)
-        addListener(mediaResumeListener)
+        mediaDataFilter.mediaDataManager = this
 
         val suspendFilter = IntentFilter(Intent.ACTION_PACKAGES_SUSPENDED)
         broadcastDispatcher.registerReceiver(appChangeReceiver, suspendFilter, null, UserHandle.ALL)
@@ -198,10 +228,9 @@
 
     private fun removeAllForPackage(packageName: String) {
         Assert.isMainThread()
-        val listenersCopy = listeners.toSet()
         val toRemove = mediaEntries.filter { it.value.packageName == packageName }
         toRemove.forEach {
-            removeEntry(it.key, listenersCopy)
+            removeEntry(it.key)
         }
     }
 
@@ -261,12 +290,45 @@
     /**
      * Add a listener for changes in this class
      */
-    fun addListener(listener: Listener) = listeners.add(listener)
+    fun addListener(listener: Listener) {
+        // mediaDataFilter is the current end of the internal pipeline. Register external
+        // listeners as listeners to it.
+        mediaDataFilter.addListener(listener)
+    }
 
     /**
      * Remove a listener for changes in this class
      */
-    fun removeListener(listener: Listener) = listeners.remove(listener)
+    fun removeListener(listener: Listener) {
+        // Since mediaDataFilter is the current end of the internal pipelie, external listeners
+        // have been registered to it. So, they need to be removed from it too.
+        mediaDataFilter.removeListener(listener)
+    }
+
+    /**
+     * Add a listener for internal events.
+     */
+    private fun addInternalListener(listener: Listener) = internalListeners.add(listener)
+
+    /**
+     * Notify internal listeners of loaded event.
+     *
+     * External listeners registered with [addListener] will be notified after the event propagates
+     * through the internal listener pipeline.
+     */
+    private fun notifyMediaDataLoaded(key: String, oldKey: String?, info: MediaData) {
+        internalListeners.forEach { it.onMediaDataLoaded(key, oldKey, info) }
+    }
+
+    /**
+     * Notify internal listeners of removed event.
+     *
+     * External listeners registered with [addListener] will be notified after the event propagates
+     * through the internal listener pipeline.
+     */
+    private fun notifyMediaDataRemoved(key: String) {
+        internalListeners.forEach { it.onMediaDataRemoved(key) }
+    }
 
     /**
      * Called whenever the player has been paused or stopped for a while, or swiped from QQS.
@@ -284,16 +346,13 @@
         }
     }
 
-    private fun removeEntry(key: String, listenersCopy: Set<Listener>) {
+    private fun removeEntry(key: String) {
         mediaEntries.remove(key)
-        listenersCopy.forEach {
-            it.onMediaDataRemoved(key)
-        }
+        notifyMediaDataRemoved(key)
     }
 
     fun dismissMediaData(key: String, delay: Long) {
-        val listenersCopy = listeners.toSet()
-        foregroundExecutor.executeDelayed({ removeEntry(key, listenersCopy) }, delay)
+        foregroundExecutor.executeDelayed({ removeEntry(key) }, delay)
     }
 
     private fun loadMediaDataInBgForResumption(
@@ -422,7 +481,7 @@
                 val runnable = if (action.actionIntent != null) {
                     Runnable {
                         if (action.isAuthenticationRequired()) {
-                            activityStarter.dismissKeyguardThenExecute ({
+                            activityStarter.dismissKeyguardThenExecute({
                                 var result = sendPendingIntent(action.actionIntent)
                                 result
                             }, {}, true)
@@ -550,10 +609,7 @@
         if (mediaEntries.containsKey(key)) {
             // Otherwise this was removed already
             mediaEntries.put(key, data)
-            val listenersCopy = listeners.toSet()
-            listenersCopy.forEach {
-                it.onMediaDataLoaded(key, oldKey, data)
-            }
+            notifyMediaDataLoaded(key, oldKey, data)
         }
     }
 
@@ -570,31 +626,21 @@
             val pkg = removed?.packageName
             val migrate = mediaEntries.put(pkg, updated) == null
             // Notify listeners of "new" controls when migrating or removed and update when not
-            val listenersCopy = listeners.toSet()
             if (migrate) {
-                listenersCopy.forEach {
-                    it.onMediaDataLoaded(pkg, key, updated)
-                }
+                notifyMediaDataLoaded(pkg, key, updated)
             } else {
                 // Since packageName is used for the key of the resumption controls, it is
                 // possible that another notification has already been reused for the resumption
                 // controls of this package. In this case, rather than renaming this player as
                 // packageName, just remove it and then send a update to the existing resumption
                 // controls.
-                listenersCopy.forEach {
-                    it.onMediaDataRemoved(key)
-                }
-                listenersCopy.forEach {
-                    it.onMediaDataLoaded(pkg, pkg, updated)
-                }
+                notifyMediaDataRemoved(key)
+                notifyMediaDataLoaded(pkg, pkg, updated)
             }
             return
         }
         if (removed != null) {
-            val listenersCopy = listeners.toSet()
-            listenersCopy.forEach {
-                it.onMediaDataRemoved(key)
-            }
+            notifyMediaDataRemoved(key)
         }
     }
 
@@ -614,17 +660,31 @@
 
         if (!useMediaResumption) {
             // Remove any existing resume controls
-            val listenersCopy = listeners.toSet()
             val filtered = mediaEntries.filter { !it.value.active }
             filtered.forEach {
                 mediaEntries.remove(it.key)
-                listenersCopy.forEach { listener ->
-                    listener.onMediaDataRemoved(it.key)
-                }
+                notifyMediaDataRemoved(it.key)
             }
         }
     }
 
+    /**
+     * Invoked when the user has dismissed the media carousel
+     */
+    fun onSwipeToDismiss() = mediaDataFilter.onSwipeToDismiss()
+
+    /**
+     * Are there any media notifications active?
+     */
+    fun hasActiveMedia() = mediaDataFilter.hasActiveMedia()
+
+    /**
+     * Are there any media entries we should display?
+     * If resumption is enabled, this will include inactive players
+     * If resumption is disabled, we only want to show active players
+     */
+    fun hasAnyMedia() = mediaDataFilter.hasAnyMedia()
+
     interface Listener {
 
         /**
@@ -644,7 +704,8 @@
 
     override fun dump(fd: FileDescriptor, pw: PrintWriter, args: Array<out String>) {
         pw.apply {
-            println("listeners: $listeners")
+            println("internalListeners: $internalListeners")
+            println("externalListeners: ${mediaDataFilter.listeners}")
             println("mediaEntries: $mediaEntries")
             println("useMediaResumption: $useMediaResumption")
             println("appsBlockedFromResume: $appsBlockedFromResume")
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaDeviceManager.kt b/packages/SystemUI/src/com/android/systemui/media/MediaDeviceManager.kt
index ae7f66b..102a484 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaDeviceManager.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaDeviceManager.kt
@@ -32,26 +32,23 @@
 import java.io.PrintWriter
 import java.util.concurrent.Executor
 import javax.inject.Inject
-import javax.inject.Singleton
 
 /**
  * Provides information about the route (ie. device) where playback is occurring.
  */
-@Singleton
 class MediaDeviceManager @Inject constructor(
     private val context: Context,
     private val localMediaManagerFactory: LocalMediaManagerFactory,
     private val mr2manager: MediaRouter2Manager,
     @Main private val fgExecutor: Executor,
     @Background private val bgExecutor: Executor,
-    private val mediaDataManager: MediaDataManager,
-    private val dumpManager: DumpManager
+    dumpManager: DumpManager
 ) : MediaDataManager.Listener, Dumpable {
+
     private val listeners: MutableSet<Listener> = mutableSetOf()
     private val entries: MutableMap<String, Entry> = mutableMapOf()
 
     init {
-        mediaDataManager.addListener(this)
         dumpManager.registerDumpable(javaClass.name, this)
     }
 
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaHost.kt b/packages/SystemUI/src/com/android/systemui/media/MediaHost.kt
index 3598719..ce184aa 100644
--- a/packages/SystemUI/src/com/android/systemui/media/MediaHost.kt
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaHost.kt
@@ -14,7 +14,7 @@
 class MediaHost @Inject constructor(
     private val state: MediaHostStateHolder,
     private val mediaHierarchyManager: MediaHierarchyManager,
-    private val mediaDataFilter: MediaDataFilter,
+    private val mediaDataManager: MediaDataManager,
     private val mediaHostStatesManager: MediaHostStatesManager
 ) : MediaHostState by state {
     lateinit var hostView: UniqueObjectHostView
@@ -79,12 +79,12 @@
                 // be a delay until the views and the controllers are initialized, leaving us
                 // with either a blank view or the controllers not yet initialized and the
                 // measuring wrong
-                mediaDataFilter.addListener(listener)
+                mediaDataManager.addListener(listener)
                 updateViewVisibility()
             }
 
             override fun onViewDetachedFromWindow(v: View?) {
-                mediaDataFilter.removeListener(listener)
+                mediaDataManager.removeListener(listener)
             }
         })
 
@@ -113,9 +113,9 @@
 
     private fun updateViewVisibility() {
         visible = if (showsOnlyActiveMedia) {
-            mediaDataFilter.hasActiveMedia()
+            mediaDataManager.hasActiveMedia()
         } else {
-            mediaDataFilter.hasAnyMedia()
+            mediaDataManager.hasAnyMedia()
         }
         val newVisibility = if (visible) View.VISIBLE else View.GONE
         if (newVisibility != hostView.visibility) {
@@ -289,4 +289,4 @@
      * Get a copy of this view state, deepcopying all appropriate members
      */
     fun copy(): MediaHostState
-}
\ No newline at end of file
+}
diff --git a/packages/SystemUI/src/com/android/systemui/media/MediaSessionBasedFilter.kt b/packages/SystemUI/src/com/android/systemui/media/MediaSessionBasedFilter.kt
new file mode 100644
index 0000000..f01713f
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/media/MediaSessionBasedFilter.kt
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.media
+
+import android.content.ComponentName
+import android.content.Context
+import android.media.session.MediaController
+import android.media.session.MediaController.PlaybackInfo
+import android.media.session.MediaSession
+import android.media.session.MediaSessionManager
+import android.util.Log
+import com.android.systemui.dagger.qualifiers.Background
+import com.android.systemui.dagger.qualifiers.Main
+import com.android.systemui.statusbar.phone.NotificationListenerWithPlugins
+import java.util.concurrent.Executor
+import javax.inject.Inject
+
+private const val TAG = "MediaSessionBasedFilter"
+
+/**
+ * Filters media loaded events for local media sessions while an app is casting.
+ *
+ * When an app is casting there can be one remote media sessions and potentially more local media
+ * sessions. In this situation, there should only be a media object for the remote session. To
+ * achieve this, update events for the local session need to be filtered.
+ */
+class MediaSessionBasedFilter @Inject constructor(
+    context: Context,
+    private val sessionManager: MediaSessionManager,
+    @Main private val foregroundExecutor: Executor,
+    @Background private val backgroundExecutor: Executor
+) : MediaDataManager.Listener {
+
+    private val listeners: MutableSet<MediaDataManager.Listener> = mutableSetOf()
+
+    // Keep track of MediaControllers for a given package to check if an app is casting and it
+    // filter loaded events for local sessions.
+    private val packageControllers: LinkedHashMap<String, MutableList<MediaController>> =
+            LinkedHashMap()
+
+    // Keep track of the key used for the session tokens. This information is used to know when
+    // dispatch a removed event so that a media object for a local session will be removed.
+    private val keyedTokens: MutableMap<String, MutableList<MediaSession.Token>> = mutableMapOf()
+
+    private val sessionListener = object : MediaSessionManager.OnActiveSessionsChangedListener {
+        override fun onActiveSessionsChanged(controllers: List<MediaController>) {
+            handleControllersChanged(controllers)
+        }
+    }
+
+    init {
+        backgroundExecutor.execute {
+            val name = ComponentName(context, NotificationListenerWithPlugins::class.java)
+            sessionManager.addOnActiveSessionsChangedListener(sessionListener, name)
+            handleControllersChanged(sessionManager.getActiveSessions(name))
+        }
+    }
+
+    /**
+     * Add a listener for filtered [MediaData] changes
+     */
+    fun addListener(listener: MediaDataManager.Listener) = listeners.add(listener)
+
+    /**
+     * Remove a listener that was registered with addListener
+     */
+    fun removeListener(listener: MediaDataManager.Listener) = listeners.remove(listener)
+
+    /**
+     * May filter loaded events by not passing them along to listeners.
+     *
+     * If an app has only one session with playback type PLAYBACK_TYPE_REMOTE, then assuming that
+     * the app is casting. Sometimes apps will send redundant updates to a local session with
+     * playback type PLAYBACK_TYPE_LOCAL. These updates should be filtered to improve the usability
+     * of the media controls.
+     */
+    override fun onMediaDataLoaded(key: String, oldKey: String?, info: MediaData) {
+        backgroundExecutor.execute {
+            val isMigration = oldKey != null && key != oldKey
+            if (isMigration) {
+                keyedTokens.remove(oldKey)?.let { removed -> keyedTokens.put(key, removed) }
+            }
+            if (info.token != null) {
+                keyedTokens.get(key)?.let {
+                    tokens ->
+                    tokens.add(info.token)
+                } ?: run {
+                    val tokens = mutableListOf(info.token)
+                    keyedTokens.put(key, tokens)
+                }
+            }
+            // Determine if an app is casting by checking if it has a session with playback type
+            // PLAYBACK_TYPE_REMOTE.
+            val remoteControllers = packageControllers.get(info.packageName)?.filter {
+                it.playbackInfo?.playbackType == PlaybackInfo.PLAYBACK_TYPE_REMOTE
+            }
+            // Limiting search to only apps with a single remote session.
+            val remote = if (remoteControllers?.size == 1) remoteControllers.firstOrNull() else null
+            if (isMigration || remote == null || remote.sessionToken == info.token) {
+                // Not filtering in this case. Passing the event along to listeners.
+                dispatchMediaDataLoaded(key, oldKey, info)
+            } else {
+                // Filtering this event because the app is casting and the loaded events is for a
+                // local session.
+                Log.d(TAG, "filtering key=$key local=${info.token} remote=${remote?.sessionToken}")
+                // If the local session uses a different notification key, then lets go a step
+                // farther and dismiss the media data so that media controls for the local session
+                // don't hang around while casting.
+                if (!keyedTokens.get(key)!!.contains(remote.sessionToken)) {
+                    dispatchMediaDataRemoved(key)
+                }
+            }
+        }
+    }
+
+    override fun onMediaDataRemoved(key: String) {
+        // Queue on background thread to ensure ordering of loaded and removed events is maintained.
+        backgroundExecutor.execute {
+            keyedTokens.remove(key)
+            dispatchMediaDataRemoved(key)
+        }
+    }
+
+    private fun dispatchMediaDataLoaded(key: String, oldKey: String?, info: MediaData) {
+        foregroundExecutor.execute {
+            listeners.toSet().forEach { it.onMediaDataLoaded(key, oldKey, info) }
+        }
+    }
+
+    private fun dispatchMediaDataRemoved(key: String) {
+        foregroundExecutor.execute {
+            listeners.toSet().forEach { it.onMediaDataRemoved(key) }
+        }
+    }
+
+    private fun handleControllersChanged(controllers: List<MediaController>) {
+        packageControllers.clear()
+        controllers.forEach {
+            controller ->
+            packageControllers.get(controller.packageName)?.let {
+                tokens ->
+                tokens.add(controller)
+            } ?: run {
+                val tokens = mutableListOf(controller)
+                packageControllers.put(controller.packageName, tokens)
+            }
+        }
+    }
+}
diff --git a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataCombineLatestTest.java b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataCombineLatestTest.java
index 2e794a4..89538ac 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataCombineLatestTest.java
+++ b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataCombineLatestTest.java
@@ -20,7 +20,6 @@
 
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.eq;
-import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.verify;
@@ -34,19 +33,23 @@
 import com.android.systemui.SysuiTestCase;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 import java.util.ArrayList;
-import java.util.Map;
 
 @SmallTest
 @RunWith(AndroidTestingRunner.class)
 @TestableLooper.RunWithLooper
 public class MediaDataCombineLatestTest extends SysuiTestCase {
 
+    @Rule public MockitoRule mockito = MockitoJUnit.rule();
+
     private static final String KEY = "TEST_KEY";
     private static final String OLD_KEY = "TEST_KEY_OLD";
     private static final String APP = "APP";
@@ -59,27 +62,14 @@
 
     private MediaDataCombineLatest mManager;
 
-    @Mock private MediaDataManager mDataSource;
-    @Mock private MediaDeviceManager mDeviceSource;
     @Mock private MediaDataManager.Listener mListener;
 
-    private MediaDataManager.Listener mDataListener;
-    private MediaDeviceManager.Listener mDeviceListener;
-
     private MediaData mMediaData;
     private MediaDeviceData mDeviceData;
 
     @Before
     public void setUp() {
-        mDataSource = mock(MediaDataManager.class);
-        mDeviceSource = mock(MediaDeviceManager.class);
-        mListener = mock(MediaDataManager.Listener.class);
-
-        mManager = new MediaDataCombineLatest(mDataSource, mDeviceSource);
-
-        mDataListener = captureDataListener();
-        mDeviceListener = captureDeviceListener();
-
+        mManager = new MediaDataCombineLatest();
         mManager.addListener(mListener);
 
         mMediaData = new MediaData(USER_ID, true, BG_COLOR, APP, null, ARTIST, TITLE, null,
@@ -91,7 +81,7 @@
     @Test
     public void eventNotEmittedWithoutDevice() {
         // WHEN data source emits an event without device data
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
+        mManager.onMediaDataLoaded(KEY, null, mMediaData);
         // THEN an event isn't emitted
         verify(mListener, never()).onMediaDataLoaded(eq(KEY), any(), any());
     }
@@ -99,7 +89,7 @@
     @Test
     public void eventNotEmittedWithoutMedia() {
         // WHEN device source emits an event without media data
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, null, mDeviceData);
         // THEN an event isn't emitted
         verify(mListener, never()).onMediaDataLoaded(eq(KEY), any(), any());
     }
@@ -107,9 +97,9 @@
     @Test
     public void emitEventAfterDeviceFirst() {
         // GIVEN that a device event has already been received
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, null, mDeviceData);
         // WHEN media event is received
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
+        mManager.onMediaDataLoaded(KEY, null, mMediaData);
         // THEN the listener receives a combined event
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), any(), captor.capture());
@@ -119,9 +109,9 @@
     @Test
     public void emitEventAfterMediaFirst() {
         // GIVEN that media event has already been received
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
+        mManager.onMediaDataLoaded(KEY, null, mMediaData);
         // WHEN device event is received
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, null, mDeviceData);
         // THEN the listener receives a combined event
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), any(), captor.capture());
@@ -131,11 +121,11 @@
     @Test
     public void migrateKeyMediaFirst() {
         // GIVEN that media and device info has already been received
-        mDataListener.onMediaDataLoaded(OLD_KEY, null, mMediaData);
-        mDeviceListener.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
+        mManager.onMediaDataLoaded(OLD_KEY, null, mMediaData);
+        mManager.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
         reset(mListener);
         // WHEN a key migration event is received
-        mDataListener.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
+        mManager.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
         // THEN the listener receives a combined event
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), eq(OLD_KEY), captor.capture());
@@ -145,11 +135,11 @@
     @Test
     public void migrateKeyDeviceFirst() {
         // GIVEN that media and device info has already been received
-        mDataListener.onMediaDataLoaded(OLD_KEY, null, mMediaData);
-        mDeviceListener.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
+        mManager.onMediaDataLoaded(OLD_KEY, null, mMediaData);
+        mManager.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
         reset(mListener);
         // WHEN a key migration event is received
-        mDeviceListener.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
         // THEN the listener receives a combined event
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), eq(OLD_KEY), captor.capture());
@@ -159,12 +149,12 @@
     @Test
     public void migrateKeyMediaAfter() {
         // GIVEN that media and device info has already been received
-        mDataListener.onMediaDataLoaded(OLD_KEY, null, mMediaData);
-        mDeviceListener.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
-        mDeviceListener.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
+        mManager.onMediaDataLoaded(OLD_KEY, null, mMediaData);
+        mManager.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
         reset(mListener);
         // WHEN a second key migration event is received for media
-        mDataListener.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
+        mManager.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
         // THEN the key has already been migrated
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), eq(KEY), captor.capture());
@@ -174,12 +164,12 @@
     @Test
     public void migrateKeyDeviceAfter() {
         // GIVEN that media and device info has already been received
-        mDataListener.onMediaDataLoaded(OLD_KEY, null, mMediaData);
-        mDeviceListener.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
-        mDataListener.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
+        mManager.onMediaDataLoaded(OLD_KEY, null, mMediaData);
+        mManager.onMediaDeviceChanged(OLD_KEY, null, mDeviceData);
+        mManager.onMediaDataLoaded(KEY, OLD_KEY, mMediaData);
         reset(mListener);
         // WHEN a second key migration event is received for the device
-        mDeviceListener.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
+        mManager.onMediaDeviceChanged(KEY, OLD_KEY, mDeviceData);
         // THEN the key has already be migrated
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq(KEY), eq(KEY), captor.capture());
@@ -189,60 +179,34 @@
     @Test
     public void mediaDataRemoved() {
         // WHEN media data is removed without first receiving device or data
-        mDataListener.onMediaDataRemoved(KEY);
+        mManager.onMediaDataRemoved(KEY);
         // THEN a removed event isn't emitted
         verify(mListener, never()).onMediaDataRemoved(eq(KEY));
     }
 
     @Test
     public void mediaDataRemovedAfterMediaEvent() {
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
-        mDataListener.onMediaDataRemoved(KEY);
+        mManager.onMediaDataLoaded(KEY, null, mMediaData);
+        mManager.onMediaDataRemoved(KEY);
         verify(mListener).onMediaDataRemoved(eq(KEY));
     }
 
     @Test
     public void mediaDataRemovedAfterDeviceEvent() {
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
-        mDataListener.onMediaDataRemoved(KEY);
+        mManager.onMediaDeviceChanged(KEY, null, mDeviceData);
+        mManager.onMediaDataRemoved(KEY);
         verify(mListener).onMediaDataRemoved(eq(KEY));
     }
 
     @Test
     public void mediaDataKeyUpdated() {
         // GIVEN that device and media events have already been received
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
+        mManager.onMediaDataLoaded(KEY, null, mMediaData);
+        mManager.onMediaDeviceChanged(KEY, null, mDeviceData);
         // WHEN the key is changed
-        mDataListener.onMediaDataLoaded("NEW_KEY", KEY, mMediaData);
+        mManager.onMediaDataLoaded("NEW_KEY", KEY, mMediaData);
         // THEN the listener gets a load event with the correct keys
         ArgumentCaptor<MediaData> captor = ArgumentCaptor.forClass(MediaData.class);
         verify(mListener).onMediaDataLoaded(eq("NEW_KEY"), any(), captor.capture());
     }
-
-    @Test
-    public void getDataIncludesDevice() {
-        // GIVEN that device and media events have been received
-        mDeviceListener.onMediaDeviceChanged(KEY, null, mDeviceData);
-        mDataListener.onMediaDataLoaded(KEY, null, mMediaData);
-
-        // THEN the result of getData includes device info
-        Map<String, MediaData> results = mManager.getData();
-        assertThat(results.get(KEY)).isNotNull();
-        assertThat(results.get(KEY).getDevice()).isEqualTo(mDeviceData);
-    }
-
-    private MediaDataManager.Listener captureDataListener() {
-        ArgumentCaptor<MediaDataManager.Listener> captor = ArgumentCaptor.forClass(
-                MediaDataManager.Listener.class);
-        verify(mDataSource).addListener(captor.capture());
-        return captor.getValue();
-    }
-
-    private MediaDeviceManager.Listener captureDeviceListener() {
-        ArgumentCaptor<MediaDeviceManager.Listener> captor = ArgumentCaptor.forClass(
-                MediaDeviceManager.Listener.class);
-        verify(mDeviceSource).addListener(captor.capture());
-        return captor.getValue();
-    }
 }
diff --git a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataFilterTest.kt b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataFilterTest.kt
index afb64a7..36b6527 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataFilterTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataFilterTest.kt
@@ -32,6 +32,7 @@
 import org.mockito.Mockito
 import org.mockito.Mockito.`when`
 import org.mockito.Mockito.never
+import org.mockito.Mockito.reset
 import org.mockito.Mockito.verify
 import org.mockito.MockitoAnnotations
 import java.util.concurrent.Executor
@@ -56,8 +57,6 @@
 class MediaDataFilterTest : SysuiTestCase() {
 
     @Mock
-    private lateinit var combineLatest: MediaDataCombineLatest
-    @Mock
     private lateinit var listener: MediaDataManager.Listener
     @Mock
     private lateinit var broadcastDispatcher: BroadcastDispatcher
@@ -78,8 +77,9 @@
     @Before
     fun setup() {
         MockitoAnnotations.initMocks(this)
-        mediaDataFilter = MediaDataFilter(combineLatest, broadcastDispatcher, mediaResumeListener,
-            mediaDataManager, lockscreenUserManager, executor)
+        mediaDataFilter = MediaDataFilter(broadcastDispatcher, mediaResumeListener,
+                lockscreenUserManager, executor)
+        mediaDataFilter.mediaDataManager = mediaDataManager
         mediaDataFilter.addListener(listener)
 
         // Start all tests as main user
@@ -152,8 +152,9 @@
     @Test
     fun testOnUserSwitched_addsNewUserControls() {
         // GIVEN that we had some media for both users
-        val dataMap = mapOf(KEY to dataMain, KEY_ALT to dataGuest)
-        `when`(combineLatest.getData()).thenReturn(dataMap)
+        mediaDataFilter.onMediaDataLoaded(KEY, null, dataMain)
+        mediaDataFilter.onMediaDataLoaded(KEY_ALT, null, dataGuest)
+        reset(listener)
 
         // and we switch to guest user
         setUser(USER_GUEST)
@@ -213,4 +214,4 @@
 
         verify(mediaDataManager).setTimedOut(eq(KEY), eq(true))
     }
-}
\ No newline at end of file
+}
diff --git a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataManagerTest.kt b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataManagerTest.kt
index 457d559..84c1bf9 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataManagerTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDataManagerTest.kt
@@ -16,6 +16,7 @@
 import com.android.systemui.plugins.ActivityStarter
 import com.android.systemui.statusbar.SbnBuilder
 import com.android.systemui.util.concurrency.FakeExecutor
+import com.android.systemui.util.mockito.capture
 import com.android.systemui.util.mockito.eq
 import com.android.systemui.util.time.FakeSystemClock
 import com.google.common.truth.Truth.assertThat
@@ -24,9 +25,13 @@
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Captor
 import org.mockito.Mock
 import org.mockito.Mockito
 import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.reset
 import org.mockito.Mockito.verify
 import org.mockito.junit.MockitoJUnit
 import org.mockito.Mockito.`when` as whenever
@@ -48,6 +53,7 @@
 @RunWith(AndroidTestingRunner::class)
 class MediaDataManagerTest : SysuiTestCase() {
 
+    @JvmField @Rule val mockito = MockitoJUnit.rule()
     @Mock lateinit var mediaControllerFactory: MediaControllerFactory
     @Mock lateinit var controller: MediaController
     lateinit var session: MediaSession
@@ -58,20 +64,38 @@
     @Mock lateinit var broadcastDispatcher: BroadcastDispatcher
     @Mock lateinit var mediaTimeoutListener: MediaTimeoutListener
     @Mock lateinit var mediaResumeListener: MediaResumeListener
+    @Mock lateinit var mediaSessionBasedFilter: MediaSessionBasedFilter
+    @Mock lateinit var mediaDeviceManager: MediaDeviceManager
+    @Mock lateinit var mediaDataCombineLatest: MediaDataCombineLatest
+    @Mock lateinit var mediaDataFilter: MediaDataFilter
+    @Mock lateinit var listener: MediaDataManager.Listener
     @Mock lateinit var pendingIntent: PendingIntent
     @Mock lateinit var activityStarter: ActivityStarter
-    @JvmField @Rule val mockito = MockitoJUnit.rule()
     lateinit var mediaDataManager: MediaDataManager
     lateinit var mediaNotification: StatusBarNotification
+    @Captor lateinit var mediaDataCaptor: ArgumentCaptor<MediaData>
 
     @Before
     fun setup() {
         foregroundExecutor = FakeExecutor(FakeSystemClock())
         backgroundExecutor = FakeExecutor(FakeSystemClock())
-        mediaDataManager = MediaDataManager(context, backgroundExecutor, foregroundExecutor,
-                mediaControllerFactory, broadcastDispatcher, dumpManager,
-                mediaTimeoutListener, mediaResumeListener, activityStarter,
-                useMediaResumption = true, useQsMediaPlayer = true)
+        mediaDataManager = MediaDataManager(
+            context = context,
+            backgroundExecutor = backgroundExecutor,
+            foregroundExecutor = foregroundExecutor,
+            mediaControllerFactory = mediaControllerFactory,
+            broadcastDispatcher = broadcastDispatcher,
+            dumpManager = dumpManager,
+            mediaTimeoutListener = mediaTimeoutListener,
+            mediaResumeListener = mediaResumeListener,
+            mediaSessionBasedFilter = mediaSessionBasedFilter,
+            mediaDeviceManager = mediaDeviceManager,
+            mediaDataCombineLatest = mediaDataCombineLatest,
+            mediaDataFilter = mediaDataFilter,
+            activityStarter = activityStarter,
+            useMediaResumption = true,
+            useQsMediaPlayer = true
+        )
         session = MediaSession(context, "MediaDataManagerTestSession")
         mediaNotification = SbnBuilder().run {
             setPkg(PACKAGE_NAME)
@@ -86,6 +110,12 @@
             putString(MediaMetadata.METADATA_KEY_TITLE, SESSION_TITLE)
         }
         whenever(mediaControllerFactory.create(eq(session.sessionToken))).thenReturn(controller)
+
+        // This is an ugly hack for now. The mediaSessionBasedFilter is one of the internal
+        // listeners in the internal processing pipeline. It receives events, but ince it is a
+        // mock, it doesn't pass those events along the chain to the external listeners. So, just
+        // treat mediaSessionBasedFilter as a listener for testing.
+        listener = mediaSessionBasedFilter
     }
 
     @After
@@ -115,8 +145,6 @@
 
     @Test
     fun testOnMetaDataLoaded_callsListener() {
-        val listener = mock(MediaDataManager.Listener::class.java)
-        mediaDataManager.addListener(listener)
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         mediaDataManager.onMediaDataLoaded(KEY, oldKey = null, data = mock(MediaData::class.java))
         verify(listener).onMediaDataLoaded(eq(KEY), eq(null), anyObject())
@@ -124,90 +152,81 @@
 
     @Test
     fun testOnMetaDataLoaded_conservesActiveFlag() {
-        val listener = TestListener()
         whenever(mediaControllerFactory.create(anyObject())).thenReturn(controller)
         whenever(controller.metadata).thenReturn(metadataBuilder.build())
         mediaDataManager.addListener(listener)
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(1)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(1)
-        assertThat(listener.data!!.active).isTrue()
+        verify(listener).onMediaDataLoaded(eq(KEY), eq(null), capture(mediaDataCaptor))
+        assertThat(mediaDataCaptor.value!!.active).isTrue()
     }
 
     @Test
     fun testOnNotificationRemoved_callsListener() {
-        val listener = mock(MediaDataManager.Listener::class.java)
-        mediaDataManager.addListener(listener)
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         mediaDataManager.onMediaDataLoaded(KEY, oldKey = null, data = mock(MediaData::class.java))
         mediaDataManager.onNotificationRemoved(KEY)
-
         verify(listener).onMediaDataRemoved(eq(KEY))
     }
 
     @Test
     fun testOnNotificationRemoved_withResumption() {
         // GIVEN that the manager has a notification with a resume action
-        val listener = TestListener()
-        mediaDataManager.addListener(listener)
         whenever(controller.metadata).thenReturn(metadataBuilder.build())
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(1)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(1)
-        val data = listener.data!!
+        verify(listener).onMediaDataLoaded(eq(KEY), eq(null), capture(mediaDataCaptor))
+        val data = mediaDataCaptor.value
         assertThat(data.resumption).isFalse()
         mediaDataManager.onMediaDataLoaded(KEY, null, data.copy(resumeAction = Runnable {}))
         // WHEN the notification is removed
         mediaDataManager.onNotificationRemoved(KEY)
         // THEN the media data indicates that it is for resumption
-        assertThat(listener.data!!.resumption).isTrue()
-        // AND the new key is the package name
-        assertThat(listener.key!!).isEqualTo(PACKAGE_NAME)
-        assertThat(listener.oldKey!!).isEqualTo(KEY)
-        assertThat(listener.removedKey).isNull()
+        verify(listener).onMediaDataLoaded(eq(PACKAGE_NAME), eq(KEY), capture(mediaDataCaptor))
+        assertThat(mediaDataCaptor.value.resumption).isTrue()
     }
 
     @Test
     fun testOnNotificationRemoved_twoWithResumption() {
         // GIVEN that the manager has two notifications with resume actions
-        val listener = TestListener()
-        mediaDataManager.addListener(listener)
         whenever(controller.metadata).thenReturn(metadataBuilder.build())
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         mediaDataManager.onNotificationAdded(KEY_2, mediaNotification)
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(2)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(2)
-        val data = listener.data!!
+        verify(listener).onMediaDataLoaded(eq(KEY), eq(null), capture(mediaDataCaptor))
+        val data = mediaDataCaptor.value
         assertThat(data.resumption).isFalse()
         val resumableData = data.copy(resumeAction = Runnable {})
         mediaDataManager.onMediaDataLoaded(KEY, null, resumableData)
         mediaDataManager.onMediaDataLoaded(KEY_2, null, resumableData)
+        reset(listener)
         // WHEN the first is removed
         mediaDataManager.onNotificationRemoved(KEY)
         // THEN the data is for resumption and the key is migrated to the package name
-        assertThat(listener.data!!.resumption).isTrue()
-        assertThat(listener.key!!).isEqualTo(PACKAGE_NAME)
-        assertThat(listener.oldKey!!).isEqualTo(KEY)
-        assertThat(listener.removedKey).isNull()
+        verify(listener).onMediaDataLoaded(eq(PACKAGE_NAME), eq(KEY), capture(mediaDataCaptor))
+        assertThat(mediaDataCaptor.value.resumption).isTrue()
+        verify(listener, never()).onMediaDataRemoved(eq(KEY))
         // WHEN the second is removed
         mediaDataManager.onNotificationRemoved(KEY_2)
         // THEN the data is for resumption and the second key is removed
-        assertThat(listener.data!!.resumption).isTrue()
-        assertThat(listener.key!!).isEqualTo(PACKAGE_NAME)
-        assertThat(listener.oldKey!!).isEqualTo(PACKAGE_NAME)
-        assertThat(listener.removedKey!!).isEqualTo(KEY_2)
+        verify(listener).onMediaDataLoaded(eq(PACKAGE_NAME), eq(PACKAGE_NAME),
+                capture(mediaDataCaptor))
+        assertThat(mediaDataCaptor.value.resumption).isTrue()
+        verify(listener).onMediaDataRemoved(eq(KEY_2))
     }
 
     @Test
     fun testAppBlockedFromResumption() {
         // GIVEN that the manager has a notification with a resume action
-        val listener = TestListener()
-        mediaDataManager.addListener(listener)
         whenever(controller.metadata).thenReturn(metadataBuilder.build())
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(1)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(1)
-        val data = listener.data!!
+        verify(listener).onMediaDataLoaded(eq(KEY), eq(null), capture(mediaDataCaptor))
+        val data = mediaDataCaptor.value
         assertThat(data.resumption).isFalse()
         mediaDataManager.onMediaDataLoaded(KEY, null, data.copy(resumeAction = Runnable {}))
 
@@ -219,7 +238,7 @@
         mediaDataManager.onNotificationRemoved(KEY)
 
         // THEN the media data is removed
-        assertThat(listener.removedKey!!).isEqualTo(KEY)
+        verify(listener).onMediaDataRemoved(eq(KEY))
     }
 
     @Test
@@ -229,13 +248,12 @@
         mediaDataManager.appsBlockedFromResume = blocked
 
         // and GIVEN that the manager has a notification from that app with a resume action
-        val listener = TestListener()
-        mediaDataManager.addListener(listener)
         whenever(controller.metadata).thenReturn(metadataBuilder.build())
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(1)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(1)
-        val data = listener.data!!
+        verify(listener).onMediaDataLoaded(eq(KEY), eq(null), capture(mediaDataCaptor))
+        val data = mediaDataCaptor.value
         assertThat(data.resumption).isFalse()
         mediaDataManager.onMediaDataLoaded(KEY, null, data.copy(resumeAction = Runnable {}))
 
@@ -246,14 +264,11 @@
         mediaDataManager.onNotificationRemoved(KEY)
 
         // THEN the entry will stay as a resume control
-        assertThat(listener.key!!).isEqualTo(PACKAGE_NAME)
-        assertThat(listener.oldKey!!).isEqualTo(KEY)
+        verify(listener).onMediaDataLoaded(eq(PACKAGE_NAME), eq(KEY), capture(mediaDataCaptor))
     }
 
     @Test
     fun testAddResumptionControls() {
-        val listener = TestListener()
-        mediaDataManager.addListener(listener)
         // WHEN resumption controls are added`
         val desc = MediaDescription.Builder().run {
             setTitle(SESSION_TITLE)
@@ -264,7 +279,8 @@
         assertThat(backgroundExecutor.runAllReady()).isEqualTo(1)
         assertThat(foregroundExecutor.runAllReady()).isEqualTo(1)
         // THEN the media data indicates that it is for resumption
-        val data = listener.data!!
+        verify(listener).onMediaDataLoaded(eq(PACKAGE_NAME), eq(null), capture(mediaDataCaptor))
+        val data = mediaDataCaptor.value
         assertThat(data.resumption).isTrue()
         assertThat(data.song).isEqualTo(SESSION_TITLE)
         assertThat(data.app).isEqualTo(APP_NAME)
@@ -273,8 +289,6 @@
 
     @Test
     fun testDismissMedia_listenerCalled() {
-        val listener = mock(MediaDataManager.Listener::class.java)
-        mediaDataManager.addListener(listener)
         mediaDataManager.onNotificationAdded(KEY, mediaNotification)
         mediaDataManager.onMediaDataLoaded(KEY, oldKey = null, data = mock(MediaData::class.java))
         mediaDataManager.dismissMediaData(KEY, 0L)
@@ -284,26 +298,4 @@
 
         verify(listener).onMediaDataRemoved(eq(KEY))
     }
-
-    /**
-     * Simple implementation of [MediaDataManager.Listener] for the test.
-     *
-     * Giving up on trying to get a mock Listener and ArgumentCaptor to work.
-     */
-    private class TestListener : MediaDataManager.Listener {
-        var data: MediaData? = null
-        var key: String? = null
-        var oldKey: String? = null
-        var removedKey: String? = null
-
-        override fun onMediaDataLoaded(key: String, oldKey: String?, data: MediaData) {
-            this.key = key
-            this.oldKey = oldKey
-            this.data = data
-        }
-
-        override fun onMediaDataRemoved(key: String) {
-            removedKey = key
-        }
-    }
 }
diff --git a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDeviceManagerTest.kt b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDeviceManagerTest.kt
index 7bc15dd..fdb432c 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/media/MediaDeviceManagerTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/media/MediaDeviceManagerTest.kt
@@ -68,7 +68,6 @@
 public class MediaDeviceManagerTest : SysuiTestCase() {
 
     private lateinit var manager: MediaDeviceManager
-    @Mock private lateinit var mediaDataManager: MediaDataManager
     @Mock private lateinit var lmmFactory: LocalMediaManagerFactory
     @Mock private lateinit var lmm: LocalMediaManager
     @Mock private lateinit var mr2: MediaRouter2Manager
@@ -91,7 +90,7 @@
         fakeFgExecutor = FakeExecutor(FakeSystemClock())
         fakeBgExecutor = FakeExecutor(FakeSystemClock())
         manager = MediaDeviceManager(context, lmmFactory, mr2, fakeFgExecutor, fakeBgExecutor,
-                mediaDataManager, dumpster)
+                dumpster)
         manager.addListener(listener)
 
         // Configure mocks.
diff --git a/packages/SystemUI/tests/src/com/android/systemui/media/MediaSessionBasedFilterTest.kt b/packages/SystemUI/tests/src/com/android/systemui/media/MediaSessionBasedFilterTest.kt
new file mode 100644
index 0000000..887cc77
--- /dev/null
+++ b/packages/SystemUI/tests/src/com/android/systemui/media/MediaSessionBasedFilterTest.kt
@@ -0,0 +1,383 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.media
+
+import android.graphics.Color
+import android.media.session.MediaController
+import android.media.session.MediaController.PlaybackInfo
+import android.media.session.MediaSession
+import android.media.session.MediaSessionManager
+import android.testing.AndroidTestingRunner
+import android.testing.TestableLooper
+import androidx.test.filters.SmallTest
+
+import com.android.systemui.SysuiTestCase
+import com.android.systemui.util.concurrency.FakeExecutor
+import com.android.systemui.util.mockito.eq
+import com.android.systemui.util.time.FakeSystemClock
+
+import org.junit.After
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mock
+import org.mockito.Mockito
+import org.mockito.Mockito.any
+import org.mockito.Mockito.never
+import org.mockito.Mockito.reset
+import org.mockito.Mockito.verify
+import org.mockito.junit.MockitoJUnit
+import org.mockito.Mockito.`when` as whenever
+
+private const val PACKAGE = "PKG"
+private const val KEY = "TEST_KEY"
+private const val NOTIF_KEY = "TEST_KEY"
+private const val SESSION_ARTIST = "SESSION_ARTIST"
+private const val SESSION_TITLE = "SESSION_TITLE"
+private const val APP_NAME = "APP_NAME"
+private const val USER_ID = 0
+
+private val info = MediaData(
+    userId = USER_ID,
+    initialized = true,
+    backgroundColor = Color.DKGRAY,
+    app = APP_NAME,
+    appIcon = null,
+    artist = SESSION_ARTIST,
+    song = SESSION_TITLE,
+    artwork = null,
+    actions = emptyList(),
+    actionsToShowInCompact = emptyList(),
+    packageName = PACKAGE,
+    token = null,
+    clickIntent = null,
+    device = null,
+    active = true,
+    resumeAction = null,
+    resumption = false,
+    notificationKey = NOTIF_KEY,
+    hasCheckedForResume = false
+)
+
+private fun <T> eq(value: T): T = Mockito.eq(value) ?: value
+
+@SmallTest
+@RunWith(AndroidTestingRunner::class)
+@TestableLooper.RunWithLooper
+public class MediaSessionBasedFilterTest : SysuiTestCase() {
+
+    @JvmField @Rule val mockito = MockitoJUnit.rule()
+
+    // Unit to be tested
+    private lateinit var filter: MediaSessionBasedFilter
+
+    private lateinit var sessionListener: MediaSessionManager.OnActiveSessionsChangedListener
+    @Mock private lateinit var mediaListener: MediaDataManager.Listener
+
+    // MediaSessionBasedFilter dependencies
+    @Mock private lateinit var mediaSessionManager: MediaSessionManager
+    private lateinit var fgExecutor: FakeExecutor
+    private lateinit var bgExecutor: FakeExecutor
+
+    @Mock private lateinit var controller1: MediaController
+    @Mock private lateinit var controller2: MediaController
+    @Mock private lateinit var controller3: MediaController
+    @Mock private lateinit var controller4: MediaController
+
+    private lateinit var token1: MediaSession.Token
+    private lateinit var token2: MediaSession.Token
+    private lateinit var token3: MediaSession.Token
+    private lateinit var token4: MediaSession.Token
+
+    @Mock private lateinit var remotePlaybackInfo: PlaybackInfo
+    @Mock private lateinit var localPlaybackInfo: PlaybackInfo
+
+    private lateinit var session1: MediaSession
+    private lateinit var session2: MediaSession
+    private lateinit var session3: MediaSession
+    private lateinit var session4: MediaSession
+
+    private lateinit var mediaData1: MediaData
+    private lateinit var mediaData2: MediaData
+    private lateinit var mediaData3: MediaData
+    private lateinit var mediaData4: MediaData
+
+    @Before
+    fun setUp() {
+        fgExecutor = FakeExecutor(FakeSystemClock())
+        bgExecutor = FakeExecutor(FakeSystemClock())
+        filter = MediaSessionBasedFilter(context, mediaSessionManager, fgExecutor, bgExecutor)
+
+        // Configure mocks.
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(emptyList())
+
+        session1 = MediaSession(context, "MediaSessionBasedFilter1")
+        session2 = MediaSession(context, "MediaSessionBasedFilter2")
+        session3 = MediaSession(context, "MediaSessionBasedFilter3")
+        session4 = MediaSession(context, "MediaSessionBasedFilter4")
+
+        token1 = session1.sessionToken
+        token2 = session2.sessionToken
+        token3 = session3.sessionToken
+        token4 = session4.sessionToken
+
+        whenever(controller1.getSessionToken()).thenReturn(token1)
+        whenever(controller2.getSessionToken()).thenReturn(token2)
+        whenever(controller3.getSessionToken()).thenReturn(token3)
+        whenever(controller4.getSessionToken()).thenReturn(token4)
+
+        whenever(controller1.getPackageName()).thenReturn(PACKAGE)
+        whenever(controller2.getPackageName()).thenReturn(PACKAGE)
+        whenever(controller3.getPackageName()).thenReturn(PACKAGE)
+        whenever(controller4.getPackageName()).thenReturn(PACKAGE)
+
+        mediaData1 = info.copy(token = token1)
+        mediaData2 = info.copy(token = token2)
+        mediaData3 = info.copy(token = token3)
+        mediaData4 = info.copy(token = token4)
+
+        whenever(remotePlaybackInfo.getPlaybackType()).thenReturn(PlaybackInfo.PLAYBACK_TYPE_REMOTE)
+        whenever(localPlaybackInfo.getPlaybackType()).thenReturn(PlaybackInfo.PLAYBACK_TYPE_LOCAL)
+
+        whenever(controller1.getPlaybackInfo()).thenReturn(localPlaybackInfo)
+        whenever(controller2.getPlaybackInfo()).thenReturn(localPlaybackInfo)
+        whenever(controller3.getPlaybackInfo()).thenReturn(localPlaybackInfo)
+        whenever(controller4.getPlaybackInfo()).thenReturn(localPlaybackInfo)
+
+        // Capture listener
+        bgExecutor.runAllReady()
+        val listenerCaptor = ArgumentCaptor.forClass(
+                MediaSessionManager.OnActiveSessionsChangedListener::class.java)
+        verify(mediaSessionManager).addOnActiveSessionsChangedListener(
+                listenerCaptor.capture(), any())
+        sessionListener = listenerCaptor.value
+
+        filter.addListener(mediaListener)
+    }
+
+    @After
+    fun tearDown() {
+        session1.release()
+        session2.release()
+        session3.release()
+        session4.release()
+    }
+
+    @Test
+    fun noMediaSession_loadedEventNotFiltered() {
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+    }
+
+    @Test
+    fun noMediaSession_removedEventNotFiltered() {
+        filter.onMediaDataRemoved(KEY)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        verify(mediaListener).onMediaDataRemoved(eq(KEY))
+    }
+
+    @Test
+    fun matchingMediaSession_loadedEventNotFiltered() {
+        // GIVEN an active session
+        val controllers = listOf(controller1)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matches the session
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+    }
+
+    @Test
+    fun matchingMediaSession_removedEventNotFiltered() {
+        // GIVEN an active session
+        val controllers = listOf(controller1)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a removed event is received
+        filter.onMediaDataRemoved(KEY)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataRemoved(eq(KEY))
+    }
+
+    @Test
+    fun remoteSession_loadedEventNotFiltered() {
+        // GIVEN a remove session
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matche the session
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+    }
+
+    @Test
+    fun remoteAndLocalSessions_localLoadedEventFiltered() {
+        // GIVEN remote and local sessions
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1, controller2)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matches the remote session
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+        // WHEN a loaded event is received that matches the local session
+        filter.onMediaDataLoaded(KEY, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is filtered
+        verify(mediaListener, never()).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData2))
+    }
+
+    @Test
+    fun remoteAndLocalHaveDifferentKeys_localLoadedEventFiltered() {
+        // GIVEN remote and local sessions
+        val key1 = "KEY_1"
+        val key2 = "KEY_2"
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1, controller2)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matches the remote session
+        filter.onMediaDataLoaded(key1, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(key1), eq(null), eq(mediaData1))
+        // WHEN a loaded event is received that matches the local session
+        filter.onMediaDataLoaded(key2, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is filtered
+        verify(mediaListener, never()).onMediaDataLoaded(eq(key2), eq(null), eq(mediaData2))
+        // AND there should be a removed event for key2
+        verify(mediaListener).onMediaDataRemoved(eq(key2))
+    }
+
+    @Test
+    fun multipleRemoteSessions_loadedEventNotFiltered() {
+        // GIVEN two remote sessions
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        whenever(controller2.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1, controller2)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matches the remote session
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+        // WHEN a loaded event is received that matches the local session
+        filter.onMediaDataLoaded(KEY, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData2))
+    }
+
+    @Test
+    fun multipleOtherSessions_loadedEventNotFiltered() {
+        // GIVEN multiple active sessions from other packages
+        val controllers = listOf(controller1, controller2, controller3, controller4)
+        whenever(controller1.getPackageName()).thenReturn("PKG_1")
+        whenever(controller2.getPackageName()).thenReturn("PKG_2")
+        whenever(controller3.getPackageName()).thenReturn("PKG_3")
+        whenever(controller4.getPackageName()).thenReturn("PKG_4")
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received
+        filter.onMediaDataLoaded(KEY, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the event is not filtered
+        verify(mediaListener).onMediaDataLoaded(eq(KEY), eq(null), eq(mediaData1))
+    }
+
+    @Test
+    fun doNotFilterDuringKeyMigration() {
+        val key1 = "KEY_1"
+        val key2 = "KEY_2"
+        // GIVEN a loaded event
+        filter.onMediaDataLoaded(key1, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        reset(mediaListener)
+        // GIVEN remote and local sessions
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1, controller2)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // WHEN a loaded event is received that matches the local session but it is a key migration
+        filter.onMediaDataLoaded(key2, key1, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the key migration event is fired
+        verify(mediaListener).onMediaDataLoaded(eq(key2), eq(key1), eq(mediaData2))
+    }
+
+    @Test
+    fun filterAfterKeyMigration() {
+        val key1 = "KEY_1"
+        val key2 = "KEY_2"
+        // GIVEN a loaded event
+        filter.onMediaDataLoaded(key1, null, mediaData1)
+        filter.onMediaDataLoaded(key1, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        reset(mediaListener)
+        // GIVEN remote and local sessions
+        whenever(controller1.getPlaybackInfo()).thenReturn(remotePlaybackInfo)
+        val controllers = listOf(controller1, controller2)
+        whenever(mediaSessionManager.getActiveSessions(any())).thenReturn(controllers)
+        sessionListener.onActiveSessionsChanged(controllers)
+        // GIVEN that the keys have been migrated
+        filter.onMediaDataLoaded(key2, key1, mediaData1)
+        filter.onMediaDataLoaded(key2, key1, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        reset(mediaListener)
+        // WHEN a loaded event is received that matches the local session
+        filter.onMediaDataLoaded(key2, null, mediaData2)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the key migration event is filtered
+        verify(mediaListener, never()).onMediaDataLoaded(eq(key2), eq(null), eq(mediaData2))
+        // WHEN a loaded event is received that matches the remote session
+        filter.onMediaDataLoaded(key2, null, mediaData1)
+        bgExecutor.runAllReady()
+        fgExecutor.runAllReady()
+        // THEN the key migration event is fired
+        verify(mediaListener).onMediaDataLoaded(eq(key2), eq(null), eq(mediaData1))
+    }
+}