Send last received message to recipient

Eliminate race condition of a message being sent prior
to the callback being registered.

Bug: 144163031
Test: Unit tests pass
Change-Id: Ic58154dda962e1eff463cacf1fe12e138c853d3f
diff --git a/connected-device-lib/src/com/android/car/connecteddevice/ConnectedDeviceManager.java b/connected-device-lib/src/com/android/car/connecteddevice/ConnectedDeviceManager.java
index c4e392e..0f57de2 100644
--- a/connected-device-lib/src/com/android/car/connecteddevice/ConnectedDeviceManager.java
+++ b/connected-device-lib/src/com/android/car/connecteddevice/ConnectedDeviceManager.java
@@ -93,6 +93,10 @@
     private final Map<String, InternalConnectedDevice> mConnectedDevices =
             new ConcurrentHashMap<>();
 
+    // recipientId -> (deviceId -> message bytes)
+    private final Map<UUID, Map<String, byte[]>> mRecipientMissedMessages =
+            new ConcurrentHashMap<>();
+
     // Recipient ids that received multiple callback registrations indicate that the recipient id
     // has been compromised. Another party now has access the messages intended for that recipient.
     // As a safeguard, that recipient id will be added to this list and blocked from further
@@ -313,6 +317,12 @@
         ThreadSafeCallbacks<DeviceCallback> newCallbacks = new ThreadSafeCallbacks<>();
         newCallbacks.add(callback, executor);
         recipientCallbacks.put(recipientId, newCallbacks);
+
+        byte[] message = popMissedMessage(recipientId, device.getDeviceId());
+        if (message != null) {
+            newCallbacks.invoke(deviceCallback ->
+                    deviceCallback.onMessageReceived(device, message));
+        }
     }
 
     private void notifyOfBlacklisting(@NonNull ConnectedDevice device, @NonNull UUID recipientId,
@@ -323,6 +333,32 @@
                 callback.onDeviceError(device, DEVICE_ERROR_INSECURE_RECIPIENT_ID_DETECTED));
     }
 
+    private void saveMissedMessage(@NonNull String deviceId, @NonNull UUID recipientId,
+            @NonNull byte[] message) {
+        // Store last message in case recipient registers callbacks in the future.
+        mRecipientMissedMessages.putIfAbsent(recipientId, new HashMap<>());
+        mRecipientMissedMessages.get(recipientId).putIfAbsent(deviceId, message);
+    }
+
+    /**
+     * Remove the last message sent for this device prior to a {@link DeviceCallback} being
+     * registered.
+     *
+     * @param recipientId Recipient's id
+     * @param deviceId Device id
+     * @return The last missed {@code byte[]} of the message, or {@code null} if no messages were
+     *         missed.
+     */
+    @Nullable
+    private byte[] popMissedMessage(@NonNull UUID recipientId, @NonNull String deviceId) {
+        Map<String, byte[]> missedMessages = mRecipientMissedMessages.get(recipientId);
+        if (missedMessages == null) {
+            return null;
+        }
+
+        return missedMessages.remove(deviceId);
+    }
+
     /**
      * Unregister callback from device events.
      *
@@ -509,14 +545,17 @@
                     + "recipient " + message.getRecipient() + ".");
             return;
         }
+        UUID recipientId = message.getRecipient();
         Map<UUID, ThreadSafeCallbacks<DeviceCallback>> deviceCallbacks =
                 mDeviceCallbacks.get(deviceId);
         if (deviceCallbacks == null) {
+            saveMissedMessage(deviceId, recipientId, message.getMessage());
             return;
         }
         ThreadSafeCallbacks<DeviceCallback> recipientCallbacks =
-                deviceCallbacks.get(message.getRecipient());
+                deviceCallbacks.get(recipientId);
         if (recipientCallbacks == null) {
+            saveMissedMessage(deviceId, recipientId, message.getMessage());
             return;
         }
 
diff --git a/connected-device-lib/tests/unit/src/com/android/car/connecteddevice/ConnectedDeviceManagerTest.java b/connected-device-lib/tests/unit/src/com/android/car/connecteddevice/ConnectedDeviceManagerTest.java
index 20520ef..f8a546e 100644
--- a/connected-device-lib/tests/unit/src/com/android/car/connecteddevice/ConnectedDeviceManagerTest.java
+++ b/connected-device-lib/tests/unit/src/com/android/car/connecteddevice/ConnectedDeviceManagerTest.java
@@ -423,6 +423,58 @@
         assertThat(tryAcquire(semaphore)).isFalse();
     }
 
+    @Test
+    public void registerDeviceCallback_sendsMissedMessageAfterRegistration()
+            throws InterruptedException {
+        Semaphore semaphore = new Semaphore(0);
+        connectNewDevice(mMockCentralManager);
+        ConnectedDevice connectedDevice =
+                mConnectedDeviceManager.getActiveUserConnectedDevices().get(0);
+        byte[] payload = ByteUtils.randomBytes(10);
+        DeviceMessage message = new DeviceMessage(mRecipientId, false, payload);
+        mConnectedDeviceManager.onMessageReceived(connectedDevice.getDeviceId(), message);
+        DeviceCallback deviceCallback = createDeviceCallback(semaphore);
+        mConnectedDeviceManager.registerDeviceCallback(connectedDevice, mRecipientId,
+                deviceCallback, mCallbackExecutor);
+        assertThat(tryAcquire(semaphore)).isTrue();
+        verify(deviceCallback).onMessageReceived(connectedDevice, payload);
+    }
+
+    @Test
+    public void registerDeviceCallback_doesNotSendMissedMessageForDifferentRecipient()
+            throws InterruptedException {
+        Semaphore semaphore = new Semaphore(0);
+        connectNewDevice(mMockCentralManager);
+        ConnectedDevice connectedDevice =
+                mConnectedDeviceManager.getActiveUserConnectedDevices().get(0);
+        byte[] payload = ByteUtils.randomBytes(10);
+        DeviceMessage message = new DeviceMessage(UUID.randomUUID(), false, payload);
+        mConnectedDeviceManager.onMessageReceived(connectedDevice.getDeviceId(), message);
+        DeviceCallback deviceCallback = createDeviceCallback(semaphore);
+        mConnectedDeviceManager.registerDeviceCallback(connectedDevice, mRecipientId,
+                deviceCallback, mCallbackExecutor);
+        assertThat(tryAcquire(semaphore)).isFalse();
+    }
+
+    @Test
+    public void registerDeviceCallback_doesNotSendMissedMessageForDifferentDevice()
+            throws InterruptedException {
+        Semaphore semaphore = new Semaphore(0);
+        connectNewDevice(mMockCentralManager);
+        connectNewDevice(mMockCentralManager);
+        List<ConnectedDevice> connectedDevices =
+                mConnectedDeviceManager.getActiveUserConnectedDevices();
+        ConnectedDevice connectedDevice = connectedDevices.get(0);
+        ConnectedDevice otherDevice = connectedDevices.get(1);
+        byte[] payload = ByteUtils.randomBytes(10);
+        DeviceMessage message = new DeviceMessage(mRecipientId, false, payload);
+        mConnectedDeviceManager.onMessageReceived(otherDevice.getDeviceId(), message);
+        DeviceCallback deviceCallback = createDeviceCallback(semaphore);
+        mConnectedDeviceManager.registerDeviceCallback(connectedDevice, mRecipientId,
+                deviceCallback, mCallbackExecutor);
+        assertThat(tryAcquire(semaphore)).isFalse();
+    }
+
     private boolean tryAcquire(Semaphore semaphore) throws InterruptedException {
         return semaphore.tryAcquire(100, TimeUnit.MILLISECONDS);
     }