Make sure IPC calls complete before unbinding isolated service

Bug: 377754732
Test: manually tested with large TrainingExamplesOutputParcel; atest
OdpExampleStoreServiceTests
Flag: EXEMPT (bug 337358613)

Change-Id: Idf4eebdcb45a8b91e9e3a9c93e1600bdf3109e4b
diff --git a/framework/java/com/android/ondevicepersonalization/internal/util/BaseOdpParceledListSlice.java b/framework/java/com/android/ondevicepersonalization/internal/util/BaseOdpParceledListSlice.java
index 267a4de..138d9b0 100644
--- a/framework/java/com/android/ondevicepersonalization/internal/util/BaseOdpParceledListSlice.java
+++ b/framework/java/com/android/ondevicepersonalization/internal/util/BaseOdpParceledListSlice.java
@@ -21,7 +21,6 @@
 import android.os.Parcel;
 import android.os.Parcelable;
 import android.os.RemoteException;
-import android.util.Log;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -39,13 +38,9 @@
  * See b/17671747.
  */
 abstract class BaseOdpParceledListSlice<T> implements Parcelable {
-    private static final String TAG = "OdpParceledListSlice";
+    private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
     private static final boolean DEBUG = false;
-
-    /*
-     * TODO get this number from somewhere else. For now set it to a quarter of
-     * the 1MB limit.
-     */
+    private static final String TAG = BaseOdpParceledListSlice.class.getSimpleName();
     private static final int MAX_IPC_SIZE = IBinder.getSuggestedMaxIpcSizeBytes();
 
     private final List<T> mList;
@@ -60,7 +55,7 @@
     BaseOdpParceledListSlice(Parcel p, ClassLoader loader) {
         final int numItems = p.readInt();
         mList = new ArrayList<T>(numItems);
-        if (DEBUG) Log.d(TAG, "Retrieving " + numItems + " items");
+        if (DEBUG) sLogger.d(TAG + ": Retrieving " + numItems + " items");
         if (numItems <= 0) {
             return;
         }
@@ -83,7 +78,7 @@
 
             mList.add(parcelable);
 
-            if (DEBUG) Log.d(TAG, "Read inline #" + i + ": " + mList.get(mList.size() - 1));
+            if (DEBUG) sLogger.d(TAG + ": Read inline #" + i + ": " + mList.get(mList.size() - 1));
             i++;
         }
         if (i >= numItems) {
@@ -92,8 +87,8 @@
         final IBinder retriever = p.readStrongBinder();
         while (i < numItems) {
             if (DEBUG) {
-                Log.d(TAG,
-                        "Reading more @" + i + " of " + numItems + ": retriever=" + retriever);
+                sLogger.d(TAG
+                        + ": Reading more @" + i + " of " + numItems + ": retriever=" + retriever);
             }
             Parcel data = Parcel.obtain();
             Parcel reply = Parcel.obtain();
@@ -101,7 +96,8 @@
             try {
                 retriever.transact(IBinder.FIRST_CALL_TRANSACTION, data, reply, 0);
             } catch (RemoteException e) {
-                Log.w(TAG, "Failure retrieving array; only received " + i + " of " + numItems, e);
+                sLogger.w(e, TAG + ": Failure retrieving array; only received " + i + " of "
+                        + numItems);
                 return;
             }
             while (i < numItems && reply.readInt() != 0) {
@@ -110,7 +106,9 @@
 
                 mList.add(parcelable);
 
-                if (DEBUG) Log.d(TAG, "Read extra #" + i + ": " + mList.get(mList.size() - 1));
+                if (DEBUG) {
+                    sLogger.d(TAG + ": Read extra #" + i + ": " + mList.get(mList.size() - 1));
+                }
                 i++;
             }
             reply.recycle();
@@ -157,7 +155,7 @@
         final int numItems = mList.size();
         final int callFlags = flags;
         dest.writeInt(numItems);
-        if (DEBUG) Log.d(TAG, "Writing " + numItems + " items");
+        if (DEBUG) sLogger.d(TAG + ": Writing " + numItems + " items");
         if (numItems > 0) {
             final Class<?> listElementClass = mList.get(0).getClass();
             writeParcelableCreator(mList.get(0), dest);
@@ -169,7 +167,7 @@
                 verifySameType(listElementClass, parcelable.getClass());
                 writeElement(parcelable, dest, callFlags);
 
-                if (DEBUG) Log.d(TAG, "Wrote inline #" + i + ": " + mList.get(i));
+                if (DEBUG) sLogger.d(TAG + ": Wrote inline #" + i + ": " + mList.get(i));
                 i++;
             }
             if (i < numItems) {
@@ -182,7 +180,7 @@
                             return super.onTransact(code, data, reply, flags);
                         }
                         int i = data.readInt();
-                        if (DEBUG) Log.d(TAG, "Writing more @" + i + " of " + numItems);
+                        if (DEBUG) sLogger.d(TAG + ": Writing more @" + i + " of " + numItems);
                         while (i < numItems && reply.dataSize() < MAX_IPC_SIZE) {
                             reply.writeInt(1);
 
@@ -190,19 +188,21 @@
                             verifySameType(listElementClass, parcelable.getClass());
                             writeElement(parcelable, reply, callFlags);
 
-                            if (DEBUG) Log.d(TAG, "Wrote extra #" + i + ": " + mList.get(i));
+                            if (DEBUG) {
+                                sLogger.d(TAG + ": Wrote extra #" + i + ": " + mList.get(i));
+                            }
                             i++;
                         }
                         if (i < numItems) {
-                            if (DEBUG) Log.d(TAG, "Breaking @" + i + " of " + numItems);
+                            if (DEBUG) sLogger.d(TAG + ": Breaking @" + i + " of " + numItems);
                             reply.writeInt(0);
                         }
                         return true;
                     }
                 };
                 if (DEBUG) {
-                    Log.d(TAG,
-                            "Breaking @" + i + " of " + numItems + ": retriever=" + retriever);
+                    sLogger.d(TAG
+                            + ": Breaking @" + i + " of " + numItems + ": retriever=" + retriever);
                 }
                 dest.writeStrongBinder(retriever);
             }
diff --git a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java
index 2f7c56e..11d4419 100644
--- a/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java
+++ b/src/com/android/ondevicepersonalization/services/federatedcompute/OdpExampleStoreService.java
@@ -60,6 +60,7 @@
 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
 
 import java.util.Objects;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
@@ -225,6 +226,7 @@
                                     TimeUnit.SECONDS,
                                     mInjector.getScheduledExecutor());
 
+            CountDownLatch latch = new CountDownLatch(1);
             Futures.addCallback(
                     resultFuture,
                     new FutureCallback<Bundle>() {
@@ -260,6 +262,7 @@
                                                             trainingExampleRecordList.getList()));
                                 }
                             } finally {
+                                latch.countDown();
                                 StatsUtils.writeServiceRequestMetrics(
                                         Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
                                         packageName,
@@ -272,6 +275,7 @@
 
                         @Override
                         public void onFailure(Throwable t) {
+                            latch.countDown();
                             int status = Constants.STATUS_INTERNAL_ERROR;
                             if (t instanceof TimeoutException) {
                                 status = Constants.STATUS_TIMEOUT;
@@ -301,10 +305,17 @@
             var unused =
                     Futures.whenAllComplete(loadFuture, resultFuture)
                             .callAsync(
-                                    () ->
-                                            mInjector
-                                                    .getProcessRunner()
-                                                    .unloadIsolatedService(loadFuture.get()),
+                                    () -> {
+                                        try {
+                                            latch.await();
+                                        } catch (InterruptedException e) {
+                                            sLogger.e(e, "%s : Interrupted while "
+                                                    + "waiting for transaction complete", TAG);
+                                        }
+                                        return mInjector
+                                                .getProcessRunner()
+                                                .unloadIsolatedService(loadFuture.get());
+                                    },
                                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
         } catch (Throwable e) {
             sLogger.e(e, "%s : Start query failed.", TAG);