VirtualMachine: Clean up when binder is gone

This CL includes following changes:
  - Drop mVirtualMachine reference when virtmgr dies.
  - Only clean up mVirtualMachine in dropVm().
  - Remove code duplication for querying the latest VM status
    and updating internal states.
  - Merge stop() and close() implementation as much as possible.

Bug: 421234619
Test: T/H, atest MicrodroidHostTestCases MicrodroidTestApp
Flag: EXEMPT bugfix
Change-Id: Icfba44e783a043c1478f77d48e5537114b37ba3e
diff --git a/libs/framework-virtualization/src/android/system/virtualmachine/VirtualMachine.java b/libs/framework-virtualization/src/android/system/virtualmachine/VirtualMachine.java
index 9194e98..800fcb3 100644
--- a/libs/framework-virtualization/src/android/system/virtualmachine/VirtualMachine.java
+++ b/libs/framework-virtualization/src/android/system/virtualmachine/VirtualMachine.java
@@ -118,7 +118,6 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
 import java.util.zip.ZipFile;
 
@@ -240,6 +239,10 @@
     /** Name of the file with KEK used to setup encrypted store */
     private static final String ENCRYPTED_STORE_KEK_FILE = "encrypted_store_kek.bin";
 
+    // Both binders are for virtmgr, and not virtualizationservice
+    @NonNull private final IBinder.DeathRecipient mVSDeathRecipient;
+    @NonNull private final IBinder.DeathRecipient mVMDeathRecipient;
+
     /** The package which owns this VM. */
     @NonNull private final String mPackageName;
 
@@ -398,6 +401,9 @@
     @GuardedBy("mLock")
     private boolean mWasDeleted = false;
 
+    @GuardedBy("mLock")
+    private boolean mStopHandled = false;
+
     /** The registered callback */
     @GuardedBy("mCallbackLock")
     @Nullable
@@ -471,6 +477,18 @@
         } else {
             mMemoryManagementCallbacks = null;
         }
+
+        mVSDeathRecipient = () -> handleStopped(STOP_REASON_VIRTUALIZATION_SERVICE_DIED);
+        mVMDeathRecipient = () -> handleStopped(STOP_REASON_VIRTUALIZATION_SERVICE_DIED);
+        // Set death recipient at the last moment, so it's only called after fully initialized.
+        try {
+            mVirtualizationService
+                    .getBinder()
+                    .asBinder()
+                    .linkToDeath(mVSDeathRecipient, /* flags= */ 0);
+        } catch (RemoteException e) {
+            throw e.rethrowAsRuntimeException();
+        }
     }
 
     /**
@@ -800,7 +818,7 @@
             // A VM can quite happily keep running if its backing files have been deleted.
             // But once it stops, it's gone forever.
             synchronized (mLock) {
-                dropVm();
+                mWasDeleted = true;
             }
             return STATUS_DELETED;
         }
@@ -824,23 +842,29 @@
     // Throw an appropriate exception if we have a running VM, or the VM has been deleted.
     @GuardedBy("mLock")
     private void checkStopped() throws VirtualMachineException {
-        if (mWasDeleted || !mVmRootPath.exists()) {
-            throw new VirtualMachineException("VM has been deleted");
-        }
-        if (mVirtualMachine == null) {
-            return;
-        }
-        try {
-            if (stateToStatus(mVirtualMachine.getState()) != STATUS_STOPPED) {
+        switch (getStatus()) {
+            case STATUS_STOPPED:
+                // no-op
+                return;
+            case STATUS_RUNNING:
                 throw new VirtualMachineException("VM is not in stopped state");
-            }
-        } catch (RemoteException e) {
-            throw e.rethrowAsRuntimeException();
+            case STATUS_DELETED:
+                throw new VirtualMachineException("VM has been deleted");
         }
-        // It's stopped, but we still have a reference to it - we can fix that.
-        dropVm();
     }
 
+    private void handleStopped(@VirtualMachineCallback.StopReason int reason) {
+        synchronized (mLock) {
+            if (mStopHandled) {
+                return;
+            }
+            mStopHandled = true;
+            dropVm();
+        }
+        executeCallback((cb) -> cb.onStopped(VirtualMachine.this, reason));
+    }
+
+
     /**
      * This should only be called when we know our VM has stopped; we no longer need to hold a
      * reference to it (this allows resources to be GC'd) and we no longer need to be informed of
@@ -855,25 +879,25 @@
         if (mMemoryManagementCallbacks != null) {
             mContext.unregisterComponentCallbacks(mMemoryManagementCallbacks);
         }
-        mVirtualMachine = null;
+        if (mVirtualMachine != null) {
+            mVirtualMachine.asBinder().unlinkToDeath(mVMDeathRecipient, /* flags= */ 0);
+            mVirtualMachine = null;
+        }
     }
 
     /** If we have an IVirtualMachine in the running state return it, otherwise throw. */
     @GuardedBy("mLock")
     private IVirtualMachine getRunningVm() throws VirtualMachineException {
-        try {
-            if (mVirtualMachine != null
-                    && stateToStatus(mVirtualMachine.getState()) == STATUS_RUNNING) {
+        switch (getStatus()) {
+            case STATUS_STOPPED:
+                throw new VirtualMachineException("VM is not in running state");
+            case STATUS_RUNNING:
                 return mVirtualMachine;
-            } else {
-                if (mWasDeleted || !mVmRootPath.exists()) {
-                    throw new VirtualMachineException("VM has been deleted");
-                } else {
-                    throw new VirtualMachineException("VM is not in running state");
-                }
-            }
-        } catch (RemoteException e) {
-            throw e.rethrowAsRuntimeException();
+            case STATUS_DELETED:
+                throw new VirtualMachineException("VM has been deleted");
+            default:
+                // unreachable code. Just make compiler happy.
+                return null;
         }
     }
 
@@ -1525,6 +1549,8 @@
         synchronized (mLock) {
             checkStopped();
 
+            mStopHandled = false;
+
             try {
                 mIdsigFilePath.createNewFile();
                 for (ExtraApkSpec extraApk : mExtraApks) {
@@ -1625,6 +1651,7 @@
                         service.createVm(
                                 vmConfigParcel, consoleOutFd, consoleInFd, mLogWriter, null);
                 mVirtualMachine.registerCallback(new CallbackTranslator(this, service));
+                mVirtualMachine.asBinder().linkToDeath(mVMDeathRecipient, /* flags= */ 0);
                 if (mMemoryManagementCallbacks != null) {
                     mContext.registerComponentCallbacks(mMemoryManagementCallbacks);
                 }
@@ -1865,18 +1892,19 @@
     @WorkerThread
     public void stop() throws VirtualMachineException {
         synchronized (mLock) {
-            if (mVirtualMachine == null) {
-                throw new VirtualMachineException("VM is not running");
-            }
-            try {
-                mVirtualMachine.stop();
-                dropVm();
-            } catch (RemoteException e) {
-                throw e.rethrowAsRuntimeException();
-            } catch (ServiceSpecificException e) {
-                throw new VirtualMachineException(e);
+            if (getStatus() == STATUS_RUNNING) {
+                try {
+                    mVirtualMachine.stop();
+                    dropVm();
+                    return;
+                } catch (RemoteException e) {
+                    throw e.rethrowAsRuntimeException();
+                } catch (ServiceSpecificException e) {
+                    throw new VirtualMachineException(e);
+                }
             }
         }
+        throw new VirtualMachineException("VM is not running");
     }
 
     /** @hide */
@@ -1923,20 +1951,17 @@
     @WorkerThread
     @Override
     public void close() {
-        synchronized (mLock) {
-            if (mVirtualMachine == null) {
-                return;
-            }
-            try {
-                if (stateToStatus(mVirtualMachine.getState()) == STATUS_RUNNING) {
-                    mVirtualMachine.stop();
-                    dropVm();
-                }
-            } catch (RemoteException | ServiceSpecificException e) {
-                // Deliberately ignored; this almost certainly means the VM exited just as
-                // we tried to stop it.
-                Log.i(TAG, "Ignoring error on close()", e);
-            }
+        try {
+            stop();
+
+            mVirtualizationService
+                    .getBinder()
+                    .asBinder()
+                    .unlinkToDeath(mVSDeathRecipient, /* flags= */ 0);
+        } catch (Exception e) {
+            // Deliberately ignored; this almost certainly means the VM exited just as
+            // we tried to stop it.
+            Log.i(TAG, "Ignoring error on close()", e);
         }
     }
 
@@ -2391,16 +2416,11 @@
     private static class CallbackTranslator extends IVirtualMachineCallback.Stub {
         private final WeakReference<VirtualMachine> mVirtualMachine;
         private final WeakReference<IVirtualizationService> mService;
-        private final DeathRecipient mDeathRecipient;
 
-        // The VM should only be observed to die once
-        private final AtomicBoolean mOnDiedCalled = new AtomicBoolean(false);
-
-        public CallbackTranslator(VirtualMachine virtualMachine, IVirtualizationService service) throws RemoteException {
+        public CallbackTranslator(VirtualMachine virtualMachine, IVirtualizationService service)
+                throws RemoteException {
             this.mVirtualMachine = new WeakReference<>(virtualMachine);
             this.mService = new WeakReference<>(service);
-            this.mDeathRecipient = () -> reportStopped(STOP_REASON_VIRTUALIZATION_SERVICE_DIED);
-            service.asBinder().linkToDeath(mDeathRecipient, 0);
         }
 
         @Override
@@ -2439,19 +2459,9 @@
         @Override
         public void onDied(int cid, int reason) {
             int translatedReason = getTranslatedReason(reason);
-            reportStopped(translatedReason);
-            var service = mService.get();
-            if (service != null) {
-                service.asBinder().unlinkToDeath(mDeathRecipient, 0);
-            }
-        }
-
-        private void reportStopped(@VirtualMachineCallback.StopReason int reason) {
-            if (mOnDiedCalled.compareAndSet(false, true)) {
-                VirtualMachine vm = mVirtualMachine.get();
-                if (vm != null) {
-                    vm.executeCallback((cb) -> cb.onStopped(vm, reason));
-                }
+            VirtualMachine vm = mVirtualMachine.get();
+            if (vm != null) {
+                vm.handleStopped(translatedReason);
             }
         }
 
diff --git a/tests/libs/device/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java b/tests/libs/device/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
index 7492a60..4d9707e 100644
--- a/tests/libs/device/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
+++ b/tests/libs/device/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
@@ -40,6 +40,7 @@
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineException;
 import android.system.virtualmachine.VirtualMachineManager;
+import android.text.TextUtils;
 import android.util.Log;
 
 import androidx.annotation.CallSuper;
@@ -298,6 +299,15 @@
         assume().withMessage("Secretkeeper not supported").that(isUpdatableVmSupported()).isFalse();
     }
 
+    protected void assumeDebuggableBuild() {
+        // SystemProperties can't be used due to the sepolicy denial.
+        Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
+        UiAutomation uiAutomation = instrumentation.getUiAutomation();
+        assume().withMessage("Test requires debuggable build")
+                .that(runInShell(TAG, uiAutomation, "getprop ro.debuggable").trim())
+                .isEqualTo("1");
+    }
+
     protected boolean isUpdatableVmSupported() throws VirtualMachineException {
         // Pre-36 OS doesn't have VirtualMachineManager#isUpdatableVmSupported.
         if (Build.VERSION.SDK_INT >= 35) {
@@ -634,6 +644,28 @@
         }
     }
 
+    protected void kill(String tag, String processName) {
+        Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
+        UiAutomation uiAutomation = instrumentation.getUiAutomation();
+        uiAutomation.adoptShellPermissionIdentity();
+        try {
+            String pid = runInShell(TAG, uiAutomation, "pidof " + processName).trim();
+            if (TextUtils.isEmpty(pid)) {
+                Log.i(tag, "Process " + processName + " isn't running. Skipping kill()");
+                return;
+            }
+
+            String res = runInShellWithStderr(TAG, uiAutomation, "su 0 kill -9 " + pid).trim();
+            if (TextUtils.isEmpty(res)) {
+                Log.i(tag, "Process " + processName + " (pid=" + pid + ") is killed");
+            } else {
+                throw new RuntimeException("Failed to kill process. Is this a debuggable build?");
+            }
+        } finally {
+            uiAutomation.dropShellPermissionIdentity();
+        }
+    }
+
     protected static class TestResults {
         public Exception mException;
         public Integer mAddInteger;
diff --git a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
index 640ac4d..6c6a2af 100644
--- a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
+++ b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
@@ -119,6 +119,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Stream;
@@ -483,6 +484,9 @@
                 () -> getVirtualMachineManager().delete("test_vm"), "not in stopped state");
 
         vm.stop();
+
+        assertThat(vm.getStatus()).isEqualTo(STATUS_STOPPED);
+
         getVirtualMachineManager().delete("test_vm");
         assertThat(vm.getStatus()).isEqualTo(STATUS_DELETED);
 
@@ -499,6 +503,74 @@
         assertThrowsVmException(() -> getVirtualMachineManager().delete("test_vm"));
     }
 
+    private void testCrashInternal(String vmName, String killProcessName) throws Exception {
+        assumeSupportedDevice();
+
+        // kill() requires debuggable build
+        assumeDebuggableBuild();
+
+        VirtualMachineConfig config =
+                newVmConfigBuilderWithPayloadBinary("MicrodroidTestNativeLib.so")
+                        .setMemoryBytes(minMemoryRequired())
+                        .setDebugLevel(DEBUG_LEVEL_FULL)
+                        .build();
+
+        try (VirtualMachine vm = forceCreateNewVirtualMachine(vmName, config)) {
+            assertThat(vm.getStatus()).isEqualTo(STATUS_STOPPED);
+
+            CountDownLatch started = new CountDownLatch(1);
+            CountDownLatch stopped = new CountDownLatch(1);
+            vm.setCallback(
+                    Executors.newSingleThreadExecutor(),
+                    new VirtualMachineCallback() {
+                        @Override
+                        public void onPayloadStarted(VirtualMachine vm) {
+                            started.countDown();
+                        }
+
+                        @Override
+                        public void onPayloadReady(VirtualMachine vm) {}
+
+                        @Override
+                        public void onPayloadFinished(VirtualMachine vm, int exitCode) {}
+
+                        @Override
+                        public void onError(VirtualMachine vm, int errorCode, String message) {}
+
+                        @Override
+                        public void onStopped(VirtualMachine vm, int reason) {
+                            stopped.countDown();
+                        }
+                    });
+
+            vm.run();
+            assertThat(vm.getStatus()).isEqualTo(STATUS_RUNNING);
+
+            // Let globalTimeout to handle timeout.
+            started.await();
+
+            kill(TAG, killProcessName);
+
+            // Let globalTimeout to handle timeout.
+            stopped.await();
+
+            assertThat(vm.getStatus()).isEqualTo(STATUS_STOPPED);
+
+            assertThrowsVmExceptionContaining(() -> vm.stop(), "not running");
+        }
+    }
+
+    @Test
+    public void vm_crosvmCrash() throws Exception {
+        String vmConfigName = "test_vm_crosvmCrash";
+        testCrashInternal(vmConfigName, "crosvm_" + vmConfigName);
+    }
+
+    @Test
+    public void vm_virtmgrCrash() throws Exception {
+        testCrashInternal("test_vm_virtmgrCrash", "virtmgr");
+    }
+
     @Test
     @CddTest
     public void connectVsock() throws Exception {