Make UserSwitchObserver.onBeforeUserSwitching oneway but still blocking.

Bug: 371536480
Test: atest UserControllerTest
Test: atest UserTrackerImplTest
Flag: EXEMPT bugfix
(cherry picked from commit a3bd1e2befed1ccc7d3bb07fc43421dfa809d2b3)
(cherry picked from https://googleplex-android-review.googlesource.com/q/commit:f7efa779da5c59085b38cb73da61ef0d83b672b6)
Merged-In: I6fd04b00ab768533df01a3eca613f388bf70e42a
Change-Id: I6fd04b00ab768533df01a3eca613f388bf70e42a
diff --git a/core/java/android/app/IUserSwitchObserver.aidl b/core/java/android/app/IUserSwitchObserver.aidl
index 1ff7a17..d71ee7c 100644
--- a/core/java/android/app/IUserSwitchObserver.aidl
+++ b/core/java/android/app/IUserSwitchObserver.aidl
@@ -19,10 +19,10 @@
 import android.os.IRemoteCallback;
 
 /** {@hide} */
-interface IUserSwitchObserver {
-    void onBeforeUserSwitching(int newUserId);
-    oneway void onUserSwitching(int newUserId, IRemoteCallback reply);
-    oneway void onUserSwitchComplete(int newUserId);
-    oneway void onForegroundProfileSwitch(int newProfileId);
-    oneway void onLockedBootComplete(int newUserId);
+oneway interface IUserSwitchObserver {
+    void onBeforeUserSwitching(int newUserId, IRemoteCallback reply);
+    void onUserSwitching(int newUserId, IRemoteCallback reply);
+    void onUserSwitchComplete(int newUserId);
+    void onForegroundProfileSwitch(int newProfileId);
+    void onLockedBootComplete(int newUserId);
 }
diff --git a/core/java/android/app/UserSwitchObserver.java b/core/java/android/app/UserSwitchObserver.java
index 727799a1..1664cfb 100644
--- a/core/java/android/app/UserSwitchObserver.java
+++ b/core/java/android/app/UserSwitchObserver.java
@@ -30,7 +30,11 @@
     }
 
     @Override
-    public void onBeforeUserSwitching(int newUserId) throws RemoteException {}
+    public void onBeforeUserSwitching(int newUserId, IRemoteCallback reply) throws RemoteException {
+        if (reply != null) {
+            reply.sendResult(null);
+        }
+    }
 
     @Override
     public void onUserSwitching(int newUserId, IRemoteCallback reply) throws RemoteException {
diff --git a/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt b/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
index 1a997a7..eab8a13 100644
--- a/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
+++ b/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
@@ -72,10 +72,17 @@
     @WeaklyReferencedCallback
     interface Callback {
         /**
-         * Notifies that the current user will be changed.
+         * Same as {@link onBeforeUserSwitching(Int, Runnable)} but the callback will be called
+         * automatically after the completion of this method.
          */
         fun onBeforeUserSwitching(newUser: Int) {}
 
+        /** Notifies that the current user will be changed. */
+        fun onBeforeUserSwitching(newUser: Int, resultCallback: Runnable) {
+            onBeforeUserSwitching(newUser)
+            resultCallback.run()
+        }
+
         /**
          * Same as {@link onUserChanging(Int, Context, Runnable)} but the callback will be
          * called automatically after the completion of this method.
diff --git a/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt b/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
index 0a1f649..82eff461 100644
--- a/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
+++ b/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
@@ -192,8 +192,9 @@
 
     private fun registerUserSwitchObserver() {
         iActivityManager.registerUserSwitchObserver(object : UserSwitchObserver() {
-            override fun onBeforeUserSwitching(newUserId: Int) {
+            override fun onBeforeUserSwitching(newUserId: Int, reply: IRemoteCallback?) {
                 handleBeforeUserSwitching(newUserId)
+                reply?.sendResult(null)
             }
 
             override fun onUserSwitching(newUserId: Int, reply: IRemoteCallback?) {
@@ -228,8 +229,7 @@
         setUserIdInternal(newUserId)
 
         notifySubscribers { callback, resultCallback ->
-            callback.onBeforeUserSwitching(newUserId)
-            resultCallback.run()
+            callback.onBeforeUserSwitching(newUserId, resultCallback)
         }.await()
     }
 
diff --git a/packages/SystemUI/tests/src/com/android/systemui/settings/UserTrackerImplTest.kt b/packages/SystemUI/tests/src/com/android/systemui/settings/UserTrackerImplTest.kt
index 774aa51..57b82e3 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/settings/UserTrackerImplTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/settings/UserTrackerImplTest.kt
@@ -81,6 +81,9 @@
     private lateinit var iActivityManager: IActivityManager
 
     @Mock
+    private lateinit var beforeUserSwitchingReply: IRemoteCallback
+
+    @Mock
     private lateinit var userSwitchingReply: IRemoteCallback
 
     @Mock(stubOnly = true)
@@ -216,9 +219,10 @@
 
         val captor = ArgumentCaptor.forClass(IUserSwitchObserver::class.java)
         verify(iActivityManager).registerUserSwitchObserver(capture(captor), anyString())
-        captor.value.onBeforeUserSwitching(newID)
+        captor.value.onBeforeUserSwitching(newID, beforeUserSwitchingReply)
         captor.value.onUserSwitching(newID, userSwitchingReply)
         runCurrent()
+        verify(beforeUserSwitchingReply).sendResult(any())
         verify(userSwitchingReply).sendResult(any())
 
         verify(userManager).getProfiles(newID)
@@ -343,10 +347,11 @@
 
         val captor = ArgumentCaptor.forClass(IUserSwitchObserver::class.java)
         verify(iActivityManager).registerUserSwitchObserver(capture(captor), anyString())
-        captor.value.onBeforeUserSwitching(newID)
+        captor.value.onBeforeUserSwitching(newID, beforeUserSwitchingReply)
         captor.value.onUserSwitching(newID, userSwitchingReply)
         runCurrent()
 
+        verify(beforeUserSwitchingReply).sendResult(any())
         verify(userSwitchingReply).sendResult(any())
         assertThat(callback.calledOnUserChanging).isEqualTo(1)
         assertThat(callback.lastUser).isEqualTo(newID)
@@ -395,7 +400,7 @@
 
         val captor = ArgumentCaptor.forClass(IUserSwitchObserver::class.java)
         verify(iActivityManager).registerUserSwitchObserver(capture(captor), anyString())
-        captor.value.onBeforeUserSwitching(newID)
+        captor.value.onBeforeUserSwitching(newID, any())
         captor.value.onUserSwitchComplete(newID)
         runCurrent()
 
@@ -449,8 +454,10 @@
 
         val captor = ArgumentCaptor.forClass(IUserSwitchObserver::class.java)
         verify(iActivityManager).registerUserSwitchObserver(capture(captor), anyString())
+        captor.value.onBeforeUserSwitching(newID, beforeUserSwitchingReply)
         captor.value.onUserSwitching(newID, userSwitchingReply)
         runCurrent()
+        verify(beforeUserSwitchingReply).sendResult(any())
         verify(userSwitchingReply).sendResult(any())
         captor.value.onUserSwitchComplete(newID)
 
@@ -465,6 +472,7 @@
     }
 
     private class TestCallback : UserTracker.Callback {
+        var calledOnBeforeUserChanging = 0
         var calledOnUserChanging = 0
         var calledOnUserChanged = 0
         var calledOnProfilesChanged = 0
@@ -472,6 +480,11 @@
         var lastUserContext: Context? = null
         var lastUserProfiles = emptyList<UserInfo>()
 
+        override fun onBeforeUserSwitching(newUser: Int) {
+            calledOnBeforeUserChanging++
+            lastUser = newUser
+        }
+
         override fun onUserChanging(newUser: Int, userContext: Context) {
             calledOnUserChanging++
             lastUser = newUser
diff --git a/services/core/java/com/android/server/am/UserController.java b/services/core/java/com/android/server/am/UserController.java
index b94d175..f755ebc 100644
--- a/services/core/java/com/android/server/am/UserController.java
+++ b/services/core/java/com/android/server/am/UserController.java
@@ -159,6 +159,7 @@
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
 /**
@@ -1867,8 +1868,14 @@
                 return false;
             }
 
-            mHandler.post(() -> startUserInternalOnHandler(userId, oldUserId, userStartMode,
-                    unlockListener, callingUid, callingPid));
+            final Runnable continueStartUserInternal = () -> continueStartUserInternal(userInfo,
+                    oldUserId, userStartMode, unlockListener, callingUid, callingPid);
+            if (foreground) {
+                mHandler.post(() -> dispatchOnBeforeUserSwitching(userId, () ->
+                        mHandler.post(continueStartUserInternal)));
+            } else {
+                continueStartUserInternal.run();
+            }
         } finally {
             Binder.restoreCallingIdentity(ident);
         }
@@ -1876,11 +1883,11 @@
         return true;
     }
 
-    private void startUserInternalOnHandler(int userId, int oldUserId, int userStartMode,
+    private void continueStartUserInternal(UserInfo userInfo, int oldUserId, int userStartMode,
             IProgressListener unlockListener, int callingUid, int callingPid) {
         final TimingsTraceAndSlog t = new TimingsTraceAndSlog();
         final boolean foreground = userStartMode == USER_START_MODE_FOREGROUND;
-        final UserInfo userInfo = getUserInfo(userId);
+        final int userId = userInfo.id;
 
         boolean needStart = false;
         boolean updateUmState = false;
@@ -1940,7 +1947,6 @@
             // it should be moved outside, but for now it's not as there are many calls to
             // external components here afterwards
             updateProfileRelatedCaches();
-            dispatchOnBeforeUserSwitching(userId);
             mInjector.getWindowManager().setCurrentUser(userId);
             mInjector.reportCurWakefulnessUsageEvent();
             // Once the internal notion of the active user has switched, we lock the device
@@ -2240,25 +2246,42 @@
         mUserSwitchObservers.finishBroadcast();
     }
 
-    private void dispatchOnBeforeUserSwitching(@UserIdInt int newUserId) {
+    private void dispatchOnBeforeUserSwitching(@UserIdInt int newUserId, Runnable onComplete) {
         final TimingsTraceAndSlog t = new TimingsTraceAndSlog();
         t.traceBegin("dispatchOnBeforeUserSwitching-" + newUserId);
-        final int observerCount = mUserSwitchObservers.beginBroadcast();
-        for (int i = 0; i < observerCount; i++) {
-            final String name = "#" + i + " " + mUserSwitchObservers.getBroadcastCookie(i);
-            t.traceBegin("onBeforeUserSwitching-" + name);
+        final AtomicBoolean isFirst = new AtomicBoolean(true);
+        startTimeoutForOnBeforeUserSwitching(isFirst, onComplete);
+        informUserSwitchObservers((observer, callback) -> {
             try {
-                mUserSwitchObservers.getBroadcastItem(i).onBeforeUserSwitching(newUserId);
+                observer.onBeforeUserSwitching(newUserId, callback);
             } catch (RemoteException e) {
-                // Ignore
-            } finally {
-                t.traceEnd();
+                // ignore
             }
-        }
-        mUserSwitchObservers.finishBroadcast();
+        }, () -> {
+            if (isFirst.getAndSet(false)) {
+                onComplete.run();
+            }
+        }, "onBeforeUserSwitching");
         t.traceEnd();
     }
 
+    private void startTimeoutForOnBeforeUserSwitching(AtomicBoolean isFirst,
+            Runnable onComplete) {
+        final long timeout = getUserSwitchTimeoutMs();
+        mHandler.postDelayed(() -> {
+            if (isFirst.getAndSet(false)) {
+                String unresponsiveObservers;
+                synchronized (mLock) {
+                    unresponsiveObservers = String.join(", ", mCurWaitingUserSwitchCallbacks);
+                }
+                Slogf.e(TAG, "Timeout on dispatchOnBeforeUserSwitching. These UserSwitchObservers "
+                        + "did not respond in " + timeout + "ms: " + unresponsiveObservers + ".");
+                onComplete.run();
+            }
+        }, timeout);
+    }
+
+
     /** Called on handler thread */
     @VisibleForTesting
     void dispatchUserSwitchComplete(@UserIdInt int oldUserId, @UserIdInt int newUserId) {
@@ -2438,70 +2461,76 @@
         t.traceBegin("dispatchUserSwitch-" + oldUserId + "-to-" + newUserId);
 
         EventLog.writeEvent(EventLogTags.UC_DISPATCH_USER_SWITCH, oldUserId, newUserId);
-
-        final int observerCount = mUserSwitchObservers.beginBroadcast();
-        if (observerCount > 0) {
-            final ArraySet<String> curWaitingUserSwitchCallbacks = new ArraySet<>();
-            synchronized (mLock) {
-                uss.switching = true;
-                mCurWaitingUserSwitchCallbacks = curWaitingUserSwitchCallbacks;
+        uss.switching = true;
+        informUserSwitchObservers((observer, callback) -> {
+            try {
+                observer.onUserSwitching(newUserId, callback);
+            } catch (RemoteException e) {
+                // ignore
             }
-            final AtomicInteger waitingCallbacksCount = new AtomicInteger(observerCount);
-            final long userSwitchTimeoutMs = getUserSwitchTimeoutMs();
-            final long dispatchStartedTime = SystemClock.elapsedRealtime();
-            for (int i = 0; i < observerCount; i++) {
-                final long dispatchStartedTimeForObserver = SystemClock.elapsedRealtime();
-                try {
-                    // Prepend with unique prefix to guarantee that keys are unique
-                    final String name = "#" + i + " " + mUserSwitchObservers.getBroadcastCookie(i);
-                    synchronized (mLock) {
-                        curWaitingUserSwitchCallbacks.add(name);
-                    }
-                    final IRemoteCallback callback = new IRemoteCallback.Stub() {
-                        @Override
-                        public void sendResult(Bundle data) throws RemoteException {
-                            asyncTraceEnd("onUserSwitching-" + name, newUserId);
-                            synchronized (mLock) {
-                                long delayForObserver = SystemClock.elapsedRealtime()
-                                        - dispatchStartedTimeForObserver;
-                                if (delayForObserver > LONG_USER_SWITCH_OBSERVER_WARNING_TIME_MS) {
-                                    Slogf.w(TAG, "User switch slowed down by observer " + name
-                                            + ": result took " + delayForObserver
-                                            + " ms to process.");
-                                }
-
-                                long totalDelay = SystemClock.elapsedRealtime()
-                                        - dispatchStartedTime;
-                                if (totalDelay > userSwitchTimeoutMs) {
-                                    Slogf.e(TAG, "User switch timeout: observer " + name
-                                            + "'s result was received " + totalDelay
-                                            + " ms after dispatchUserSwitch.");
-                                }
-
-                                curWaitingUserSwitchCallbacks.remove(name);
-                                // Continue switching if all callbacks have been notified and
-                                // user switching session is still valid
-                                if (waitingCallbacksCount.decrementAndGet() == 0
-                                        && (curWaitingUserSwitchCallbacks
-                                        == mCurWaitingUserSwitchCallbacks)) {
-                                    sendContinueUserSwitchLU(uss, oldUserId, newUserId);
-                                }
-                            }
-                        }
-                    };
-                    asyncTraceBegin("onUserSwitching-" + name, newUserId);
-                    mUserSwitchObservers.getBroadcastItem(i).onUserSwitching(newUserId, callback);
-                } catch (RemoteException e) {
-                    // Ignore
-                }
-            }
-        } else {
+        }, () -> {
             synchronized (mLock) {
                 sendContinueUserSwitchLU(uss, oldUserId, newUserId);
             }
+        }, "onUserSwitching");
+        t.traceEnd();
+    }
+
+    void informUserSwitchObservers(BiConsumer<IUserSwitchObserver, IRemoteCallback> consumer,
+            final Runnable onComplete, String trace) {
+        final int observerCount = mUserSwitchObservers.beginBroadcast();
+        if (observerCount == 0) {
+            onComplete.run();
+            mUserSwitchObservers.finishBroadcast();
+            return;
+        }
+        final ArraySet<String> curWaitingUserSwitchCallbacks = new ArraySet<>();
+        synchronized (mLock) {
+            mCurWaitingUserSwitchCallbacks = curWaitingUserSwitchCallbacks;
+        }
+        final AtomicInteger waitingCallbacksCount = new AtomicInteger(observerCount);
+        final long userSwitchTimeoutMs = getUserSwitchTimeoutMs();
+        final long dispatchStartedTime = SystemClock.elapsedRealtime();
+        for (int i = 0; i < observerCount; i++) {
+            final long dispatchStartedTimeForObserver = SystemClock.elapsedRealtime();
+            // Prepend with unique prefix to guarantee that keys are unique
+            final String name = "#" + i + " " + mUserSwitchObservers.getBroadcastCookie(i);
+            synchronized (mLock) {
+                curWaitingUserSwitchCallbacks.add(name);
+            }
+            final IRemoteCallback callback = new IRemoteCallback.Stub() {
+                @Override
+                public void sendResult(Bundle data) throws RemoteException {
+                    asyncTraceEnd(trace + "-" + name, 0);
+                    synchronized (mLock) {
+                        long delayForObserver = SystemClock.elapsedRealtime()
+                                - dispatchStartedTimeForObserver;
+                        if (delayForObserver > LONG_USER_SWITCH_OBSERVER_WARNING_TIME_MS) {
+                            Slogf.w(TAG, "User switch slowed down by observer " + name
+                                    + ": result took " + delayForObserver
+                                    + " ms to process. " + trace);
+                        }
+                        long totalDelay = SystemClock.elapsedRealtime() - dispatchStartedTime;
+                        if (totalDelay > userSwitchTimeoutMs) {
+                            Slogf.e(TAG, "User switch timeout: observer " + name
+                                    + "'s result was received " + totalDelay
+                                    + " ms after dispatchUserSwitch. " + trace);
+                        }
+                        curWaitingUserSwitchCallbacks.remove(name);
+                        // Continue switching if all callbacks have been notified and
+                        // user switching session is still valid
+                        if (waitingCallbacksCount.decrementAndGet() == 0
+                                && (curWaitingUserSwitchCallbacks
+                                == mCurWaitingUserSwitchCallbacks)) {
+                            onComplete.run();
+                        }
+                    }
+                }
+            };
+            asyncTraceBegin(trace + "-" + name, 0);
+            consumer.accept(mUserSwitchObservers.getBroadcastItem(i), callback);
         }
         mUserSwitchObservers.finishBroadcast();
-        t.traceEnd(); // end dispatchUserSwitch-
     }
 
     @GuardedBy("mLock")
diff --git a/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java b/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
index 9743dc9..45bd4b3 100644
--- a/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
@@ -90,6 +90,7 @@
 import android.os.Message;
 import android.os.PowerManagerInternal;
 import android.os.RemoteException;
+import android.os.SystemClock;
 import android.os.UserHandle;
 import android.os.UserManager;
 import android.os.storage.IStorageManager;
@@ -176,14 +177,12 @@
             Intent.ACTION_USER_STARTING);
 
     private static final Set<Integer> START_FOREGROUND_USER_MESSAGE_CODES = newHashSet(
-            0, // for startUserInternalOnHandler
             REPORT_USER_SWITCH_MSG,
             USER_SWITCH_TIMEOUT_MSG,
             USER_START_MSG,
             USER_CURRENT_MSG);
 
     private static final Set<Integer> START_BACKGROUND_USER_MESSAGE_CODES = newHashSet(
-            0, // for startUserInternalOnHandler
             USER_START_MSG,
             REPORT_LOCKED_BOOT_COMPLETE_MSG);
 
@@ -367,7 +366,7 @@
         // and the cascade effect goes on...). In fact, a better approach would to not assert the
         // binder calls, but their side effects (in this case, that the user is stopped right away)
         assertWithMessage("wrong binder message calls").that(mInjector.mHandler.getMessageCodes())
-                .containsExactly(/* for startUserInternalOnHandler */ 0, USER_START_MSG);
+                .containsExactly(USER_START_MSG);
     }
 
     private void startUserAssertions(
@@ -410,17 +409,12 @@
     @Test
     public void testDispatchUserSwitch() throws RemoteException {
         // Prepare mock observer and register it
-        IUserSwitchObserver observer = mock(IUserSwitchObserver.class);
-        when(observer.asBinder()).thenReturn(new Binder());
-        doAnswer(invocation -> {
-            IRemoteCallback callback = (IRemoteCallback) invocation.getArguments()[1];
-            callback.sendResult(null);
-            return null;
-        }).when(observer).onUserSwitching(anyInt(), any());
-        mUserController.registerUserSwitchObserver(observer, "mock");
+        IUserSwitchObserver observer = registerUserSwitchObserver(
+                /* replyToOnBeforeUserSwitchingCallback= */ true,
+                /* replyToOnUserSwitchingCallback= */ true);
         // Start user -- this will update state of mUserController
         mUserController.startUser(TEST_USER_ID, USER_START_MODE_FOREGROUND);
-        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
+        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID), any());
         Message reportMsg = mInjector.mHandler.getMessageForCode(REPORT_USER_SWITCH_MSG);
         assertNotNull(reportMsg);
         UserState userState = (UserState) reportMsg.obj;
@@ -445,13 +439,13 @@
 
     @Test
     public void testDispatchUserSwitchBadReceiver() throws RemoteException {
-        // Prepare mock observer which doesn't notify the callback and register it
-        IUserSwitchObserver observer = mock(IUserSwitchObserver.class);
-        when(observer.asBinder()).thenReturn(new Binder());
-        mUserController.registerUserSwitchObserver(observer, "mock");
+        // Prepare mock observer which doesn't notify the onUserSwitching callback and register it
+        IUserSwitchObserver observer = registerUserSwitchObserver(
+                /* replyToOnBeforeUserSwitchingCallback= */ true,
+                /* replyToOnUserSwitchingCallback= */ false);
         // Start user -- this will update state of mUserController
         mUserController.startUser(TEST_USER_ID, USER_START_MODE_FOREGROUND);
-        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
+        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID), any());
         Message reportMsg = mInjector.mHandler.getMessageForCode(REPORT_USER_SWITCH_MSG);
         assertNotNull(reportMsg);
         UserState userState = (UserState) reportMsg.obj;
@@ -542,7 +536,6 @@
         expectedCodes.add(REPORT_USER_SWITCH_COMPLETE_MSG);
         if (backgroundUserStopping) {
             expectedCodes.add(CLEAR_USER_JOURNEY_SESSION_MSG);
-            expectedCodes.add(0); // this is for directly posting in stopping.
         }
         if (expectScheduleBackgroundUserStopping) {
             expectedCodes.add(SCHEDULED_STOP_BACKGROUND_USER_MSG);
@@ -558,9 +551,9 @@
     @Test
     public void testDispatchUserSwitchComplete() throws RemoteException {
         // Prepare mock observer and register it
-        IUserSwitchObserver observer = mock(IUserSwitchObserver.class);
-        when(observer.asBinder()).thenReturn(new Binder());
-        mUserController.registerUserSwitchObserver(observer, "mock");
+        IUserSwitchObserver observer = registerUserSwitchObserver(
+                /* replyToOnBeforeUserSwitchingCallback= */ true,
+                /* replyToOnUserSwitchingCallback= */ true);
         // Start user -- this will update state of mUserController
         mUserController.startUser(TEST_USER_ID, USER_START_MODE_FOREGROUND);
         Message reportMsg = mInjector.mHandler.getMessageForCode(REPORT_USER_SWITCH_MSG);
@@ -1596,6 +1589,29 @@
         verify(mInjector, never()).onSystemUserVisibilityChanged(anyBoolean());
     }
 
+    private IUserSwitchObserver registerUserSwitchObserver(
+            boolean replyToOnBeforeUserSwitchingCallback, boolean replyToOnUserSwitchingCallback)
+            throws RemoteException {
+        IUserSwitchObserver observer = mock(IUserSwitchObserver.class);
+        when(observer.asBinder()).thenReturn(new Binder());
+        if (replyToOnBeforeUserSwitchingCallback) {
+            doAnswer(invocation -> {
+                IRemoteCallback callback = (IRemoteCallback) invocation.getArguments()[1];
+                callback.sendResult(null);
+                return null;
+            }).when(observer).onBeforeUserSwitching(anyInt(), any());
+        }
+        if (replyToOnUserSwitchingCallback) {
+            doAnswer(invocation -> {
+                IRemoteCallback callback = (IRemoteCallback) invocation.getArguments()[1];
+                callback.sendResult(null);
+                return null;
+            }).when(observer).onUserSwitching(anyInt(), any());
+        }
+        mUserController.registerUserSwitchObserver(observer, "mock");
+        return observer;
+    }
+
     // Should be public to allow mocking
     private static class TestInjector extends UserController.Injector {
         public final TestHandler mHandler;
@@ -1787,6 +1803,7 @@
 
     private static class TestHandler extends Handler {
         private final List<Message> mMessages = new ArrayList<>();
+        private final List<Runnable> mPendingCallbacks = new ArrayList<>();
 
         TestHandler(Looper looper) {
             super(looper);
@@ -1819,14 +1836,24 @@
 
         @Override
         public boolean sendMessageAtTime(Message msg, long uptimeMillis) {
-            Message copy = new Message();
-            copy.copyFrom(msg);
-            mMessages.add(copy);
-            if (msg.getCallback() != null) {
-                msg.getCallback().run();
+            if (msg.getCallback() == null) {
+                Message copy = new Message();
+                copy.copyFrom(msg);
+                mMessages.add(copy);
+            } else {
+                if (SystemClock.uptimeMillis() >= uptimeMillis) {
+                    msg.getCallback().run();
+                } else {
+                    mPendingCallbacks.add(msg.getCallback());
+                }
                 msg.setCallback(null);
             }
             return super.sendMessageAtTime(msg, uptimeMillis);
         }
+
+        private void runPendingCallbacks() {
+            mPendingCallbacks.forEach(Runnable::run);
+            mPendingCallbacks.clear();
+        }
     }
 }