Use IBinder instead of Proxy object for callback equivalence

When canceling a callback, the object is passed into our binder
implementation. We were comparing the callbacks directly, but there
is no guarantee that the same remote binder will be comparable via
the callback types. Instead, call asBinder to get the underlying
binder. This is guaranteed to be the same when there are multiple
calls into our backend from the same client.

Bug: 275653768
Test: RkpdAppUnitTests RkpdAppIntegrationTests RkpdAppHostTests
Test: RemoteProvisioningServiceTests keystore2_test
Change-Id: I5d67068ffb65498ba290e7e2d9cda0ceb94be0fc
diff --git a/app/src/com/android/rkpdapp/service/RegistrationBinder.java b/app/src/com/android/rkpdapp/service/RegistrationBinder.java
index 9402d63..4631c57 100644
--- a/app/src/com/android/rkpdapp/service/RegistrationBinder.java
+++ b/app/src/com/android/rkpdapp/service/RegistrationBinder.java
@@ -17,6 +17,7 @@
 package com.android.rkpdapp.service;
 
 import android.content.Context;
+import android.os.IBinder;
 import android.os.RemoteException;
 import android.util.Log;
 
@@ -67,7 +68,7 @@
     private final ExecutorService mThreadPool;
     private final Object mTasksLock = new Object();
     @GuardedBy("mTasksLock")
-    private final HashMap<IGetKeyCallback, Future<?>> mTasks = new HashMap<>();
+    private final HashMap<IBinder, Future<?>> mTasks = new HashMap<>();
 
     public RegistrationBinder(Context context, int clientUid, SystemInterface systemInterface,
             ProvisionedKeyDao provisionedKeyDao, ServerInterface rkpServer,
@@ -84,7 +85,8 @@
     private void getKeyWorker(int keyId, IGetKeyCallback callback)
             throws CborException, InterruptedException, RkpdException {
         Log.i(TAG, "Key requested for : " + mSystemInterface.getServiceName() + ", clientUid: "
-                + mClientUid + ", keyId: " + keyId + ", callback: " + callback.hashCode());
+                + mClientUid + ", keyId: " + keyId + ", callback: "
+                + callback.asBinder().hashCode());
         // Use reduced look-ahead to get rid of soon-to-be expired keys, because the periodic
         // provisioner should be ensuring that old keys are already expired. However, in the
         // edge case that periodic provisioning didn't work, we want to allow slightly "more stale"
@@ -212,12 +214,13 @@
     @Override
     public void getKey(int keyId, IGetKeyCallback callback) {
         synchronized (mTasksLock) {
-            if (mTasks.containsKey(callback)) {
-                throw new IllegalArgumentException("Callback " + callback.hashCode()
+            if (mTasks.containsKey(callback.asBinder())) {
+                throw new IllegalArgumentException("Callback " + callback.asBinder().hashCode()
                         + " is already associated with a getKey operation that is in-progress");
             }
 
-            mTasks.put(callback, mThreadPool.submit(() -> getKeyThreadWorker(keyId, callback)));
+            mTasks.put(callback.asBinder(),
+                    mThreadPool.submit(() -> getKeyThreadWorker(keyId, callback)));
         }
     }
 
@@ -250,7 +253,7 @@
         } finally {
             metric.close();
             synchronized (mTasksLock) {
-                mTasks.remove(callback);
+                mTasks.remove(callback.asBinder());
             }
         }
     }
@@ -281,11 +284,11 @@
 
     @Override
     public void cancelGetKey(IGetKeyCallback callback) throws RemoteException {
-        Log.i(TAG, "cancelGetKey(" + callback.hashCode() + ")");
+        Log.i(TAG, "cancelGetKey(" + callback.asBinder().hashCode() + ")");
         synchronized (mTasksLock) {
             try (RkpdClientOperation metric = RkpdClientOperation.cancelGetKey(mClientUid,
                     mSystemInterface.getServiceName())) {
-                Future<?> task = mTasks.get(callback);
+                Future<?> task = mTasks.get(callback.asBinder());
 
                 if (task == null) {
                     Log.w(TAG, "callback not found, task may have already completed");
diff --git a/app/tests/e2e/src/com/android/rkpdapp/e2etest/KeystoreIntegrationTest.java b/app/tests/e2e/src/com/android/rkpdapp/e2etest/KeystoreIntegrationTest.java
index 6c3987e..5a8e728 100644
--- a/app/tests/e2e/src/com/android/rkpdapp/e2etest/KeystoreIntegrationTest.java
+++ b/app/tests/e2e/src/com/android/rkpdapp/e2etest/KeystoreIntegrationTest.java
@@ -23,6 +23,8 @@
 import static com.google.common.truth.Truth.assertWithMessage;
 import static com.google.common.truth.TruthJUnit.assume;
 
+import static org.junit.Assert.assertThrows;
+
 import android.content.Context;
 import android.hardware.security.keymint.IRemotelyProvisionedComponent;
 import android.os.Process;
@@ -324,6 +326,35 @@
         }
     }
 
+    @Test
+    public void testCancelDueToServiceTimeout() throws Exception {
+        FakeRkpServer.RequestHandler blocksForOneMinute = (session, bodySize) -> {
+            session.getInputStream().readNBytes(bodySize);
+            try {
+                Thread.sleep(60 * 1000);
+            } catch (InterruptedException e) {
+                assertWithMessage("sleep failed", e).fail();
+            }
+            return null;
+        };
+
+        try (SystemPropertySetter ignored = SystemPropertySetter.setRkpOnly(mInstanceName);
+             FakeRkpServer server = new FakeRkpServer(blocksForOneMinute)) {
+            Settings.setDeviceConfig(sContext, 1, Duration.ofDays(1), server.getUrl());
+
+            // keystore will time out well before a minute has passed
+            ProviderException e = assertThrows(ProviderException.class, this::createKeystoreKey);
+
+            assertThat(e).hasCauseThat().isInstanceOf(KeyStoreException.class);
+            KeyStoreException keyStoreException = (KeyStoreException) e.getCause();
+            assertThat(keyStoreException.getErrorCode())
+                    .isEqualTo(ResponseCode.OUT_OF_KEYS_TRANSIENT_ERROR);
+            assertThat(keyStoreException.getRetryPolicy())
+                    .isEqualTo(KeyStoreException.RETRY_WITH_EXPONENTIAL_BACKOFF);
+            assertThat(keyStoreException.isTransientFailure()).isTrue();
+        }
+    }
+
     private void provisionFreshKeys() {
         PeriodicProvisioner provisioner = TestWorkerBuilder.from(
                 sContext,
diff --git a/app/tests/unit/src/com/android/rkpdapp/unittest/RegistrationBinderTest.java b/app/tests/unit/src/com/android/rkpdapp/unittest/RegistrationBinderTest.java
index 1e2d70e..538c22d 100644
--- a/app/tests/unit/src/com/android/rkpdapp/unittest/RegistrationBinderTest.java
+++ b/app/tests/unit/src/com/android/rkpdapp/unittest/RegistrationBinderTest.java
@@ -26,6 +26,7 @@
 import static org.mockito.ArgumentMatchers.notNull;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.argThat;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.contains;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
@@ -39,6 +40,8 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.content.Context;
+import android.os.Binder;
+import android.os.IBinder;
 
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
@@ -152,8 +155,10 @@
                 .getKeyForClientAndIrpc(IRPC_HAL, CLIENT_UID, KEY_ID);
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
+        verify(callback, atLeastOnce()).asBinder();
         verify(callback).onSuccess(matches(FAKE_KEY));
         verifyNoMoreInteractions(callback);
     }
@@ -165,8 +170,10 @@
                 .getOrAssignKey(eq(IRPC_HAL), notNull(), eq(CLIENT_UID), eq(KEY_ID));
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
+        verify(callback, atLeastOnce()).asBinder();
         verify(callback).onSuccess(matches(FAKE_KEY));
         verifyNoMoreInteractions(callback);
     }
@@ -178,7 +185,9 @@
                 .getOrAssignKey(eq(IRPC_HAL), notNull(), eq(CLIENT_UID), eq(KEY_ID));
 
         Instant minExpiry = Instant.now().plus(Settings.getExpiringBy(mContext));
-        mRegistration.getKey(KEY_ID, mock(IGetKeyCallback.class));
+        IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
+        mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         Instant maxExpiry = Instant.now().plus(Settings.getExpiringBy(mContext));
 
@@ -194,7 +203,9 @@
 
         Instant minExpiry = Instant.now().plus(Settings.getExpiringBy(mContext));
         Instant minFallbackExpiry = Instant.now().plus(RegistrationBinder.MIN_KEY_LIFETIME);
-        mRegistration.getKey(KEY_ID, mock(IGetKeyCallback.class));
+        IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
+        mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         Instant maxExpiry = Instant.now().plus(Settings.getExpiringBy(mContext));
         Instant maxFallbackExpiry = Instant.now().plus(RegistrationBinder.MIN_KEY_LIFETIME);
@@ -215,11 +226,13 @@
                 .getOrAssignKey(eq(IRPC_HAL), notNull(), eq(CLIENT_UID), eq(KEY_ID));
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onSuccess(matches(FAKE_KEY));
         verify(callback).onProvisioningNeeded();
         verify(mMockProvisioner).provisionKeys(any(), any(), same(mFakeGeekResponse));
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(callback);
     }
 
@@ -230,10 +243,12 @@
                 .provisionKeys(any(), any(), any());
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onError(IGetKeyCallback.Error.ERROR_UNKNOWN, "PROVISIONING FAIL");
         verify(callback).onProvisioningNeeded();
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(callback);
     }
 
@@ -244,6 +259,7 @@
                 .provisionKeys(any(), any(), any());
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onError(IGetKeyCallback.Error.ERROR_UNKNOWN, "FAIL");
@@ -256,6 +272,7 @@
                 .provisionKeys(any(), any(), any());
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onError(IGetKeyCallback.Error.ERROR_PENDING_INTERNET_CONNECTIVITY,
@@ -286,12 +303,14 @@
                     .provisionKeys(any(), any(), any());
 
             IGetKeyCallback callback = mock(IGetKeyCallback.class);
+            doReturn(new Binder()).when(callback).asBinder();
             mRegistration.getKey(KEY_ID, callback);
             // We cannot use completeAllTasks here because that shuts down the thread pool,
             // so use a timeout on verifying the callback instead.
             verify(callback, timeout(MAX_TIMEOUT.toMillis()))
                     .onError(getExpectedGetKeyError(errorCode), errorCode.toString());
             verify(callback).onProvisioningNeeded();
+            verify(callback, atLeastOnce()).asBinder();
             verifyNoMoreInteractions(callback);
         }
     }
@@ -303,7 +322,9 @@
                 .getKeyForClientAndIrpc(IRPC_HAL, CLIENT_UID, KEY_ID);
 
         Instant minExpiry = Instant.now().plus(RegistrationBinder.MIN_KEY_LIFETIME);
-        mRegistration.getKey(KEY_ID, mock(IGetKeyCallback.class));
+        IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
+        mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         Instant maxExpiry = Instant.now().plus(RegistrationBinder.MIN_KEY_LIFETIME);
 
@@ -315,11 +336,13 @@
         // This test ensures that getKey will handle the case in which provisioner doesn't error
         // out, but it also does not actually provision any keys.
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onError(IGetKeyCallback.Error.ERROR_UNKNOWN,
                 "Provisioning failed, no keys available");
         verify(callback).onProvisioningNeeded();
+        verify(callback, atLeastOnce()).asBinder();
         verify(mMockProvisioner).provisionKeys(any(), any(), any());
         verify(mRkpServer).fetchGeekAndUpdate(any());
         verifyNoMoreInteractions(callback);
@@ -329,11 +352,13 @@
     public void getKeyDisableProvisioningIsHonored() throws Exception {
         mFakeGeekResponse.numExtraAttestationKeys = 0;
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         completeAllTasks();
         verify(callback).onError(IGetKeyCallback.Error.ERROR_UNKNOWN,
                 "Provisioning failed, no keys available");
         verify(callback).onProvisioningNeeded();
+        verify(callback, atLeastOnce()).asBinder();
         verify(mRkpServer).fetchGeekAndUpdate(any());
         verify(mMockProvisioner, never()).provisionKeys(any(), any(), any());
         verifyNoMoreInteractions(callback);
@@ -349,6 +374,7 @@
                 .isProvisioningNeeded(notNull(), eq(IRPC_HAL));
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
 
         // We cannot complete all tasks until after the get key worker task completes, because
@@ -371,6 +397,7 @@
                 .isProvisioningNeeded(notNull(), eq(IRPC_HAL));
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
 
         // We cannot complete all tasks until after the get key worker task completes, because
@@ -384,12 +411,19 @@
 
     @Test
     public void getKeyHandlesCancelBeforeProvisioning() throws Exception {
+        final IBinder theBinder = new Binder();
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(theBinder).when(callback).asBinder();
         AtomicBoolean allowCancel = new AtomicBoolean(true);
         doAnswer(
                 answer((hal, minExpiry, uid, keyId) -> {
                     if (allowCancel.getAndSet(false)) {
-                        mRegistration.cancelGetKey(callback);
+                        // Use a different callback object that wraps the same binder to ensure
+                        // that the underlying code is matching based on binder, not the callback.
+                        IGetKeyCallback differentCallback = mock(IGetKeyCallback.class);
+                        doReturn(theBinder).when(differentCallback).asBinder();
+                        mRegistration.cancelGetKey(differentCallback);
+                        verify(differentCallback, atLeastOnce()).asBinder();
                     }
                     return null;
                 }))
@@ -399,6 +433,7 @@
 
         completeAllTasks();
         verify(callback).onCancel();
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(mMockProvisioner);
         verifyNoMoreInteractions(callback);
     }
@@ -409,24 +444,29 @@
         doAnswer(answerVoid((hal, dao, metrics) -> mRegistration.cancelGetKey(callback)))
                 .when(mMockProvisioner)
                 .provisionKeys(any(), any(), any());
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
 
         completeAllTasks();
         verify(callback).onCancel();
         verify(callback).onProvisioningNeeded();
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(callback);
     }
 
     @Test
     public void getKeyHandlesCancelOfInvalidCallback() throws Exception {
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.cancelGetKey(callback);
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(callback);
     }
 
     @Test
     public void getKeyHandlesInterruptedException() throws Exception {
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         doThrow(new InterruptedException())
                 .when(mMockProvisioner)
                 .provisionKeys(any(), any(), any());
@@ -435,6 +475,7 @@
         completeAllTasks();
         verify(callback).onCancel();
         verify(callback).onProvisioningNeeded();
+        verify(callback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(callback);
     }
 
@@ -452,6 +493,7 @@
                 .getKeyForClientAndIrpc(IRPC_HAL, CLIENT_UID, KEY_ID);
 
         IGetKeyCallback callback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(callback).asBinder();
         mRegistration.getKey(KEY_ID, callback);
         assertThrows(IllegalArgumentException.class, () -> mRegistration.getKey(KEY_ID, callback));
         getKeyBlocker.countDown();
@@ -471,9 +513,11 @@
                 .getKeyForClientAndIrpc(IRPC_HAL, CLIENT_UID, KEY_ID);
 
         IGetKeyCallback successfulCallback = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(successfulCallback).asBinder();
         mRegistration.getKey(KEY_ID, successfulCallback);
 
         IGetKeyCallback cancelMe = mock(IGetKeyCallback.class);
+        doReturn(new Binder()).when(cancelMe).asBinder();
         mRegistration.getKey(KEY_ID, cancelMe);
 
         assertThat(getKeyEnteredTwice.await(MAX_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS))
@@ -484,9 +528,11 @@
 
         completeAllTasks();
         verify(successfulCallback).onSuccess(matches(FAKE_KEY));
+        verify(successfulCallback, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(successfulCallback);
 
         verify(cancelMe).onCancel();
+        verify(cancelMe, atLeastOnce()).asBinder();
         verifyNoMoreInteractions(cancelMe);
     }