Inform apps of sandbox death through callbacks

Apps can use SdkSandboxLifecycleCallback to receive notifications about
the death of the sandbox after registering the callback via
registerSdkSandboxLifecycleCallback(). If the sandbox has not been
created yet (for e.g. if the app tries to register the callback before
the sandbox is created), the callback is queued until the sandbox is
created.

Bug: 238761711
Bug: 240437810
Test: atest SdkSandboxManagerServiceUnitTest
Test: atest SdkSandboxManagerTest
Change-Id: I5c6d177f643d167a84dca1ab454e967dbcb50992
diff --git a/sdksandbox/framework/api/current.txt b/sdksandbox/framework/api/current.txt
index d724448..792a5e4 100644
--- a/sdksandbox/framework/api/current.txt
+++ b/sdksandbox/framework/api/current.txt
@@ -48,9 +48,11 @@
   }
 
   public final class SdkSandboxManager {
+    method public void addSdkSandboxLifecycleCallback(@NonNull java.util.concurrent.Executor, @NonNull android.app.sdksandbox.SdkSandboxManager.SdkSandboxLifecycleCallback);
     method @NonNull public java.util.List<android.content.pm.SharedLibraryInfo> getLoadedSdkLibrariesInfo();
     method public static int getSdkSandboxState();
     method public void loadSdk(@NonNull String, @NonNull android.os.Bundle, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<android.app.sdksandbox.LoadSdkResponse,android.app.sdksandbox.LoadSdkException>);
+    method public void removeSdkSandboxLifecycleCallback(@NonNull android.app.sdksandbox.SdkSandboxManager.SdkSandboxLifecycleCallback);
     method public void requestSurfacePackage(@NonNull String, int, int, int, @NonNull android.os.Bundle, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<android.app.sdksandbox.RequestSurfacePackageResponse,android.app.sdksandbox.RequestSurfacePackageException>);
     method public void sendData(@NonNull String, @NonNull android.os.Bundle, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<android.app.sdksandbox.SendDataResponse,android.app.sdksandbox.SendDataException>);
     method public void unloadSdk(@NonNull String);
@@ -64,6 +66,10 @@
     field public static final int SEND_DATA_INTERNAL_ERROR = 800; // 0x320
   }
 
+  public static interface SdkSandboxManager.SdkSandboxLifecycleCallback {
+    method public void onSdkSandboxDied();
+  }
+
   public final class SendDataException extends java.lang.Exception {
     ctor public SendDataException(int, @Nullable String);
     ctor public SendDataException(int, @Nullable String, @Nullable Throwable);
diff --git a/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxLifecycleCallback.aidl b/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxLifecycleCallback.aidl
new file mode 100644
index 0000000..e5fb1d5
--- /dev/null
+++ b/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxLifecycleCallback.aidl
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.app.sdksandbox;
+
+/** @hide */
+oneway interface ISdkSandboxLifecycleCallback {
+    void onSdkSandboxDied();
+}
\ No newline at end of file
diff --git a/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxManager.aidl b/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxManager.aidl
index 30fdf05..8732cf0 100644
--- a/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxManager.aidl
+++ b/sdksandbox/framework/java/android/app/sdksandbox/ISdkSandboxManager.aidl
@@ -21,11 +21,14 @@
 
 import android.app.sdksandbox.ILoadSdkCallback;
 import android.app.sdksandbox.IRequestSurfacePackageCallback;
+import android.app.sdksandbox.ISdkSandboxLifecycleCallback;
 import android.app.sdksandbox.ISendDataCallback;
 import android.content.pm.SharedLibraryInfo;
 
 /** @hide */
 interface ISdkSandboxManager {
+    void addSdkSandboxLifecycleCallback(in String callingPackageName, in ISdkSandboxLifecycleCallback callback);
+    void removeSdkSandboxLifecycleCallback(in String callingPackageName, in ISdkSandboxLifecycleCallback callback);
     void loadSdk(in String callingPackageName, in String sdkName, in Bundle params, in ILoadSdkCallback callback);
     void unloadSdk(in String callingPackageName, in String sdkName);
     void requestSurfacePackage(in String callingPackageName, in String sdkName, in IBinder hostToken, int displayId, in int width, in int height, in Bundle params, IRequestSurfacePackageCallback callback);
diff --git a/sdksandbox/framework/java/android/app/sdksandbox/SdkSandboxManager.java b/sdksandbox/framework/java/android/app/sdksandbox/SdkSandboxManager.java
index e31ebd2..88cdf58 100644
--- a/sdksandbox/framework/java/android/app/sdksandbox/SdkSandboxManager.java
+++ b/sdksandbox/framework/java/android/app/sdksandbox/SdkSandboxManager.java
@@ -32,8 +32,11 @@
 import android.os.RemoteException;
 import android.view.SurfaceControlViewHost.SurfacePackage;
 
+import com.android.internal.annotations.GuardedBy;
+
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Executor;
 
@@ -150,6 +153,10 @@
 
     private final Context mContext;
 
+    @GuardedBy("mLifecycleCallbacks")
+    private final ArrayList<SdkSandboxLifecycleCallbackProxy> mLifecycleCallbacks =
+            new ArrayList<>();
+
     /** @hide */
     public SdkSandboxManager(@NonNull Context context, @NonNull ISdkSandboxManager binder) {
         mContext = context;
@@ -180,6 +187,65 @@
     }
 
     /**
+     * Add a callback which gets registered for sdk sandbox lifecycle events, such as sdk sandbox
+     * death. If the sandbox has not yet been created when this is called, the request will be
+     * stored until a sandbox is created, at which point it is activated for that sandbox. Multiple
+     * callbacks can be added to detect death.
+     *
+     * @param callbackExecutor the {@link Executor} on which to invoke the callback
+     * @param callback the {@link SdkSandboxLifecycleCallback} which will receive sdk sandbox
+     *     lifecycle events.
+     */
+    public void addSdkSandboxLifecycleCallback(
+            @NonNull @CallbackExecutor Executor callbackExecutor,
+            @NonNull SdkSandboxLifecycleCallback callback) {
+        if (callbackExecutor == null) {
+            throw new IllegalArgumentException("executor cannot be null");
+        }
+        if (callback == null) {
+            throw new IllegalArgumentException("callback cannot be null");
+        }
+
+        synchronized (mLifecycleCallbacks) {
+            final SdkSandboxLifecycleCallbackProxy callbackProxy =
+                    new SdkSandboxLifecycleCallbackProxy(callbackExecutor, callback);
+            try {
+                mService.addSdkSandboxLifecycleCallback(
+                        mContext.getPackageName(), callbackProxy);
+            } catch (RemoteException e) {
+                throw e.rethrowFromSystemServer();
+            }
+            mLifecycleCallbacks.add(callbackProxy);
+        }
+    }
+
+    /**
+     * Remove an {@link SdkSandboxLifecycleCallback} that was previously added using {@link
+     * SdkSandboxManager#addSdkSandboxLifecycleCallback(Executor, SdkSandboxLifecycleCallback)}
+     *
+     * @param callback the {@link SdkSandboxLifecycleCallback} which was previously added using
+     *     {@link SdkSandboxManager#addSdkSandboxLifecycleCallback(Executor,
+     *     SdkSandboxLifecycleCallback)}
+     */
+    public void removeSdkSandboxLifecycleCallback(
+            @NonNull SdkSandboxLifecycleCallback callback) {
+        synchronized (mLifecycleCallbacks) {
+            for (int i = mLifecycleCallbacks.size() - 1; i >= 0; i--) {
+                final SdkSandboxLifecycleCallbackProxy callbackProxy = mLifecycleCallbacks.get(i);
+                if (callbackProxy.callback == callback) {
+                    try {
+                        mService.removeSdkSandboxLifecycleCallback(
+                                mContext.getPackageName(), callbackProxy);
+                    } catch (RemoteException e) {
+                        throw e.rethrowFromSystemServer();
+                    }
+                    mLifecycleCallbacks.remove(i);
+                }
+            }
+        }
+    }
+
+    /**
      * Load SDK in a SDK sandbox java process.
      *
      * <p>It loads SDK library with {@code sdkName} to a sandbox process asynchronously, caller
@@ -317,6 +383,46 @@
         }
     }
 
+    /**
+     * A callback for tracking events SDK sandbox death.
+     *
+     * <p>The callback can be added using {@link
+     * SdkSandboxManager#addSdkSandboxLifecycleCallback(Executor, SdkSandboxLifecycleCallback)}
+     * and removed using {@link
+     * SdkSandboxManager#removeSdkSandboxLifecycleCallback(SdkSandboxLifecycleCallback)}
+     */
+    public interface SdkSandboxLifecycleCallback {
+        /**
+         * Notifies the client application that the SDK sandbox has died. The sandbox could die for
+         * various reasons, for example, due to memory pressure on the system, or a crash in the
+         * sandbox.
+         *
+         * The system will automatically restart the sandbox process if it died due to a crash.
+         * However, the state of the sandbox will be lost - so any SDKs that were loaded previously
+         * would have to be loaded again, using {@link SdkSandboxManager#loadSdk(String, Bundle,
+         * Executor, OutcomeReceiver)} to continue using them.
+         */
+        void onSdkSandboxDied();
+    }
+
+    /** @hide */
+    private static class SdkSandboxLifecycleCallbackProxy
+            extends ISdkSandboxLifecycleCallback.Stub {
+        private final Executor mExecutor;
+        public final SdkSandboxLifecycleCallback callback;
+
+        SdkSandboxLifecycleCallbackProxy(
+                Executor executor, SdkSandboxLifecycleCallback lifecycleCallback) {
+            mExecutor = executor;
+            callback = lifecycleCallback;
+        }
+
+        @Override
+        public void onSdkSandboxDied() {
+            mExecutor.execute(() -> callback.onSdkSandboxDied());
+        }
+    }
+
     /** @hide */
     private static class LoadSdkReceiverProxy extends ILoadSdkCallback.Stub {
         private final Executor mExecutor;
diff --git a/sdksandbox/service/java/com/android/server/sdksandbox/SdkSandboxManagerService.java b/sdksandbox/service/java/com/android/server/sdksandbox/SdkSandboxManagerService.java
index 065ca21..018caeb 100644
--- a/sdksandbox/service/java/com/android/server/sdksandbox/SdkSandboxManagerService.java
+++ b/sdksandbox/service/java/com/android/server/sdksandbox/SdkSandboxManagerService.java
@@ -26,6 +26,7 @@
 import android.app.ActivityManager;
 import android.app.sdksandbox.ILoadSdkCallback;
 import android.app.sdksandbox.IRequestSurfacePackageCallback;
+import android.app.sdksandbox.ISdkSandboxLifecycleCallback;
 import android.app.sdksandbox.ISdkSandboxManager;
 import android.app.sdksandbox.ISendDataCallback;
 import android.app.sdksandbox.SdkSandboxManager;
@@ -46,6 +47,7 @@
 import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
 import android.os.Process;
+import android.os.RemoteCallbackList;
 import android.os.RemoteException;
 import android.os.UserHandle;
 import android.text.TextUtils;
@@ -109,6 +111,10 @@
     @GuardedBy("mLock")
     private final Set<CallingInfo> mRunningInstrumentations = new ArraySet<>();
 
+    @GuardedBy("mLock")
+    private final ArrayMap<CallingInfo, RemoteCallbackList<ISdkSandboxLifecycleCallback>>
+            mSandboxLifecycleCallbacks = new ArrayMap<>();
+
     private final SdkSandboxManagerLocal mLocalManager;
 
     private final String mAdServicesPackageName;
@@ -172,6 +178,41 @@
     }
 
     @Override
+    public void addSdkSandboxLifecycleCallback(
+            String callingPackageName, ISdkSandboxLifecycleCallback callback) {
+        final int callingUid = Binder.getCallingUid();
+        final CallingInfo callingInfo = new CallingInfo(callingUid, callingPackageName);
+        enforceCallingPackageBelongsToUid(callingInfo);
+
+        synchronized (mLock) {
+            if (mSandboxLifecycleCallbacks.containsKey(callingInfo)) {
+                mSandboxLifecycleCallbacks.get(callingInfo).register(callback);
+            } else {
+                RemoteCallbackList<ISdkSandboxLifecycleCallback> sandboxLifecycleCallbacks =
+                        new RemoteCallbackList<>();
+                sandboxLifecycleCallbacks.register(callback);
+                mSandboxLifecycleCallbacks.put(callingInfo, sandboxLifecycleCallbacks);
+            }
+        }
+    }
+
+    @Override
+    public void removeSdkSandboxLifecycleCallback(
+            String callingPackageName, ISdkSandboxLifecycleCallback callback) {
+        final int callingUid = Binder.getCallingUid();
+        final CallingInfo callingInfo = new CallingInfo(callingUid, callingPackageName);
+        enforceCallingPackageBelongsToUid(callingInfo);
+
+        synchronized (mLock) {
+            RemoteCallbackList<ISdkSandboxLifecycleCallback> sandboxLifecycleCallbacks =
+                    mSandboxLifecycleCallbacks.get(callingInfo);
+            if (sandboxLifecycleCallbacks != null) {
+                sandboxLifecycleCallbacks.unregister(callback);
+            }
+        }
+    }
+
+    @Override
     public void loadSdk(
             String callingPackageName, String sdkName, Bundle params, ILoadSdkCallback callback) {
         final int callingUid = Binder.getCallingUid();
@@ -231,9 +272,6 @@
             return;
         }
 
-        // TODO(b/204991850): ensure requested code is included in the AndroidManifest.xml
-        invokeSdkSandboxServiceToLoadSdk(callingInfo, sdkToken, sdkProviderInfo, params, link);
-
         // Register a death recipient to clean up sdkToken and unbind its service after app dies.
         try {
             synchronized (mLock) {
@@ -245,7 +283,10 @@
         } catch (RemoteException re) {
             // App has already died, cleanup sdk token and link, and unbind its service
             onAppDeath(callingInfo);
+            return;
         }
+
+        invokeSdkSandboxServiceToLoadSdk(callingInfo, sdkToken, sdkProviderInfo, params, link);
     }
 
     @Override
@@ -299,6 +340,7 @@
 
     private void onAppDeath(CallingInfo callingInfo) {
         synchronized (mLock) {
+            mSandboxLifecycleCallbacks.remove(callingInfo);
             mCallingInfosWithDeathRecipients.remove(callingInfo);
             removeAllSdkTokensAndLinks(callingInfo);
             stopSdkSandboxService(callingInfo, "Caller " + callingInfo + " has died");
@@ -509,6 +551,15 @@
                 new SandboxBindingCallback() {
                     @Override
                     public void onBindingSuccessful(ISdkSandboxService service) {
+                        try {
+                            service.asBinder()
+                                    .linkToDeath(
+                                            () -> handleSandboxLifecycleCallbacks(callingInfo),
+                                            0);
+                        } catch (RemoteException re) {
+                            handleSandboxLifecycleCallbacks(callingInfo);
+                            return;
+                        }
                         loadSdkForService(callingInfo, sdkToken, info, params, link, service);
                     }
 
@@ -522,6 +573,29 @@
                 });
     }
 
+    private void handleSandboxLifecycleCallbacks(CallingInfo callingInfo) {
+        RemoteCallbackList<ISdkSandboxLifecycleCallback> sandboxLifecycleCallbacks;
+        synchronized (mLock) {
+            sandboxLifecycleCallbacks = mSandboxLifecycleCallbacks.get(callingInfo);
+
+            if (sandboxLifecycleCallbacks == null) {
+                return;
+            }
+
+            int size = sandboxLifecycleCallbacks.beginBroadcast();
+            for (int i = 0; i < size; ++i) {
+                try {
+                    sandboxLifecycleCallbacks.getBroadcastItem(i).onSdkSandboxDied();
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Unable to send sdk sandbox death event to app", e);
+                }
+            }
+            sandboxLifecycleCallbacks.finishBroadcast();
+
+            mSandboxLifecycleCallbacks.remove(callingInfo);
+        }
+    }
+
     @Override
     public void stopSdkSandbox(String callingPackageName) {
         final int callingUid = Binder.getCallingUid();
diff --git a/sdksandbox/tests/cts/endtoendtests/src/com/android/tests/sdksandbox/endtoend/SdkSandboxManagerTest.java b/sdksandbox/tests/cts/endtoendtests/src/com/android/tests/sdksandbox/endtoend/SdkSandboxManagerTest.java
index c7cc7ce..52983ca 100644
--- a/sdksandbox/tests/cts/endtoendtests/src/com/android/tests/sdksandbox/endtoend/SdkSandboxManagerTest.java
+++ b/sdksandbox/tests/cts/endtoendtests/src/com/android/tests/sdksandbox/endtoend/SdkSandboxManagerTest.java
@@ -23,6 +23,7 @@
 import android.app.sdksandbox.SdkSandboxManager;
 import android.app.sdksandbox.testutils.FakeLoadSdkCallback;
 import android.app.sdksandbox.testutils.FakeRequestSurfacePackageCallback;
+import android.app.sdksandbox.testutils.FakeSdkSandboxLifecycleCallback;
 import android.app.sdksandbox.testutils.FakeSendDataCallback;
 import android.content.Context;
 import android.content.pm.SharedLibraryInfo;
@@ -89,7 +90,7 @@
     }
 
     @Test
-    public void loadSdkWithInternalErrorShouldFail() {
+    public void loadSdkWithInternalErrorShouldFail() throws Exception {
         final String sdkName = "com.android.loadSdkWithInternalErrorSdkProvider";
         final FakeLoadSdkCallback callback = new FakeLoadSdkCallback();
         mSdkSandboxManager.loadSdk(sdkName, new Bundle(), Runnable::run, callback);
@@ -164,16 +165,101 @@
 
     @Test
     public void testReloadingSdkAfterKillingSandboxIsSuccessful() throws Exception {
+        // Kill the sandbox if it already exists from previous tests
+        killSandboxIfExists();
+
         // Killing the sandbox and loading the same SDKs again multiple times should work
         for (int i = 0; i < 3; ++i) {
-            // Kill the sandbox if it already exists from previous tests/loop
-            killSandboxAndWaitForDeath();
+            FakeSdkSandboxLifecycleCallback callback = new FakeSdkSandboxLifecycleCallback();
+            mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, callback);
+
             // The same SDKs should be able to be loaded again after sandbox death
             loadMultipleSdks();
-        }
 
-        // Clean up before running other tests
-        killSandboxAndWaitForDeath();
+            killSandbox();
+            assertThat(callback.isSdkSandboxDeathDetected()).isTrue();
+        }
+    }
+
+    @Test
+    public void testAddSdkSandboxLifecycleCallback_BeforeStartingSandbox() throws Exception {
+        // Kill the sandbox if it already exists from previous tests
+        killSandboxIfExists();
+
+        // Add a sandbox lifecycle callback before starting the sandbox
+        FakeSdkSandboxLifecycleCallback lifecycleCallback = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback);
+
+        // Bring up the sandbox
+        final String sdkName = "com.android.loadSdkSuccessfullySdkProvider";
+        final FakeLoadSdkCallback callback = new FakeLoadSdkCallback();
+        mSdkSandboxManager.loadSdk(sdkName, new Bundle(), Runnable::run, callback);
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        killSandbox();
+        assertThat(lifecycleCallback.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testAddSdkSandboxLifecycleCallback_AfterStartingSandbox() throws Exception {
+        // Bring up the sandbox
+        final String sdkName = "com.android.loadSdkSuccessfullySdkProvider";
+        final FakeLoadSdkCallback callback = new FakeLoadSdkCallback();
+        mSdkSandboxManager.loadSdk(sdkName, new Bundle(), Runnable::run, callback);
+        assertThat(callback.isLoadSdkSuccessful(/*ignoreSdkAlreadyLoadedError=*/ true)).isTrue();
+
+        // Add a sandbox lifecycle callback before starting the sandbox
+        FakeSdkSandboxLifecycleCallback lifecycleCallback = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback);
+
+        killSandbox();
+        assertThat(lifecycleCallback.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testRegisterMultipleSdkSandboxLifecycleCallbacks() throws Exception {
+        // Kill the sandbox if it already exists from previous tests
+        killSandboxIfExists();
+
+        // Add a sandbox lifecycle callback before starting the sandbox
+        FakeSdkSandboxLifecycleCallback lifecycleCallback1 = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback1);
+
+        // Bring up the sandbox
+        final String sdkName = "com.android.loadSdkSuccessfullySdkProvider";
+        final FakeLoadSdkCallback callback = new FakeLoadSdkCallback();
+        mSdkSandboxManager.loadSdk(sdkName, new Bundle(), Runnable::run, callback);
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        // Add another sandbox lifecycle callback after starting it
+        FakeSdkSandboxLifecycleCallback lifecycleCallback2 = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback2);
+
+        killSandbox();
+        assertThat(lifecycleCallback1.isSdkSandboxDeathDetected()).isTrue();
+        assertThat(lifecycleCallback2.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testRemoveSdkSandboxLifecycleCallback() throws Exception {
+        // Bring up the sandbox
+        final String sdkName = "com.android.loadSdkSuccessfullySdkProvider";
+        final FakeLoadSdkCallback callback = new FakeLoadSdkCallback();
+        mSdkSandboxManager.loadSdk(sdkName, new Bundle(), Runnable::run, callback);
+        assertThat(callback.isLoadSdkSuccessful(/*ignoreSdkAlreadyLoadedError=*/ true)).isTrue();
+
+        // Add and remove a sandbox lifecycle callback
+        FakeSdkSandboxLifecycleCallback lifecycleCallback1 = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback1);
+        mSdkSandboxManager.removeSdkSandboxLifecycleCallback(lifecycleCallback1);
+
+        // Add a lifecycle callback but don't remove it
+        FakeSdkSandboxLifecycleCallback lifecycleCallback2 = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, lifecycleCallback2);
+
+        killSandbox();
+        assertThat(lifecycleCallback1.isSdkSandboxDeathDetected()).isFalse();
+        assertThat(lifecycleCallback2.isSdkSandboxDeathDetected()).isTrue();
     }
 
     @Test
@@ -289,11 +375,19 @@
         assertThat(callback2.isLoadSdkSuccessful()).isTrue();
     }
 
-    private void killSandboxAndWaitForDeath() throws Exception {
+    // Returns true if the sandbox was already likely existing, false otherwise.
+    private boolean killSandboxIfExists() throws Exception {
+        FakeSdkSandboxLifecycleCallback callback = new FakeSdkSandboxLifecycleCallback();
+        mSdkSandboxManager.addSdkSandboxLifecycleCallback(Runnable::run, callback);
+        killSandbox();
+
+        return callback.isSdkSandboxDeathDetected();
+    }
+
+    private void killSandbox() throws Exception {
         // TODO(b/241542162): Avoid using reflection as a workaround once test apis can be run
         //  without issue.
         mSdkSandboxManager.getClass().getMethod("stopSdkSandbox").invoke(mSdkSandboxManager);
-        Thread.sleep(1000);
     }
 }
 
diff --git a/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallback.java b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallback.java
new file mode 100644
index 0000000..ff19788
--- /dev/null
+++ b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallback.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.app.sdksandbox.testutils;
+
+import android.app.sdksandbox.SdkSandboxManager;
+
+public class FakeSdkSandboxLifecycleCallback
+        implements SdkSandboxManager.SdkSandboxLifecycleCallback {
+    public volatile boolean sandboxDeathDetected;
+
+    public FakeSdkSandboxLifecycleCallback() {
+        sandboxDeathDetected = false;
+    }
+
+    @Override
+    public void onSdkSandboxDied() {
+        sandboxDeathDetected = true;
+    }
+
+    public boolean isSdkSandboxDeathDetected() throws InterruptedException {
+        // Wait 5 seconds to determine whether sandbox death is ever detected.
+        Thread.sleep(5000);
+        return sandboxDeathDetected;
+    }
+}
diff --git a/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallbackBinder.java b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallbackBinder.java
new file mode 100644
index 0000000..652c67e
--- /dev/null
+++ b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/FakeSdkSandboxLifecycleCallbackBinder.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.app.sdksandbox.testutils;
+
+import android.app.sdksandbox.ISdkSandboxLifecycleCallback;
+import android.os.RemoteException;
+
+public class FakeSdkSandboxLifecycleCallbackBinder extends ISdkSandboxLifecycleCallback.Stub {
+    private final FakeSdkSandboxLifecycleCallback mFakeSdkSandboxLifecycleCallback;
+
+    public FakeSdkSandboxLifecycleCallbackBinder() {
+        mFakeSdkSandboxLifecycleCallback = new FakeSdkSandboxLifecycleCallback();
+    }
+
+    @Override
+    public void onSdkSandboxDied() throws RemoteException {
+        mFakeSdkSandboxLifecycleCallback.onSdkSandboxDied();
+    }
+
+    public boolean isSdkSandboxDeathDetected() throws InterruptedException {
+        return mFakeSdkSandboxLifecycleCallback.isSdkSandboxDeathDetected();
+    }
+}
diff --git a/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/StubSdkSandboxManagerService.java b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/StubSdkSandboxManagerService.java
index 2a84dad..d275ccf 100644
--- a/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/StubSdkSandboxManagerService.java
+++ b/sdksandbox/tests/testutils/src/android/app/sdksandbox/testutils/StubSdkSandboxManagerService.java
@@ -18,6 +18,7 @@
 
 import android.app.sdksandbox.ILoadSdkCallback;
 import android.app.sdksandbox.IRequestSurfacePackageCallback;
+import android.app.sdksandbox.ISdkSandboxLifecycleCallback;
 import android.app.sdksandbox.ISdkSandboxManager;
 import android.app.sdksandbox.ISendDataCallback;
 import android.content.pm.SharedLibraryInfo;
@@ -67,4 +68,12 @@
 
     @Override
     public void syncDataFromClient(String callingPackageName, Bundle data) {}
+
+    @Override
+    public void addSdkSandboxLifecycleCallback(
+            String callingPackageName, ISdkSandboxLifecycleCallback callback) {}
+
+    @Override
+    public void removeSdkSandboxLifecycleCallback(
+            String callingPackageName, ISdkSandboxLifecycleCallback callback) {}
 }
diff --git a/sdksandbox/tests/unittest/src/com/android/server/sdksandbox/SdkSandboxManagerServiceUnitTest.java b/sdksandbox/tests/unittest/src/com/android/server/sdksandbox/SdkSandboxManagerServiceUnitTest.java
index 4e86d9f..ab987f5 100644
--- a/sdksandbox/tests/unittest/src/com/android/server/sdksandbox/SdkSandboxManagerServiceUnitTest.java
+++ b/sdksandbox/tests/unittest/src/com/android/server/sdksandbox/SdkSandboxManagerServiceUnitTest.java
@@ -28,6 +28,7 @@
 import android.app.sdksandbox.SdkSandboxManager;
 import android.app.sdksandbox.testutils.FakeLoadSdkCallbackBinder;
 import android.app.sdksandbox.testutils.FakeRequestSurfacePackageCallbackBinder;
+import android.app.sdksandbox.testutils.FakeSdkSandboxLifecycleCallbackBinder;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.Intent;
@@ -108,7 +109,7 @@
         // Required to access <sdk-library> information.
         InstrumentationRegistry.getInstrumentation().getUiAutomation().adoptShellPermissionIdentity(
                 Manifest.permission.ACCESS_SHARED_LIBRARIES, Manifest.permission.INSTALL_PACKAGES);
-        mSdkSandboxService = new FakeSdkSandboxService();
+        mSdkSandboxService = Mockito.spy(FakeSdkSandboxService.class);
         mProvider = new FakeSdkSandboxProvider(mSdkSandboxService);
 
         // Populate LocalManagerRegistry
@@ -395,6 +396,129 @@
         assertThat(thrown).hasMessageThat().contains("Sdk " + SDK_NAME + " is not loaded");
     }
 
+    @Test
+    public void testAddSdkSandboxLifecycleCallback_BeforeStartingSandbox() throws Exception {
+        disableNetworkPermissionChecks();
+
+        // Register for sandbox death event
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback);
+
+        // Load SDK and start the sandbox
+        FakeLoadSdkCallbackBinder callback = new FakeLoadSdkCallbackBinder();
+        mService.loadSdk(TEST_PACKAGE, SDK_NAME, new Bundle(), callback);
+        mSdkSandboxService.sendLoadCodeSuccessful();
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        // Kill the sandbox
+        ArgumentCaptor<IBinder.DeathRecipient> deathRecipientCaptor =
+                ArgumentCaptor.forClass(IBinder.DeathRecipient.class);
+        Mockito.verify(mSdkSandboxService.asBinder(), Mockito.atLeastOnce())
+                .linkToDeath(deathRecipientCaptor.capture(), ArgumentMatchers.eq(0));
+        IBinder.DeathRecipient deathRecipient = deathRecipientCaptor.getValue();
+        deathRecipient.binderDied();
+
+        // Check that death is recorded correctly
+        assertThat(lifecycleCallback.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testAddSdkSandboxLifecycleCallback_AfterStartingSandbox() throws Exception {
+        disableNetworkPermissionChecks();
+
+        // Load SDK and start the sandbox
+        FakeLoadSdkCallbackBinder callback = new FakeLoadSdkCallbackBinder();
+        mService.loadSdk(TEST_PACKAGE, SDK_NAME, new Bundle(), callback);
+        mSdkSandboxService.sendLoadCodeSuccessful();
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        // Register for sandbox death event
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback);
+
+        // Kill the sandbox
+        ArgumentCaptor<IBinder.DeathRecipient> deathRecipientCaptor =
+                ArgumentCaptor.forClass(IBinder.DeathRecipient.class);
+        Mockito.verify(mSdkSandboxService.asBinder(), Mockito.atLeastOnce())
+                .linkToDeath(deathRecipientCaptor.capture(), ArgumentMatchers.eq(0));
+        IBinder.DeathRecipient deathRecipient = deathRecipientCaptor.getValue();
+        deathRecipient.binderDied();
+
+        // Check that death is recorded correctly
+        assertThat(lifecycleCallback.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testMultipleAddSdkSandboxLifecycleCallbacks() throws Exception {
+        disableNetworkPermissionChecks();
+
+        // Register for sandbox death event
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback1 =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback1);
+
+        // Load SDK and start the sandbox
+        FakeLoadSdkCallbackBinder callback = new FakeLoadSdkCallbackBinder();
+        mService.loadSdk(TEST_PACKAGE, SDK_NAME, new Bundle(), callback);
+        mSdkSandboxService.sendLoadCodeSuccessful();
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        // Register for sandbox death event again
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback2 =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback2);
+
+        // Kill the sandbox
+        ArgumentCaptor<IBinder.DeathRecipient> deathRecipientCaptor =
+                ArgumentCaptor.forClass(IBinder.DeathRecipient.class);
+        Mockito.verify(mSdkSandboxService.asBinder(), Mockito.atLeastOnce())
+                .linkToDeath(deathRecipientCaptor.capture(), ArgumentMatchers.eq(0));
+        IBinder.DeathRecipient deathRecipient = deathRecipientCaptor.getValue();
+        deathRecipient.binderDied();
+
+        // Check that death is recorded correctly
+        assertThat(lifecycleCallback1.isSdkSandboxDeathDetected()).isTrue();
+        assertThat(lifecycleCallback2.isSdkSandboxDeathDetected()).isTrue();
+    }
+
+    @Test
+    public void testRemoveSdkSandboxLifecycleCallback() throws Exception {
+        disableNetworkPermissionChecks();
+
+        // Load SDK and start the sandbox
+        FakeLoadSdkCallbackBinder callback = new FakeLoadSdkCallbackBinder();
+        mService.loadSdk(TEST_PACKAGE, SDK_NAME, new Bundle(), callback);
+        mSdkSandboxService.sendLoadCodeSuccessful();
+        assertThat(callback.isLoadSdkSuccessful()).isTrue();
+
+        // Register for sandbox death event
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback1 =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback1);
+
+        // Register for sandbox death event again
+        FakeSdkSandboxLifecycleCallbackBinder lifecycleCallback2 =
+                new FakeSdkSandboxLifecycleCallbackBinder();
+        mService.addSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback2);
+
+        // Unregister one of the lifecycle callbacks
+        mService.removeSdkSandboxLifecycleCallback(TEST_PACKAGE, lifecycleCallback1);
+
+        // Kill the sandbox
+        ArgumentCaptor<IBinder.DeathRecipient> deathRecipientCaptor =
+                ArgumentCaptor.forClass(IBinder.DeathRecipient.class);
+        Mockito.verify(mSdkSandboxService.asBinder(), Mockito.atLeastOnce())
+                .linkToDeath(deathRecipientCaptor.capture(), ArgumentMatchers.eq(0));
+        IBinder.DeathRecipient deathRecipient = deathRecipientCaptor.getValue();
+        deathRecipient.binderDied();
+
+        // Check that death is recorded correctly
+        assertThat(lifecycleCallback1.isSdkSandboxDeathDetected()).isFalse();
+        assertThat(lifecycleCallback2.isSdkSandboxDeathDetected()).isTrue();
+    }
+
     @Test(expected = SecurityException.class)
     public void testDumpWithoutPermission() {
         mService.dump(new FileDescriptor(), new PrintWriter(new StringWriter()), new String[0]);