Snap for 8512216 from 775e966e07fb11a55afff5ab93b79128c29a84ac to tm-frc-art-release

Change-Id: I4b4d86e8d13757ad6739622d74483044a74a6540
diff --git a/OWNERS b/OWNERS
index 81cfdb8..46bd5b1 100644
--- a/OWNERS
+++ b/OWNERS
@@ -2,6 +2,6 @@
 # Please update this list if you find better candidates.
 tonymak@google.com
 toki@google.com
-zilka@google.com
-mns@google.com
-jalt@google.com
+licha@google.com
+joannechung@google.com
+lpeter@google.com
\ No newline at end of file
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 72e022b..370acd6 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -21,6 +21,25 @@
       "name": "TCSModelDownloaderIntegrationTest"
     }
   ],
+  "hwasan-postsubmit": [
+    {
+      "name": "TextClassifierServiceTest",
+      "options": [
+        {
+          "exclude-annotation": "androidx.test.filters.FlakyTest"
+        }
+      ]
+    },
+    {
+      "name": "libtextclassifier_tests"
+    },
+    {
+      "name": "libtextclassifier_java_tests"
+    },
+    {
+      "name": "TextClassifierNotificationTests"
+    }
+  ],
   "mainline-presubmit": [
     {
       "name": "TextClassifierNotificationTests[com.google.android.extservices.apex]"
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index fd64581..bde3898 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -87,7 +87,9 @@
     return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
   }
 
-  /** @see #getTopLanguage(Intent) */
+  /**
+   * @see #getTopLanguage(Intent)
+   */
   static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
     final int maxSize = Math.min(3, languageScores.getEntities().size());
     final String[] languages =
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
index 1ae79ce..9bdfb5e 100644
--- a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -195,7 +195,7 @@
   @Override
   public void onDownloadCompleted(
       ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
-    TcLog.v(TAG, "Start to clean up models and update model lookup cache...");
+    TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
     // Step 1: Clean up ManifestEnrollment table
     List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
     List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
@@ -286,7 +286,7 @@
   // Clear the cache table and rebuild the cache based on ModelView table
   private void updateCache() {
     synchronized (cacheLock) {
-      TcLog.v(TAG, "Updating model lookup cache...");
+      TcLog.d(TAG, "Updating model lookup cache...");
       for (String modelType : ModelType.values()) {
         modelLookupCache.get(modelType).clear();
       }
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
index b125f13..af33e21 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -44,6 +44,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Enums;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 import com.google.common.hash.Hashing;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
@@ -54,6 +55,7 @@
 import java.util.List;
 import java.util.Locale;
 import java.util.UUID;
+import java.util.concurrent.Callable;
 import javax.annotation.Nullable;
 
 /** Manager to listen to config update and download latest models. */
@@ -64,6 +66,7 @@
 
   private final Context appContext;
   private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
+  private final Callable<WorkManager> workManagerSupplier;
   private final DownloadedModelManager downloadedModelManager;
   private final TextClassifierSettings settings;
   private final ListeningExecutorService executorService;
@@ -84,6 +87,7 @@
     this(
         appContext,
         ModelDownloadWorker.class,
+        () -> WorkManager.getInstance(appContext),
         DownloadedModelManagerImpl.getInstance(appContext),
         settings,
         executorService);
@@ -93,11 +97,13 @@
   public ModelDownloadManager(
       Context appContext,
       Class<? extends ListenableWorker> modelDownloadWorkerClass,
+      Callable<WorkManager> workManagerSupplier,
       DownloadedModelManager downloadedModelManager,
       TextClassifierSettings settings,
       ListeningExecutorService executorService) {
     this.appContext = Preconditions.checkNotNull(appContext);
     this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
+    this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier);
     this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
     this.settings = Preconditions.checkNotNull(settings);
     this.executorService = Preconditions.checkNotNull(executorService);
@@ -121,22 +127,31 @@
   /** Returns the downlaoded models for the given modelType. */
   @Nullable
   public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
-    return downloadedModelManager.listModels(modelType);
+    try {
+      return downloadedModelManager.listModels(modelType);
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed to list downloaded models", t);
+      return ImmutableList.of();
+    }
   }
 
   /** Notifies the model downlaoder that the text classifier service is created. */
   public void onTextClassifierServiceCreated() {
-    DeviceConfig.addOnPropertiesChangedListener(
-        DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
-    appContext.registerReceiver(
-        localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
-    TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
-    if (!settings.isModelDownloadManagerEnabled()) {
-      return;
+    try {
+      DeviceConfig.addOnPropertiesChangedListener(
+          DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
+      appContext.registerReceiver(
+          localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
+      TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
+      if (!settings.isModelDownloadManagerEnabled()) {
+        return;
+      }
+      maybeOverrideLocaleListForTesting();
+      TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started.");
+      scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t);
     }
-    maybeOverrideLocaleListForTesting();
-    TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started.");
-    scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
   }
 
   // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -146,8 +161,12 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    TcLog.v(TAG, "Try to schedule model download work because of system locale changes.");
-    scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+    TcLog.d(TAG, "Try to schedule model download work because of system locale changes.");
+    try {
+      scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed inside onLocaleChanged", t);
+    }
   }
 
   // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -157,16 +176,24 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    maybeOverrideLocaleListForTesting();
-    TcLog.v(TAG, "Try to schedule model download work because of device config changes.");
-    scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+    TcLog.d(TAG, "Try to schedule model download work because of device config changes.");
+    try {
+      maybeOverrideLocaleListForTesting();
+      scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t);
+    }
   }
 
   /** Clean up internal states on destroying. */
   public void destroy() {
-    DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
-    appContext.unregisterReceiver(localeChangedReceiver);
-    TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+    try {
+      DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
+      appContext.unregisterReceiver(localeChangedReceiver);
+      TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t);
+    }
   }
 
   /**
@@ -178,10 +205,14 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    printWriter.println("ModelDownloadManager:");
-    printWriter.increaseIndent();
-    downloadedModelManager.dump(printWriter);
-    printWriter.decreaseIndent();
+    try {
+      printWriter.println("ModelDownloadManager:");
+      printWriter.increaseIndent();
+      downloadedModelManager.dump(printWriter);
+      printWriter.decreaseIndent();
+    } catch (Throwable t) {
+      TcLog.e(TAG, "Failed to dump ModelDownloadManager", t);
+    }
   }
 
   /**
@@ -193,54 +224,62 @@
   private void scheduleDownloadWork(int reasonToSchedule) {
     long workId =
         Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong();
-    NetworkType networkType =
-        Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
-            .or(NetworkType.UNMETERED);
-    OneTimeWorkRequest downloadRequest =
-        new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
-            .setConstraints(
-                new Constraints.Builder()
-                    .setRequiredNetworkType(networkType)
-                    .setRequiresBatteryNotLow(true)
-                    .setRequiresStorageNotLow(true)
-                    .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
-                    .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
-                    .build())
-            .setBackoffCriteria(
-                BackoffPolicy.EXPONENTIAL,
-                settings.getModelDownloadBackoffDelayInMillis(),
-                MILLISECONDS)
-            .setInputData(
-                new Data.Builder()
-                    .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
-                    .putLong(
-                        ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
-                        Instant.now().toEpochMilli())
-                    .build())
-            .build();
-    ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
-        WorkManager.getInstance(appContext)
-            .enqueueUniqueWork(
-                UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
-            .getResult();
-    Futures.addCallback(
-        enqueueResultFuture,
-        new FutureCallback<Operation.State.SUCCESS>() {
-          @Override
-          public void onSuccess(Operation.State.SUCCESS unused) {
-            TcLog.v(TAG, "Download work scheduled.");
-            TextClassifierDownloadLogger.downloadWorkScheduled(
-                workId, reasonToSchedule, /* failedToSchedule= */ false);
-          }
+    try {
+      NetworkType networkType =
+          Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+              .or(NetworkType.UNMETERED);
+      OneTimeWorkRequest downloadRequest =
+          new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+              .setConstraints(
+                  new Constraints.Builder()
+                      .setRequiredNetworkType(networkType)
+                      .setRequiresBatteryNotLow(true)
+                      .setRequiresStorageNotLow(true)
+                      .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
+                      .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
+                      .build())
+              .setBackoffCriteria(
+                  BackoffPolicy.EXPONENTIAL,
+                  settings.getModelDownloadBackoffDelayInMillis(),
+                  MILLISECONDS)
+              .setInputData(
+                  new Data.Builder()
+                      .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
+                      .putLong(
+                          ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
+                          Instant.now().toEpochMilli())
+                      .build())
+              .build();
+      ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+          workManagerSupplier
+              .call()
+              .enqueueUniqueWork(
+                  UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+              .getResult();
+      Futures.addCallback(
+          enqueueResultFuture,
+          new FutureCallback<Operation.State.SUCCESS>() {
+            @Override
+            public void onSuccess(Operation.State.SUCCESS unused) {
+              TcLog.d(TAG, "Download work scheduled.");
+              TextClassifierDownloadLogger.downloadWorkScheduled(
+                  workId, reasonToSchedule, /* failedToSchedule= */ false);
+            }
 
-          @Override
-          public void onFailure(Throwable t) {
-            TcLog.e(TAG, "Failed to schedule download work: ", t);
-            TextClassifierDownloadLogger.downloadWorkScheduled(
-                workId, reasonToSchedule, /* failedToSchedule= */ true);
-          }
-        },
-        executorService);
+            @Override
+            public void onFailure(Throwable t) {
+              TcLog.e(TAG, "Failed to schedule download work: ", t);
+              TextClassifierDownloadLogger.downloadWorkScheduled(
+                  workId, reasonToSchedule, /* failedToSchedule= */ true);
+            }
+          },
+          executorService);
+    } catch (Throwable t) {
+      // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future.
+      TcLog.e(TAG, "Failed to schedule download work: ", t);
+      TextClassifierDownloadLogger.downloadWorkScheduled(
+          workId, reasonToSchedule, /* failedToSchedule= */ true);
+    }
   }
 
   private void maybeOverrideLocaleListForTesting() {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
index 6e04e16..3db0815 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
@@ -113,6 +113,7 @@
 
   @Override
   public final ListenableFuture<ListenableWorker.Result> startWork() {
+    TcLog.d(TAG, "Start download work...");
     workStartedTimeMillis = getCurrentTimeMillis();
     // Notice: startWork() is invoked on the main thread
     if (!settings.isModelDownloadManagerEnabled()) {
@@ -121,7 +122,6 @@
           TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
       return Futures.immediateFuture(ListenableWorker.Result.failure());
     }
-    TcLog.v(TAG, "Start download work...");
     if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
       TcLog.d(TAG, "Max attempt reached. Abort download work.");
       logDownloadWorkCompleted(
@@ -134,7 +134,7 @@
             downloadResult -> {
               Preconditions.checkNotNull(manifestsToDownload);
               downloadedModelManager.onDownloadCompleted(manifestsToDownload);
-              TcLog.v(TAG, "Download work completed: " + downloadResult);
+              TcLog.d(TAG, "Download work completed: " + downloadResult);
               if (downloadResult.failureCount() == 0) {
                 logDownloadWorkCompleted(
                     downloadResult.successCount() > 0
@@ -239,7 +239,7 @@
     return Futures.whenAllComplete(downloadResultFutures)
         .call(
             () -> {
-              TcLog.v(TAG, "All Download Tasks Completed");
+              TcLog.d(TAG, "All Download Tasks Completed");
               int successCount = 0;
               int failureCount = 0;
               for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
@@ -333,7 +333,7 @@
       Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
       if (downloadedManifest != null
           && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
-        TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl);
+        TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
         return Futures.immediateVoidFuture();
       }
       if (pendingDownloads.containsKey(manifestUrl)) {
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
index 2244e9a..0b76f22 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -99,7 +99,7 @@
         new FutureCallback<File>() {
           @Override
           public void onSuccess(File pendingModelFile) {
-            TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
+            TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
           }
 
           @Override
@@ -170,11 +170,11 @@
     } catch (IOException e) {
       throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e);
     }
-    TcLog.v(TAG, "Pending model file passed validation.");
+    TcLog.d(TAG, "Pending model file passed validation.");
   }
 
   private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
-    TcLog.v(TAG, "Starting a new connection to ModelDownloaderService");
+    TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
     return CallbackToFutureAdapter.getFuture(
         completer -> {
           conn.attachCompleter(completer);
@@ -197,7 +197,7 @@
   // restult future will hang there until time out. (WorkManager forces a 10-min running time.)
   private static ListenableFuture<Long> scheduleDownload(
       IModelDownloaderService service, URI uri, File targetFile) {
-    TcLog.v(TAG, "Scheduling a new download task with ModelDownloaderService");
+    TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
     return CallbackToFutureAdapter.getFuture(
         completer -> {
           service.download(
@@ -236,7 +236,7 @@
 
     @Override
     public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
-      TcLog.v(TAG, "DownloaderService connected");
+      TcLog.d(TAG, "DownloaderService connected");
       completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
     }
 
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
index e4ebbfa..6d7e47e 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
@@ -39,7 +39,7 @@
 
   @Override
   public IBinder onBind(Intent intent) {
-    TcLog.v(TAG, "Binding to ModelDownloadService");
+    TcLog.d(TAG, "Binding to ModelDownloadService");
     return iBinder;
   }
 }
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
index 439588b..47e6f19 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
@@ -91,7 +91,7 @@
 
   @Override
   public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) {
-    TcLog.v(TAG, "Download request received: " + uri);
+    TcLog.d(TAG, "Download request received: " + uri);
     try {
       File targetFile = new File(targetFilePath);
       File tempMetadataFile = getMetadataFile(targetFile);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 71f9a4f..ddab8bd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -17,7 +17,10 @@
 package com.android.textclassifier;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import android.content.Context;
 import android.os.CancellationSignal;
@@ -39,6 +42,7 @@
 import com.android.os.AtomsProto.TextClassifierApiUsageReported;
 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.ModelType;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.statsd.StatsdTestUtils;
 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
@@ -47,6 +51,7 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
+import java.io.IOException;
 import java.util.List;
 import java.util.concurrent.Executor;
 import java.util.stream.Collectors;
@@ -81,13 +86,21 @@
   @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
   @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
   @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+  @Mock private ModelFileManager testModelFileManager;
 
   @Before
-  public void setup() {
-
-    testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+  public void setup() throws IOException {
+    testInjector =
+        new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager);
     defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
     defaultTextClassifierService.onCreate();
+
+    when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+    when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+        .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+    when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+        .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
   }
 
   @Before
@@ -211,11 +224,8 @@
 
   @Test
   public void missingModelFile_onFailureShouldBeCalled() throws Exception {
-    testInjector.setModelFileManager(
-        new ModelFileManagerImpl(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(),
-            testInjector.createTextClassifierSettings()));
+    when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(null);
     defaultTextClassifierService.onCreate();
 
     TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
@@ -251,12 +261,9 @@
     private final Context context;
     private ModelFileManager modelFileManager;
 
-    private TestInjector(Context context) {
+    private TestInjector(Context context, ModelFileManager modelFileManager) {
       this.context = Preconditions.checkNotNull(context);
-    }
-
-    private void setModelFileManager(ModelFileManager modelFileManager) {
-      this.modelFileManager = modelFileManager;
+      this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
     }
 
     @Override
@@ -267,9 +274,6 @@
     @Override
     public ModelFileManager createModelFileManager(
         TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
-      if (modelFileManager == null) {
-        return TestDataUtils.createModelFileManagerForTesting(context);
-      }
       return modelFileManager;
     }
 
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
index 20ae592..0e40515 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -25,6 +25,7 @@
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
+import androidx.work.WorkManager;
 import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
 import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
 import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
@@ -87,6 +88,7 @@
         new ModelDownloadManager(
             context,
             ModelDownloadWorker.class,
+            () -> WorkManager.getInstance(context),
             downloadedModelManager,
             settings,
             MoreExecutors.newDirectExecutorService());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index bac4fa1..a19e3ff 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,12 +16,10 @@
 
 package com.android.textclassifier;
 
-import android.content.Context;
-import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelFile;
 import com.android.textclassifier.common.ModelType;
-import com.android.textclassifier.common.TextClassifierSettings;
-import com.google.common.collect.ImmutableList;
 import java.io.File;
+import java.io.IOException;
 
 /** Utils to access test data files. */
 public final class TestDataUtils {
@@ -30,7 +28,7 @@
   private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
 
   /** Returns the root folder that contains the test data. */
-  public static File getTestDataFolder() {
+  private static File getTestDataFolder() {
     return new File("/data/local/tmp/TextClassifierServiceTest/");
   }
 
@@ -38,24 +36,25 @@
     return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
   }
 
+  public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException {
+    return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR);
+  }
+
   public static File getTestActionsModelFile() {
     return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
   }
 
+  public static ModelFile getTestActionsModelFileWrapped() throws IOException {
+    return ModelFile.createFromRegularFile(
+        getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS);
+  }
+
   public static File getLangIdModelFile() {
     return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
   }
 
-  public static ModelFileManager createModelFileManagerForTesting(Context context) {
-    return new ModelFileManagerImpl(
-        context,
-        ImmutableList.of(
-            new RegularFileFullMatchLister(
-                ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
-            new RegularFileFullMatchLister(
-                ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
-            new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)),
-        new TextClassifierSettings());
+  public static ModelFile getLangIdModelFileWrapped() throws IOException {
+    return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID);
   }
 
   private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 42177e6..e7bf90c 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -56,6 +56,10 @@
 
   @Before
   public void setup() {
+    extServicesTextClassifierRule.enableVerboseLogging();
+    // Verbose logging only takes effect after restarting ExtServices
+    extServicesTextClassifierRule.forceStopExtServices();
+
     textClassifier = extServicesTextClassifierRule.getTextClassifier();
   }
 
@@ -81,8 +85,8 @@
 
   @Test
   public void classifyText() {
-    String text = "Contact me at droid@android.com";
-    String classifiedText = "droid@android.com";
+    String text = "Contact me at http://www.android.com";
+    String classifiedText = "http://www.android.com";
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
@@ -90,7 +94,7 @@
 
     TextClassification classification = textClassifier.classifyText(request);
     assertThat(classification.getEntityCount()).isGreaterThan(0);
-    assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+    assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
     assertThat(classification.getText()).isEqualTo(classifiedText);
     assertThat(classification.getActions()).isNotEmpty();
   }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index fb1aea8..c20ec8a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -22,6 +22,9 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.expectThrows;
 
@@ -74,30 +77,34 @@
   private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
   private static final String NO_TYPE = null;
 
-  @Mock private ModelFileManagerImpl.ModelFileLister mockModelFileLister;
+  @Mock private ModelFileManager modelFileManager;
 
-  private TextClassifierSettings settings;
   private Context context;
   private TestingDeviceConfig deviceConfig;
+  private TextClassifierSettings settings;
+  private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
   private TextClassifierImpl classifier;
 
-  private final ModelFileManager modelFileManager =
-      TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
-  private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
-
   @Before
-  public void setup() {
+  public void setup() throws IOException {
     MockitoAnnotations.initMocks(this);
-    deviceConfig = new TestingDeviceConfig();
-    Context context =
+    this.context =
         new FakeContextBuilder()
             .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
             .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
             .build();
-    this.context = context;
-    settings = new TextClassifierSettings(deviceConfig);
-    // TODO(veronikanikina): consider using a testing constructor here.
-    classifier = new TextClassifierImpl(context, settings, modelFileManager);
+    this.deviceConfig = new TestingDeviceConfig();
+    this.settings = new TextClassifierSettings(deviceConfig);
+    this.annotatorModelCache = new LruCache<>(2);
+    this.classifier =
+        new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
+
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+    when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+        .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+    when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+        .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
   }
 
   @Test
@@ -110,9 +117,7 @@
     int smartStartIndex = text.indexOf(suggested);
     int smartEndIndex = smartStartIndex + suggested.length();
     TextSelection.Request request =
-        new TextSelection.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextSelection.Request.Builder(text, startIndex, endIndex).build();
 
     TextSelection selection = classifier.suggestSelection(null, null, request);
     assertThat(
@@ -120,6 +125,24 @@
   }
 
   @Test
+  public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
+    String text = "Contact me at droid@android.com";
+    String selected = "droid";
+    String suggested = "droid@android.com";
+    int startIndex = text.indexOf(selected);
+    int endIndex = startIndex + selected.length();
+    int smartStartIndex = text.indexOf(suggested);
+    int smartEndIndex = smartStartIndex + suggested.length();
+    TextSelection.Request request =
+        new TextSelection.Request.Builder(text, startIndex, endIndex)
+            .setDefaultLocales(LOCALES)
+            .build();
+
+    classifier.suggestSelection(null, null, request);
+    verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
+  }
+
+  @Test
   public void testSuggestSelection_url() throws IOException {
     String text = "Visit http://www.android.com for more information";
     String selected = "http";
@@ -129,9 +152,7 @@
     int smartStartIndex = text.indexOf(suggested);
     int smartEndIndex = smartStartIndex + suggested.length();
     TextSelection.Request request =
-        new TextSelection.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextSelection.Request.Builder(text, startIndex, endIndex).build();
 
     TextSelection selection = classifier.suggestSelection(null, null, request);
     assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
@@ -144,9 +165,7 @@
     int startIndex = text.indexOf(selected);
     int endIndex = startIndex + selected.length();
     TextSelection.Request request =
-        new TextSelection.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextSelection.Request.Builder(text, startIndex, endIndex).build();
 
     TextSelection selection = classifier.suggestSelection(null, null, request);
     assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
@@ -160,7 +179,6 @@
     int startIndex = text.indexOf(suggested);
     TextSelection.Request request =
         new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1)
-            .setDefaultLocales(LOCALES)
             .setIncludeTextClassification(true)
             .build();
 
@@ -178,7 +196,6 @@
     String text = "Visit http://www.android.com for more information";
     TextSelection.Request request =
         new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4)
-            .setDefaultLocales(LOCALES)
             .setIncludeTextClassification(false)
             .build();
 
@@ -194,9 +211,7 @@
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, startIndex, endIndex).build();
 
     TextClassification classification =
         classifier.classifyText(/* sessionId= */ null, null, request);
@@ -210,9 +225,7 @@
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, startIndex, endIndex).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -223,9 +236,7 @@
   public void testClassifyText_address() throws IOException {
     String text = "Brandschenkestrasse 110, Zürich, Switzerland";
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, 0, text.length())
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, 0, text.length()).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
@@ -238,9 +249,7 @@
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, startIndex, endIndex).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -254,9 +263,7 @@
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, startIndex, endIndex).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
@@ -275,9 +282,7 @@
     int startIndex = text.indexOf(classifiedText);
     int endIndex = startIndex + classifiedText.length();
     TextClassification.Request request =
-        new TextClassification.Request.Builder(text, startIndex, endIndex)
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(text, startIndex, endIndex).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
@@ -289,14 +294,12 @@
     LocaleList.setDefault(LocaleList.forLanguageTags("en"));
     String japaneseText = "これは日本語のテキストです";
     TextClassification.Request request =
-        new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
-            .setDefaultLocales(LOCALES)
-            .build();
+        new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
 
     TextClassification classification = classifier.classifyText(null, null, request);
     RemoteAction translateAction = classification.getActions().get(0);
     assertEquals(1, classification.getActions().size());
-    assertEquals("Translate", translateAction.getTitle().toString());
+    assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
 
     assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
     Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
@@ -323,18 +326,17 @@
 
   @Test
   public void testGenerateLinks_exclude() throws IOException {
-    String text = "You want apple@banana.com. See you tonight!";
+    String text = "The number is +12122537077. See you tonight!";
     List<String> hints = ImmutableList.of();
     List<String> included = ImmutableList.of();
-    List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+    List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
     TextLinks.Request request =
         new TextLinks.Request.Builder(text)
             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
-            .setDefaultLocales(LOCALES)
             .build();
     assertThat(
         classifier.generateLinks(null, null, request),
-        not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+        not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
   }
 
   @Test
@@ -344,7 +346,6 @@
     TextLinks.Request request =
         new TextLinks.Request.Builder(text)
             .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
-            .setDefaultLocales(LOCALES)
             .build();
     assertThat(
         classifier.generateLinks(null, null, request),
@@ -361,7 +362,6 @@
     TextLinks.Request request =
         new TextLinks.Request.Builder(text)
             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
-            .setDefaultLocales(LOCALES)
             .build();
     assertThat(
         classifier.generateLinks(null, null, request),
@@ -573,29 +573,16 @@
         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
     ModelFile annotatorModelB =
         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
-    String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
-    ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
 
-    annotatorModelCache = new LruCache<>(2);
-    ModelFileManager modelFileManagerCached =
-        new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
-    TextClassifierImpl textClassifierImpl =
-        new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
-
-    LocaleList.setDefault(LocaleList.forLanguageTags("en"));
     String englishText = "You can reach me on +12122537077.";
     String classifiedText = "+12122537077";
     TextClassification.Request request =
-        new TextClassification.Request.Builder(englishText, 0, englishText.length())
-            .setDefaultLocales(LOCALES)
-            .build();
-
-    when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+        new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
 
     // Check modelFileA v701
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelA));
-    TextClassification classificationA = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelA);
+    TextClassification classificationA = classifier.classifyText(null, null, request);
 
     assertThat(classificationA.getId()).contains("v701");
     assertThat(classificationA.getText()).contains(classifiedText);
@@ -609,9 +596,9 @@
         });
 
     // Check modelFileB v801
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelB));
-    TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelB);
+    TextClassification classificationB = classifier.classifyText(null, null, request);
 
     assertThat(classificationB.getId()).contains("v801");
     assertThat(classificationB.getText()).contains(classifiedText);
@@ -625,9 +612,9 @@
         });
 
     // Reload modelFileA v701
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelA));
-    TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelA);
+    TextClassification classificationAcached = classifier.classifyText(null, null, request);
 
     assertThat(classificationAcached.getId()).contains("v701");
     assertThat(classificationAcached.getText()).contains(classifiedText);
@@ -651,28 +638,16 @@
         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
     ModelFile annotatorModelB =
         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
-    String langIdFilePath = TestDataUtils.getLangIdModelFile().getPath();
-    ModelFile langIdModel = new ModelFile(ModelType.LANG_ID, langIdFilePath, 1, "*", false);
 
-    annotatorModelCache = new LruCache<>(settings.getMultiAnnotatorCacheSize());
-    ModelFileManager modelFileManagerCached =
-        new ModelFileManagerImpl(context, ImmutableList.of(mockModelFileLister), settings);
-    TextClassifierImpl textClassifierImpl =
-        new TextClassifierImpl(context, settings, modelFileManagerCached, annotatorModelCache);
-    LocaleList.setDefault(LocaleList.forLanguageTags("en"));
     String englishText = "You can reach me on +12122537077.";
     String classifiedText = "+12122537077";
     TextClassification.Request request =
-        new TextClassification.Request.Builder(englishText, 0, englishText.length())
-            .setDefaultLocales(LOCALES)
-            .build();
-
-    when(mockModelFileLister.list(ModelType.LANG_ID)).thenReturn(ImmutableList.of(langIdModel));
+        new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
 
     // Check modelFileA v701
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelA));
-    TextClassification classification = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelA);
+    TextClassification classification = classifier.classifyText(null, null, request);
 
     assertThat(classification.getId()).contains("v701");
     assertThat(classification.getText()).contains(classifiedText);
@@ -686,9 +661,9 @@
         });
 
     // Check modelFileB v801
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelB));
-    TextClassification classificationB = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelB);
+    TextClassification classificationB = classifier.classifyText(null, null, request);
 
     assertThat(classificationB.getId()).contains("v801");
     assertThat(classificationB.getText()).contains(classifiedText);
@@ -702,9 +677,9 @@
         });
 
     // Reload modelFileA v701
-    when(mockModelFileLister.list(ModelType.ANNOTATOR))
-        .thenReturn(ImmutableList.of(annotatorModelA));
-    TextClassification classificationAcached = textClassifierImpl.classifyText(null, null, request);
+    when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+        .thenReturn(annotatorModelA);
+    TextClassification classificationAcached = classifier.classifyText(null, null, request);
 
     assertThat(classificationAcached.getId()).contains("v701");
     assertThat(classificationAcached.getText()).contains(classifiedText);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index 394b7ad..9e11c09 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -67,6 +67,7 @@
   private TestingDeviceConfig deviceConfig;
   private WorkManager workManager;
   private ModelDownloadManager downloadManager;
+  private ModelDownloadManager downloadManagerWithBadWorkManager;
   @Mock DownloadedModelManager downloadedModelManager;
 
   @Before
@@ -80,6 +81,17 @@
         new ModelDownloadManager(
             context,
             ModelDownloadWorker.class,
+            () -> workManager,
+            downloadedModelManager,
+            new TextClassifierSettings(deviceConfig),
+            MoreExecutors.newDirectExecutorService());
+    this.downloadManagerWithBadWorkManager =
+        new ModelDownloadManager(
+            context,
+            ModelDownloadWorker.class,
+            () -> {
+              throw new IllegalStateException("WorkManager may fail!");
+            },
             downloadedModelManager,
             new TextClassifierSettings(deviceConfig),
             MoreExecutors.newDirectExecutorService());
@@ -96,7 +108,20 @@
   }
 
   @Test
+  public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception {
+    assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+    downloadManagerWithBadWorkManager.onTextClassifierServiceCreated();
+
+    // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+    TextClassifierDownloadWorkScheduled atom =
+        Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+    assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED);
+    assertThat(atom.getFailedToSchedule()).isTrue();
+  }
+
+  @Test
   public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
+    assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
     downloadManager.onTextClassifierServiceCreated();
 
     WorkInfo workInfo =
@@ -104,21 +129,34 @@
             DownloaderTestUtils.queryWorkInfos(
                 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+    // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
     verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
   }
 
   @Test
   public void onTextClassifierServiceCreated_localeListOverridden() throws Exception {
+    assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
     deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
     downloadManager.onTextClassifierServiceCreated();
 
     assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
     assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
     assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+    // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
     verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
   }
 
   @Test
+  public void onLocaleChanged_workManagerCrashed() throws Exception {
+    downloadManagerWithBadWorkManager.onLocaleChanged();
+
+    TextClassifierDownloadWorkScheduled atom =
+        Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+    assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
+    assertThat(atom.getFailedToSchedule()).isTrue();
+  }
+
+  @Test
   public void onLocaleChanged_requestEnqueued() throws Exception {
     downloadManager.onLocaleChanged();
 
@@ -131,6 +169,16 @@
   }
 
   @Test
+  public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception {
+    downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged();
+
+    TextClassifierDownloadWorkScheduled atom =
+        Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+    assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+    assertThat(atom.getFailedToSchedule()).isTrue();
+  }
+
+  @Test
   public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
     downloadManager.onTextClassifierDeviceConfigChanged();
 
@@ -188,6 +236,13 @@
     assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
   }
 
+  @Test
+  public void listDownloadedModels_doNotCrashOnError() throws Exception {
+    when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException());
+
+    assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty();
+  }
+
   private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception {
     TextClassifierDownloadWorkScheduled atom =
         Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
index e4360c6..e261158 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -18,16 +18,10 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
-import android.app.Instrumentation;
-import android.app.UiAutomation;
 import android.util.Log;
 import android.view.textclassifier.TextClassification;
 import android.view.textclassifier.TextClassification.Request;
-import android.view.textclassifier.TextClassifier;
-import androidx.test.platform.app.InstrumentationRegistry;
 import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
-import com.android.textclassifier.testing.TestingLocaleListOverrideRule;
-import java.io.IOException;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -48,171 +42,133 @@
   private static final String V804_EN_TAG = "en_v804";
   private static final String V804_RU_TAG = "ru_v804";
   private static final String FACTORY_MODEL_TAG = "*";
-
-  @Rule
-  public final TestingLocaleListOverrideRule testingLocaleListOverrideRule =
-      new TestingLocaleListOverrideRule();
+  private static final int ASSERT_MAX_ATTEMPTS = 20;
+  private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000;
 
   @Rule
   public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
       new ExtServicesTextClassifierRule();
 
-  private TextClassifier textClassifier;
-
   @Before
   public void setup() throws Exception {
-    // Flag overrides below can be overridden by Phenotype sync, which makes this test flaky
-    runShellCommand("device_config put textclassifier config_updater_model_enabled false");
-    runShellCommand("device_config put textclassifier model_download_manager_enabled true");
-    runShellCommand("device_config put textclassifier model_download_backoff_delay_in_millis 5");
+    extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false");
+    extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true");
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "model_download_backoff_delay_in_millis", "5");
+    extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US");
+    extServicesTextClassifierRule.overrideDeviceConfig();
 
-    textClassifier = extServicesTextClassifierRule.getTextClassifier();
-    startExtservicesProcess();
+    extServicesTextClassifierRule.enableVerboseLogging();
+    // Verbose logging only takes effect after restarting ExtServices
+    extServicesTextClassifierRule.forceStopExtServices();
   }
 
   @After
   public void tearDown() throws Exception {
-    runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
-    runShellCommand("device_config delete textclassifier manifest_url_annotator_ru");
-    runShellCommand("device_config put textclassifier config_updater_model_enabled true");
-    runShellCommand("device_config delete textclassifier multi_language_support_enabled");
-    runShellCommand(
-        "device_config put textclassifier model_download_backoff_delay_in_millis 3600000");
+    // This is to reset logging/locale_override for ExtServices.
+    extServicesTextClassifierRule.forceStopExtServices();
   }
 
   @Test
-  public void smokeTest() throws IOException, InterruptedException {
-    runShellCommand(
-        "device_config put textclassifier manifest_url_annotator_en "
-            + V804_EN_ANNOTATOR_MANIFEST_URL);
+  public void smokeTest() throws Exception {
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
 
-    assertWithRetries(
-        /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
+    assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
   }
 
   @Test
-  public void downgradeModel() throws IOException, InterruptedException {
+  public void downgradeModel() throws Exception {
     // Download an experimental model.
-    {
-      runShellCommand(
-          "device_config put textclassifier manifest_url_annotator_en "
-              + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
 
-      assertWithRetries(
-          /* maxAttempts= */ 10,
-          /* sleepMs= */ 1000,
-          () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
-    }
+    assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
 
     // Downgrade to an older model.
-    {
-      runShellCommand(
-          "device_config put textclassifier manifest_url_annotator_en "
-              + V804_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
 
-      assertWithRetries(
-          /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
-    }
+    assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
   }
 
   @Test
-  public void upgradeModel() throws IOException, InterruptedException {
+  public void upgradeModel() throws Exception {
     // Download a model.
-    {
-      runShellCommand(
-          "device_config put textclassifier manifest_url_annotator_en "
-              + V804_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
 
-      assertWithRetries(
-          /* maxAttempts= */ 10, /* sleepMs= */ 1000, () -> verifyActiveEnglishModel(V804_EN_TAG));
-    }
+    assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
 
     // Upgrade to an experimental model.
-    {
-      runShellCommand(
-          "device_config put textclassifier manifest_url_annotator_en "
-              + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
 
-      assertWithRetries(
-          /* maxAttempts= */ 10,
-          /* sleepMs= */ 1000,
-          () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
-    }
+    assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
   }
 
   @Test
-  public void clearFlag() throws IOException, InterruptedException {
+  public void clearFlag() throws Exception {
     // Download a new model.
-    {
-      runShellCommand(
-          "device_config put textclassifier manifest_url_annotator_en "
-              + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
 
-      assertWithRetries(
-          /* maxAttempts= */ 10,
-          /* sleepMs= */ 1000,
-          () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
-    }
+    assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
 
     // Revert the flag.
-    {
-      runShellCommand("device_config delete textclassifier manifest_url_annotator_en");
-      // Fallback to use the universal model.
-      assertWithRetries(
-          /* maxAttempts= */ 10,
-          /* sleepMs= */ 1000,
-          () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
-    }
+    extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", "");
+    // Fallback to use the universal model.
+    assertWithRetries(
+        () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
   }
 
   @Test
-  public void modelsForMultipleLanguagesDownloaded() throws IOException, InterruptedException {
-    runShellCommand("device_config put textclassifier multi_language_support_enabled true");
-    testingLocaleListOverrideRule.set("en-US", "ru-RU");
+  public void modelsForMultipleLanguagesDownloaded() throws Exception {
+    extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true");
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "testing_locale_list_override", "en-US,ru-RU");
 
     // download en model
-    runShellCommand(
-        "device_config put textclassifier manifest_url_annotator_en "
-            + EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
 
     // download ru model
-    runShellCommand(
-        "device_config put textclassifier manifest_url_annotator_ru "
-            + V804_RU_ANNOTATOR_MANIFEST_URL);
-    assertWithRetries(
-        /* maxAttempts= */ 10,
-        /* sleepMs= */ 1000,
-        () -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+    extServicesTextClassifierRule.addDeviceConfigOverride(
+        "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL);
+    assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
 
-    assertWithRetries(/* maxAttempts= */ 10, /* sleepMs= */ 1000, this::verifyActiveRussianModel);
+    assertWithRetries(this::verifyActiveRussianModel);
 
     assertWithRetries(
-        /* maxAttempts= */ 10,
-        /* sleepMs= */ 1000,
         () -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG));
   }
 
-  private void assertWithRetries(int maxAttempts, int sleepMs, Runnable assertRunnable)
-      throws InterruptedException {
-    for (int i = 0; i < maxAttempts; i++) {
+  private void assertWithRetries(Runnable assertRunnable) throws Exception {
+    for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) {
       try {
+        extServicesTextClassifierRule.overrideDeviceConfig();
         assertRunnable.run();
         break; // success. Bail out.
       } catch (AssertionError ex) {
-        if (i == maxAttempts - 1) { // last attempt, give up.
+        if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up.
+          extServicesTextClassifierRule.dumpDefaultTextClassifierService();
           throw ex;
         } else {
-          Thread.sleep(sleepMs);
+          Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS);
         }
+      } catch (Exception unknownException) {
+        throw unknownException;
       }
     }
   }
 
   private void verifyActiveModel(String text, String expectedVersion) {
     TextClassification textClassification =
-        textClassifier.classifyText(new Request.Builder(text, 0, text.length()).build());
-    Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
+        extServicesTextClassifierRule
+            .getTextClassifier()
+            .classifyText(new Request.Builder(text, 0, text.length()).build());
     // The result id contains the name of the just used model.
+    Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
     assertThat(textClassification.getId()).contains(expectedVersion);
   }
 
@@ -223,16 +179,4 @@
   private void verifyActiveRussianModel() {
     verifyActiveModel("привет", V804_RU_TAG);
   }
-
-  private void startExtservicesProcess() {
-    // Start the process of ExtServices by sending it a text classifier request.
-    textClassifier.classifyText(new TextClassification.Request.Builder("abc", 0, 3).build());
-  }
-
-  private static void runShellCommand(String cmd) {
-    Log.v(TAG, "run shell command: " + cmd);
-    Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
-    UiAutomation uiAutomation = instrumentation.getUiAutomation();
-    uiAutomation.executeShellCommand(cmd);
-  }
 }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
index 3ceb47b..5f8247d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
@@ -20,53 +20,103 @@
 import android.content.pm.PackageManager;
 import android.content.pm.PackageManager.NameNotFoundException;
 import android.provider.DeviceConfig;
+import android.util.Log;
 import android.view.textclassifier.TextClassificationManager;
 import android.view.textclassifier.TextClassifier;
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.io.ByteStreams;
+import java.io.FileInputStream;
+import java.io.IOException;
 import org.junit.rules.ExternalResource;
 
 /** A rule that manages a text classifier that is backed by the ExtServices. */
 public final class ExtServicesTextClassifierRule extends ExternalResource {
+  private static final String TAG = "androidtc";
   private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
       "textclassifier_service_package_override";
   private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
   private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
 
-  private String textClassifierServiceOverrideFlagOldValue;
+  private UiAutomation uiAutomation;
+  private DeviceConfig.Properties originalProperties;
+  private DeviceConfig.Properties.Builder newPropertiesBuilder;
 
   @Override
-  protected void before() {
-    UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+  protected void before() throws Exception {
+    uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+    uiAutomation.adoptShellPermissionIdentity();
+    originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER);
+    newPropertiesBuilder =
+        new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER)
+            .setString(
+                CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName());
+    overrideDeviceConfig();
+  }
+
+  @Override
+  protected void after() {
     try {
-      uiAutomation.adoptShellPermissionIdentity();
-      textClassifierServiceOverrideFlagOldValue =
-          DeviceConfig.getString(
-              DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
-              CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
-              null);
-      DeviceConfig.setProperty(
-          DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
-          CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
-          getExtServicesPackageName(),
-          /* makeDefault= */ false);
+      DeviceConfig.setProperties(originalProperties);
+    } catch (Throwable t) {
+      Log.e(TAG, "Failed to reset DeviceConfig", t);
     } finally {
       uiAutomation.dropShellPermissionIdentity();
     }
   }
 
-  @Override
-  protected void after() {
-    UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
-    try {
-      uiAutomation.adoptShellPermissionIdentity();
-      DeviceConfig.setProperty(
-          DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
-          CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
-          textClassifierServiceOverrideFlagOldValue,
-          /* makeDefault= */ false);
-    } finally {
-      uiAutomation.dropShellPermissionIdentity();
+  public void addDeviceConfigOverride(String name, String value) {
+    newPropertiesBuilder.setString(name, value);
+  }
+
+  /**
+   * Overrides the TextClassifier DeviceConfig manually.
+   *
+   * <p>This will clean up all device configs not in newPropertiesBuilder.
+   *
+   * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in
+   * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a
+   * hidden API.
+   */
+  public void overrideDeviceConfig() throws Exception {
+    DeviceConfig.setProperties(newPropertiesBuilder.build());
+  }
+
+  /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */
+  public void forceStopExtServices() {
+    runShellCommand("am force-stop com.google.android.ext.services");
+    runShellCommand("am force-stop android.ext.services");
+  }
+
+  public TextClassifier getTextClassifier() {
+    TextClassificationManager textClassificationManager =
+        ApplicationProvider.getApplicationContext()
+            .getSystemService(TextClassificationManager.class);
+    textClassificationManager.setTextClassifier(null); // Reset TC overrides
+    return textClassificationManager.getTextClassifier();
+  }
+
+  public void dumpDefaultTextClassifierService() {
+    runShellCommand(
+        "dumpsys activity service com.google.android.ext.services/"
+            + "com.android.textclassifier.DefaultTextClassifierService");
+    runShellCommand("cmd device_config list textclassifier");
+  }
+
+  public void enableVerboseLogging() {
+    runShellCommand("setprop log.tag.androidtc VERBOSE");
+  }
+
+  private void runShellCommand(String cmd) {
+    Log.v(TAG, "run shell command: " + cmd);
+    try (FileInputStream output =
+        new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) {
+      String cmdOutput = new String(ByteStreams.toByteArray(output));
+      if (!cmdOutput.isEmpty()) {
+        Log.d(TAG, "cmd output: " + cmdOutput);
+      }
+    } catch (IOException ioe) {
+      Log.w(TAG, "failed to get cmd output", ioe);
     }
   }
 
@@ -79,12 +129,4 @@
       return PKG_NAME_AOSP_EXTSERVICES;
     }
   }
-
-  public TextClassifier getTextClassifier() {
-    TextClassificationManager textClassificationManager =
-        ApplicationProvider.getApplicationContext()
-            .getSystemService(TextClassificationManager.class);
-    textClassificationManager.setTextClassifier(null); // Reset TC overrides
-    return textClassificationManager.getTextClassifier();
-  }
 }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
deleted file mode 100644
index 7d46e97..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/testing/TestingLocaleListOverrideRule.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.testing;
-
-import android.app.UiAutomation;
-import android.os.LocaleList;
-import android.util.Log;
-import androidx.test.platform.app.InstrumentationRegistry;
-import org.junit.rules.ExternalResource;
-
-/** class for overriding testing_locale_list_override from {@link TextClassifierSettings} */
-public final class TestingLocaleListOverrideRule extends ExternalResource {
-  private static final String TAG = "TestingLocaleListOverrideRule";
-
-  private LocaleList originalLocaleList;
-
-  @Override
-  protected void before() {
-    originalLocaleList = LocaleList.getDefault();
-  }
-
-  public void set(String... localeTags) {
-    if (localeTags.length == 0) {
-      return;
-    }
-    runShellCommand(
-        "device_config put textclassifier testing_locale_list_override "
-            + String.join(",", localeTags));
-  }
-
-  @Override
-  protected void after() {
-    runShellCommand(
-        "device_config put textclassifier testing_locale_list_override "
-            + originalLocaleList.toLanguageTags());
-    runShellCommand("device_config delete textclassifier testing_locale_list_override");
-  }
-
-  private static void runShellCommand(String cmd) {
-    Log.v(TAG, "run shell command: " + cmd);
-    UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
-    uiAutomation.executeShellCommand(cmd);
-  }
-}
diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc
index d4d119e..964e316 100644
--- a/native/utils/tokenfree/byte_encoder_test.cc
+++ b/native/utils/tokenfree/byte_encoder_test.cc
@@ -29,7 +29,7 @@
 
 using testing::ElementsAre;
 
-TEST(EncoderTest, SimpleTokenization) {
+TEST(ByteEncoderTest, SimpleTokenization) {
   const ByteEncoder encoder;
   {
     std::vector<int64_t> encoded_text;
@@ -39,7 +39,7 @@
   }
 }
 
-TEST(EncoderTest, SimpleTokenization2) {
+TEST(ByteEncoderTest, SimpleTokenization2) {
   const ByteEncoder encoder;
   {
     std::vector<int64_t> encoded_text;