Merge tm-dev-plus-aosp-without-vendor@8763363
Bug: 236760014
Merged-In: I14f950e4101027d1a6596fd1c459c4c4440b8379
Change-Id: I6bcc86898be9ef3daf6fd4cfa6a86b162ba0b9da
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 6b3b9d9..370acd6 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -16,6 +16,9 @@
},
{
"name": "TextClassifierNotificationTests"
+ },
+ {
+ "name": "TCSModelDownloaderIntegrationTest"
}
],
"hwasan-postsubmit": [
@@ -51,4 +54,4 @@
"name": "libtextclassifier_java_tests[com.google.android.extservices.apex]"
}
]
-}
\ No newline at end of file
+}
diff --git a/java/Android.bp b/java/Android.bp
index ca34a66..5948a17 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -52,8 +52,15 @@
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
android_library {
name: "TextClassifierServiceLibNoManifest",
- srcs: ["src/**/*.java"],
+ srcs: [
+ "src/**/*.java",
+ "src/**/*.aidl",
+ ],
manifest: "LibNoManifest_AndroidManifest.xml",
+ plugins: [
+ "auto_value_plugin",
+ "androidx.room_room-compiler-plugin",
+ ],
static_libs: [
"androidx.core_core",
"libtextclassifier-java",
@@ -61,6 +68,13 @@
"guava",
"textclassifier-statsd",
"error_prone_annotations",
+ "androidx.work_work-runtime",
+ "android_downloader_lib",
+ "textclassifier-statsd",
+ "textclassifier-java-proto-lite",
+ "androidx.concurrent_concurrent-futures",
+ "auto_value_annotations",
+ "androidx.room_room-runtime",
],
sdk_version: "system_current",
min_sdk_version: "30",
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index 8ef323c..26983c0 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -32,8 +32,25 @@
<uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
+ <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" />
+ <uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/>
+ <!-- The INTERNET permission is restricted to the modelDownloaderServiceProcess -->
+ <uses-permission android:name="android.permission.INTERNET"/>
<application>
+ <processes>
+ <deny-permission android:name="android.permission.INTERNET" />
+ <process />
+ <process android:process=":modelDownloaderServiceProcess">
+ <allow-permission android:name="android.permission.INTERNET" />
+ </process>
+ </processes>
+
+ <service
+ android:exported="false"
+ android:name=".downloader.ModelDownloaderService"
+ android:process=":modelDownloaderServiceProcess">
+ </service>
<service
android:exported="true"
diff --git a/java/assets/textclassifier/annotator.universal.model b/java/assets/textclassifier/annotator.universal.model
index 09f1e0b..c290f76 100755
--- a/java/assets/textclassifier/annotator.universal.model
+++ b/java/assets/textclassifier/annotator.universal.model
Binary files differ
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index beb155b..3b09673 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -27,7 +27,7 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.ConversationActions.Message;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 1f1e958..d57af5e 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -30,11 +30,11 @@
import android.view.textclassifier.TextSelection;
import androidx.annotation.NonNull;
import androidx.collection.LruCache;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.TextClassifierServiceExecutors;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
+import com.android.textclassifier.downloader.ModelDownloadManager;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
@@ -47,6 +47,7 @@
import java.io.PrintWriter;
import java.util.Map;
import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
@@ -58,6 +59,9 @@
// TODO: Figure out do we need more concurrency.
private ListeningExecutorService normPriorityExecutor;
private ListeningExecutorService lowPriorityExecutor;
+
+ @Nullable private ModelDownloadManager modelDownloadManager;
+
private TextClassifierImpl textClassifier;
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
@@ -77,9 +81,14 @@
@Override
public void onCreate() {
super.onCreate();
-
settings = injector.createTextClassifierSettings();
- modelFileManager = injector.createModelFileManager(settings);
+ modelDownloadManager =
+ new ModelDownloadManager(
+ injector.getContext().getApplicationContext(),
+ settings,
+ TextClassifierServiceExecutors.getDownloaderExecutor());
+ modelDownloadManager.onTextClassifierServiceCreated();
+ modelFileManager = injector.createModelFileManager(settings, modelDownloadManager);
normPriorityExecutor = injector.createNormPriorityExecutor();
lowPriorityExecutor = injector.createLowPriorityExecutor();
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
@@ -91,6 +100,7 @@
@Override
public void onDestroy() {
super.onDestroy();
+ modelDownloadManager.destroy();
}
@Override
@@ -197,11 +207,21 @@
@Override
protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
- IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
- // TODO(licha): Also dump ModelDownloadManager for debugging
- textClassifier.dump(indentingPrintWriter);
- dumpImpl(indentingPrintWriter);
- indentingPrintWriter.flush();
+ // Dump in a background thread b/c we may need to query Room db (e.g. to init model cache)
+ try {
+ TextClassifierServiceExecutors.getLowPriorityExecutor()
+ .submit(
+ () -> {
+ IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
+ textClassifier.dump(indentingPrintWriter);
+ modelDownloadManager.dump(indentingPrintWriter);
+ dumpImpl(indentingPrintWriter);
+ indentingPrintWriter.flush();
+ })
+ .get();
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TAG, "Failed to dump Default TextClassifierService", e);
+ }
}
private void dumpImpl(IndentingPrintWriter printWriter) {
@@ -289,8 +309,9 @@
}
@Override
- public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
- return new ModelFileManager(context, settings);
+ public ModelFileManager createModelFileManager(
+ TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
+ return new ModelFileManagerImpl(context, modelDownloadManager, settings);
}
@Override
@@ -329,7 +350,8 @@
interface Injector {
Context getContext();
- ModelFileManager createModelFileManager(TextClassifierSettings settings);
+ ModelFileManager createModelFileManager(
+ TextClassifierSettings settings, ModelDownloadManager modelDownloadManager);
TextClassifierSettings createTextClassifierSettings();
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index fd64581..bde3898 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -87,7 +87,9 @@
return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
}
- /** @see #getTopLanguage(Intent) */
+ /**
+ * @see #getTopLanguage(Intent)
+ */
static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
final int maxSize = Math.min(3, languageScores.getEntities().size());
final String[] languages =
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
new file mode 100644
index 0000000..1a03b4a
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.LocaleList;
+import com.android.textclassifier.common.ModelFile;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import javax.annotation.Nullable;
+
+/**
+ * Interface to list model files, find the best model file for a use case and dump internal state
+ */
+interface ModelFileManager {
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+ * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
+ * there is no preference.
+ * @param detectedLocales an ordered list of locales detected from the Tcs request text, use
+ * {@code null} if no detected locales are provided
+ */
+ @Nullable
+ ModelFile findBestModelFile(
+ @ModelTypeDef String modelType,
+ @Nullable LocaleList localePreferences,
+ @Nullable LocaleList detectedLocales);
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ void dump(IndentingPrintWriter printWriter);
+}
diff --git a/java/src/com/android/textclassifier/ModelFileManagerImpl.java b/java/src/com/android/textclassifier/ModelFileManagerImpl.java
new file mode 100644
index 0000000..45426d0
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelFileManagerImpl.java
@@ -0,0 +1,453 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static java.lang.Math.min;
+
+import android.content.Context;
+import android.content.res.AssetManager;
+import android.os.LocaleList;
+import androidx.annotation.GuardedBy;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.common.ModelFile;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.ModelDownloadManager;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
+import com.google.common.base.Supplier;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import javax.annotation.Nullable;
+
+// TODO(licha): Consider making this a singleton class
+// TODO(licha): Check whether this is thread-safe
+/**
+ * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
+ * model files to load.
+ */
+final class ModelFileManagerImpl implements ModelFileManager {
+
+ private static final String TAG = "ModelFileManagerImpl";
+
+ private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
+ private static final String ASSETS_DIR = "textclassifier";
+
+ private ImmutableList<ModelFileLister> modelFileListers;
+
+ private final TextClassifierSettings settings;
+
+ public ModelFileManagerImpl(
+ Context context, ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
+
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(modelDownloadManager);
+
+ this.settings = Preconditions.checkNotNull(settings);
+
+ AssetManager assetManager = context.getAssets();
+ modelFileListers =
+ ImmutableList.of(
+ // Annotator models.
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR,
+ new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ANNOTATOR,
+ ASSETS_DIR,
+ "annotator\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // Actions models.
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ACTIONS_SUGGESTIONS,
+ ASSETS_DIR,
+ "actions_suggestions\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // LangID models.
+ new RegularFileFullMatchLister(
+ ModelType.LANG_ID,
+ new File(CONFIG_UPDATER_DIR, "lang_id.model"),
+ /* isEnabled= */ () -> settings.isConfigUpdaterModelEnabled()),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.LANG_ID,
+ ASSETS_DIR,
+ "lang_id.model",
+ /* isEnabled= */ () -> true),
+ new DownloaderModelsLister(modelDownloadManager, settings));
+ }
+
+ @VisibleForTesting
+ public ModelFileManagerImpl(
+ Context context, List<ModelFileLister> modelFileListers, TextClassifierSettings settings) {
+ this.modelFileListers = ImmutableList.copyOf(modelFileListers);
+ this.settings = settings;
+ }
+
+ public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
+ Preconditions.checkNotNull(modelType);
+
+ ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
+ for (ModelFileLister modelFileLister : modelFileListers) {
+ modelFiles.addAll(modelFileLister.list(modelType));
+ }
+ return modelFiles.build();
+ }
+
+ /** Lists model files. */
+ @FunctionalInterface
+ public interface ModelFileLister {
+ List<ModelFile> list(@ModelTypeDef String modelType);
+ }
+
+ /** Lists Downloader models */
+ public static class DownloaderModelsLister implements ModelFileLister {
+
+ private final ModelDownloadManager modelDownloadManager;
+ private final TextClassifierSettings settings;
+
+ /**
+ * @param modelDownloadManager manager of downloaded models
+ * @param settings current settings
+ */
+ public DownloaderModelsLister(
+ ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
+ this.modelDownloadManager = Preconditions.checkNotNull(modelDownloadManager);
+ this.settings = Preconditions.checkNotNull(settings);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ if (settings.isModelDownloadManagerEnabled()) {
+ for (File modelFile : modelDownloadManager.listDownloadedModels(modelType)) {
+ try {
+ // TODO(licha): Construct downloader model files with locale tag in our internal
+ // database
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(modelFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to create ModelFile: " + modelFile.getAbsolutePath(), e);
+ }
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists model files by performing full match on file path. */
+ public static class RegularFileFullMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File targetFile;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param targetFile the expected model file
+ * @param isEnabled whether this lister is enabled
+ */
+ public RegularFileFullMatchLister(
+ @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.targetFile = Preconditions.checkNotNull(targetFile);
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!targetFile.exists()) {
+ return ImmutableList.of();
+ }
+ try {
+ return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
+ }
+ return ImmutableList.of();
+ }
+ }
+
+ /** Lists model file in a specified folder by doing pattern matching on file names. */
+ public static class RegularFilePatternMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File folder;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param folder the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether the lister is enabled
+ */
+ public RegularFilePatternMatchLister(
+ @ModelTypeDef String modelType,
+ File folder,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.folder = Preconditions.checkNotNull(folder);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!folder.isDirectory()) {
+ return ImmutableList.of();
+ }
+ File[] files = folder.listFiles();
+ if (files == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (File file : files) {
+ final Matcher matcher = fileNamePattern.matcher(file.getName());
+ if (!matcher.matches() || !file.isFile()) {
+ continue;
+ }
+ try {
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists the model files preloaded in the APK file. */
+ public static class AssetFilePatternMatchLister implements ModelFileLister {
+ private final AssetManager assetManager;
+ private final String modelType;
+ private final String pathToList;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+ private final Object lock = new Object();
+ // Assets won't change without updating the app, so cache the result for performance reason.
+ @GuardedBy("lock")
+ private final Map<String, ImmutableList<ModelFile>> resultCache;
+
+ /**
+ * @param modelType the type of the model.
+ * @param pathToList the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether this lister is enabled
+ */
+ public AssetFilePatternMatchLister(
+ AssetManager assetManager,
+ @ModelTypeDef String modelType,
+ String pathToList,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.assetManager = Preconditions.checkNotNull(assetManager);
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.pathToList = Preconditions.checkNotNull(pathToList);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ resultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ synchronized (lock) {
+ if (resultCache.get(modelType) != null) {
+ return resultCache.get(modelType);
+ }
+ String[] fileNames = null;
+ try {
+ fileNames = assetManager.list(pathToList);
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to list assets", e);
+ }
+ if (fileNames == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (String fileName : fileNames) {
+ final Matcher matcher = fileNamePattern.matcher(fileName);
+ if (!matcher.matches()) {
+ continue;
+ }
+ String absolutePath =
+ new StringBuilder(pathToList).append('/').append(fileName).toString();
+ try {
+ modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to call createFromAsset with: " + absolutePath, e);
+ }
+ }
+ ImmutableList<ModelFile> result = modelFilesBuilder.build();
+ resultCache.put(modelType, result);
+ return result;
+ }
+ }
+ }
+
+ /**
+ * Returns the best locale matching the given detected locales and the default device localelist.
+ * Default locale returned if no matching locale is found.
+ *
+ * @param localePreferences list of optional locale preferences. Used if request contains
+ * preference and multi_language_support is disabled.
+ * @param detectedLocales ordered list of locales detected from Tcs request text, use {@code null}
+ * if no detected locales provided.
+ */
+ public Locale findBestModelLocale(
+ @Nullable LocaleList localePreferences, @Nullable LocaleList detectedLocales) {
+ if (!settings.isMultiLanguageSupportEnabled() || isEmptyLocaleList(detectedLocales)) {
+ return isEmptyLocaleList(localePreferences) ? Locale.getDefault() : localePreferences.get(0);
+ }
+ Locale bestLocale = Locale.getDefault();
+ LocaleList adjustedLocales = LocaleList.getAdjustedDefault();
+ // we only intersect detected locales with locales for which we have predownloaded models.
+ // Number of downlaoded locale models is determined by flag in tcs settings
+ int numberOfActiveModels = min(adjustedLocales.size(), settings.getMultiLanguageModelsLimit());
+ List<String> filteredDeviceLocales =
+ Splitter.on(",")
+ .splitToList(adjustedLocales.toLanguageTags())
+ .subList(0, numberOfActiveModels);
+ LocaleList filteredDeviceLocaleList =
+ LocaleList.forLanguageTags(String.join(",", filteredDeviceLocales));
+ List<Locale.LanguageRange> deviceLanguageRange =
+ Locale.LanguageRange.parse(filteredDeviceLocaleList.toLanguageTags());
+ for (int i = 0; i < detectedLocales.size(); i++) {
+ if (Locale.lookupTag(
+ deviceLanguageRange, ImmutableList.of(detectedLocales.get(i).getLanguage()))
+ != null) {
+ bestLocale = detectedLocales.get(i);
+ break;
+ }
+ }
+ return bestLocale;
+ }
+
+ @Nullable
+ @Override
+ public ModelFile findBestModelFile(
+ @ModelTypeDef String modelType,
+ @Nullable LocaleList localePreferences,
+ @Nullable LocaleList detectedLocales) {
+ Locale targetLocale = findBestModelLocale(localePreferences, detectedLocales);
+ // detectedLocales usually only contains 2-char language (e.g. en), while locale in
+ // localePreferences is usually complete (e.g. en_US). Log only if targetLocale is not a prefix.
+ if (!isEmptyLocaleList(localePreferences)
+ && !localePreferences.get(0).toString().startsWith(targetLocale.toString())) {
+ TcLog.d(
+ TAG,
+ String.format(
+ Locale.US,
+ "localePreference and targetLocale mismatch: preference: %s, target: %s",
+ localePreferences.get(0),
+ targetLocale));
+ }
+ return findBestModelFile(modelType, targetLocale);
+ }
+
+ /**
+ * Returns the best model file for the given locale, {@code null} if nothing is found.
+ *
+ * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+ * @param targetLocale the preferred locale from preferences or detected locales default locales
+ * if non given or detected.
+ */
+ @Nullable
+ private ModelFile findBestModelFile(@ModelTypeDef String modelType, Locale targetLocale) {
+ List<Locale.LanguageRange> deviceLanguageRanges =
+ Locale.LanguageRange.parse(LocaleList.getDefault().toLanguageTags());
+ boolean languageIndependentModelOnly = false;
+ if (Locale.lookupTag(deviceLanguageRanges, ImmutableList.of(targetLocale.getLanguage()))
+ == null) {
+ // If the targetLocale's language is not in device locale list, we don't match it to avoid
+ // leaking user language profile to the callers.
+ languageIndependentModelOnly = true;
+ }
+ List<Locale.LanguageRange> targetLanguageRanges =
+ Locale.LanguageRange.parse(targetLocale.toLanguageTag());
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles(modelType)) {
+ if (languageIndependentModelOnly && !model.languageIndependent) {
+ continue;
+ }
+ if (model.isAnyLanguageSupported(targetLanguageRanges)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /**
+ * Helpter function to check if LocaleList is null or empty
+ *
+ * @param localeList locale list to be checked
+ */
+ private static boolean isEmptyLocaleList(@Nullable LocaleList localeList) {
+ return localeList == null || localeList.isEmpty();
+ }
+
+ @Override
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("ModelFileManagerImpl:");
+ printWriter.increaseIndent();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ printWriter.println(modelType + " model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFile modelFile : listModelFiles(modelType)) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ }
+ printWriter.decreaseIndent();
+ }
+}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index bf326fb..2b6c396 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -43,11 +43,12 @@
import android.view.textclassifier.TextSelection;
import androidx.annotation.GuardedBy;
import androidx.annotation.WorkerThread;
+import androidx.collection.LruCache;
import androidx.core.util.Pair;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.TextSelectionCompat;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -63,6 +64,7 @@
import com.google.android.textclassifier.ActionsSuggestionsModel.ActionSuggestions;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
@@ -95,23 +97,26 @@
private final Object lock = new Object();
@GuardedBy("lock")
- private ModelFileManager.ModelFile annotatorModelInUse;
+ private ModelFile annotatorModelInUse;
@GuardedBy("lock")
private AnnotatorModel annotatorImpl;
@GuardedBy("lock")
- private ModelFileManager.ModelFile langIdModelInUse;
+ private ModelFile langIdModelInUse;
@GuardedBy("lock")
private LangIdModel langIdImpl;
@GuardedBy("lock")
- private ModelFileManager.ModelFile actionModelInUse;
+ private ModelFile actionModelInUse;
@GuardedBy("lock")
private ActionsSuggestionsModel actionsImpl;
+ @GuardedBy("lock")
+ private final LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
+
private final TextClassifierEventLogger textClassifierEventLogger =
new TextClassifierEventLogger();
@@ -121,10 +126,20 @@
TextClassifierImpl(
Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ this(
+ context, settings, modelFileManager, new LruCache<>(settings.getMultiAnnotatorCacheSize()));
+ }
+
+ @VisibleForTesting
+ public TextClassifierImpl(
+ Context context,
+ TextClassifierSettings settings,
+ ModelFileManager modelFileManager,
+ LruCache<ModelFile, AnnotatorModel> annotatorModelCache) {
this.context = Preconditions.checkNotNull(context);
this.settings = Preconditions.checkNotNull(settings);
this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
-
+ this.annotatorModelCache = annotatorModelCache;
generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
templateIntentFactory = new TemplateIntentFactory();
}
@@ -147,7 +162,10 @@
final String detectLanguageTags =
String.join(",", detectLanguageTags(langIdModel, request.getText()));
final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final LocaleList detectedLocaleList = LocaleList.forLanguageTags(detectLanguageTags);
+ final ModelFile annotatorModelInUse =
+ getAnnotatorModelFile(request.getDefaultLocales(), detectedLocaleList);
+ final AnnotatorModel annotatorImpl = loadAnnotatorModelFile(annotatorModelInUse);
final int[] startEnd =
annotatorImpl.suggestSelection(
string,
@@ -167,6 +185,8 @@
throw new IllegalArgumentException("Got bad indices for input text. Ignoring result.");
}
final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final boolean shouldIncludeTextClassification =
+ TextSelectionCompat.shouldIncludeTextClassification(request);
final AnnotatorModel.ClassificationResult[] results =
annotatorImpl.classifyText(
string,
@@ -177,18 +197,24 @@
.setReferenceTimezone(refTime.getZone().getId())
.setLocales(localesString)
.setDetectedTextLanguageTags(detectLanguageTags)
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
.setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
.build(),
- // Passing null here to suppress intent generation
+ // Passing null here to suppress intent generation.
// TODO: Use an explicit flag to suppress it.
- /* appContext */ null,
- /* deviceLocales */ null);
+ shouldIncludeTextClassification ? context : null,
+ getResourceLocalesString());
final int size = results.length;
for (int i = 0; i < size; i++) {
tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
}
final String resultId =
createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
+ if (shouldIncludeTextClassification) {
+ TextClassification textClassification =
+ createClassificationResult(results, string, start, end, langIdModel);
+ TextSelectionCompat.setTextClassification(tsBuilder, textClassification);
+ }
return tsBuilder.setId(resultId).build();
}
@@ -213,8 +239,10 @@
request.getReferenceTime() != null
? request.getReferenceTime()
: ZonedDateTime.now(ZoneId.systemDefault());
+ final LocaleList detectedLocaleList =
+ LocaleList.forLanguageTags(String.join(",", detectLanguageTags));
final AnnotatorModel.ClassificationResult[] results =
- getAnnotatorImpl(request.getDefaultLocales())
+ getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList)
.classifyText(
string,
request.getStartIndex(),
@@ -264,7 +292,10 @@
final String localesString = concatenateLocales(request.getDefaultLocales());
LangIdModel langId = getLangIdImpl();
ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final LocaleList detectedLocaleList =
+ LocaleList.forLanguageTags(String.join(",", detectLanguageTags));
+ final AnnotatorModel annotatorImpl =
+ getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList);
final boolean isSerializedEntityDataEnabled =
ExtrasUtils.isSerializedEntityDataEnabled(request);
final AnnotatorModel.AnnotatedSpan[] annotations =
@@ -394,7 +425,7 @@
null,
context,
getResourceLocalesString(),
- getAnnotatorImpl(LocaleList.getDefault()));
+ getAnnotatorImpl(LocaleList.getDefault(), /* detectedLocaleList= */ null));
return createConversationActionResult(request, nativeSuggestions);
}
@@ -460,33 +491,57 @@
return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
}
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws IOException {
+ private ModelFile getAnnotatorModelFile(
+ LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException {
+ final ModelFile bestModel =
+ modelFileManager.findBestModelFile(
+ ModelType.ANNOTATOR, requestLocaleList, detectedLocaleList);
+ if (bestModel == null) {
+ throw new IllegalStateException("Failed to find the best annotator model");
+ }
+ return bestModel;
+ }
+
+ private AnnotatorModel loadAnnotatorModelFile(ModelFile annotatorModelFile) throws IOException {
synchronized (lock) {
- localeList = localeList == null ? LocaleList.getDefault() : localeList;
- final ModelFileManager.ModelFile bestModel =
- modelFileManager.findBestModelFile(ModelType.ANNOTATOR, localeList);
- if (bestModel == null) {
- throw new IllegalStateException("Failed to find the best annotator model");
+ if (settings.getMultiAnnotatorCacheEnabled()
+ && !Objects.equals(annotatorModelInUse, annotatorModelFile)) {
+ TcLog.v(TAG, "Attempting to reload cached annotator model....");
+ annotatorImpl = annotatorModelCache.get(annotatorModelFile);
+ if (annotatorImpl != null) {
+ annotatorModelInUse = annotatorModelFile;
+ TcLog.v(TAG, "Successfully reloaded cached annotator model: " + annotatorModelFile);
+ }
}
- if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
+ if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, annotatorModelFile)) {
+ TcLog.d(TAG, "Loading " + annotatorModelFile);
// The current annotator model may be still used by another thread / model.
// Do not call close() here, and let the GC to clean it up when no one else
// is using it.
- try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ try (AssetFileDescriptor afd = annotatorModelFile.open(context.getAssets())) {
annotatorImpl = new AnnotatorModel(afd);
annotatorImpl.setLangIdModel(getLangIdImpl());
- annotatorModelInUse = bestModel;
+ annotatorModelInUse = annotatorModelFile;
+ if (settings.getMultiAnnotatorCacheEnabled()) {
+ annotatorModelCache.put(annotatorModelFile, annotatorImpl);
+ }
}
}
return annotatorImpl;
}
}
+ private AnnotatorModel getAnnotatorImpl(
+ LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException {
+ ModelFile annotatorModelFile = getAnnotatorModelFile(requestLocaleList, detectedLocaleList);
+ return loadAnnotatorModelFile(annotatorModelFile);
+ }
+
private LangIdModel getLangIdImpl() throws IOException {
synchronized (lock) {
- final ModelFileManager.ModelFile bestModel =
- modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localePreferences= */ null);
+ final ModelFile bestModel =
+ modelFileManager.findBestModelFile(
+ ModelType.LANG_ID, /* localePreferences= */ null, /* detectedLocales= */ null);
if (bestModel == null) {
throw new IllegalStateException("Failed to find the best LangID model.");
}
@@ -504,9 +559,9 @@
private ActionsSuggestionsModel getActionsImpl() throws IOException {
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
- final ModelFileManager.ModelFile bestModel =
+ final ModelFile bestModel =
modelFileManager.findBestModelFile(
- ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault());
+ ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault(), /* detectedLocales= */ null);
if (bestModel == null) {
throw new IllegalStateException("Failed to find the best actions model");
}
diff --git a/java/src/com/android/textclassifier/common/ModelFile.java b/java/src/com/android/textclassifier/common/ModelFile.java
new file mode 100644
index 0000000..28240ab
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/ModelFile.java
@@ -0,0 +1,241 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
+import android.os.LocaleList;
+import android.os.ParcelFileDescriptor;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.LangIdModel;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** Describes TextClassifier model files on disk. */
+public class ModelFile {
+ public static final String LANGUAGE_INDEPENDENT = "*";
+
+ @ModelTypeDef public final String modelType;
+ public final String absolutePath;
+ public final int version;
+ public final LocaleList supportedLocales;
+ public final boolean languageIndependent;
+ public final boolean isAsset;
+
+ public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
+ throws IOException {
+ ParcelFileDescriptor pfd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
+ return createFromAssetFileDescriptor(
+ file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
+ }
+ }
+
+ public static ModelFile createFromAsset(
+ AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
+ return createFromAssetFileDescriptor(
+ absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
+ }
+ }
+
+ private static ModelFile createFromAssetFileDescriptor(
+ String absolutePath,
+ @ModelTypeDef String modelType,
+ AssetFileDescriptor assetFileDescriptor,
+ boolean isAsset) {
+ ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
+ return new ModelFile(
+ modelType,
+ absolutePath,
+ modelInfoFetcher.getVersion(assetFileDescriptor),
+ modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
+ isAsset);
+ }
+
+ @VisibleForTesting
+ public ModelFile(
+ @ModelTypeDef String modelType,
+ String absolutePath,
+ int version,
+ String supportedLocaleTags,
+ boolean isAsset) {
+ this.modelType = modelType;
+ this.absolutePath = absolutePath;
+ this.version = version;
+ this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
+ this.supportedLocales =
+ languageIndependent
+ ? LocaleList.getEmptyLocaleList()
+ : LocaleList.forLanguageTags(supportedLocaleTags);
+ this.isAsset = isAsset;
+ }
+
+ /** Returns if this model file is preferred to the given one. */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!languageIndependent && model.languageIndependent) {
+ return true;
+ }
+ if (languageIndependent && !model.languageIndependent) {
+ return false;
+ }
+
+ // A higher-version model is preferred.
+ if (version > model.version) {
+ return true;
+ }
+ return false;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ if (languageIndependent) {
+ return true;
+ }
+ List<String> supportedLocaleTags = Arrays.asList(supportedLocales.toLanguageTags().split(","));
+ return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
+ }
+
+ public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
+ if (isAsset) {
+ return assetManager.openFd(absolutePath);
+ }
+ File file = new File(absolutePath);
+ ParcelFileDescriptor parcelFileDescriptor =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
+ }
+
+ public boolean canWrite() {
+ if (isAsset) {
+ return false;
+ }
+ return new File(absolutePath).canWrite();
+ }
+
+ public boolean delete() {
+ if (isAsset) {
+ throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
+ }
+ return new File(absolutePath).delete();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ModelFile)) {
+ return false;
+ }
+ ModelFile modelFile = (ModelFile) o;
+ return version == modelFile.version
+ && languageIndependent == modelFile.languageIndependent
+ && isAsset == modelFile.isAsset
+ && Objects.equals(modelType, modelFile.modelType)
+ && Objects.equals(absolutePath, modelFile.absolutePath)
+ && Objects.equals(supportedLocales, modelFile.supportedLocales);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
+ }
+
+ public ModelInfo toModelInfo() {
+ return new ModelInfo(
+ version, languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags());
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ Locale.US,
+ "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
+ modelType,
+ absolutePath,
+ version,
+ languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
+ isAsset);
+ }
+
+ public static ImmutableList<Optional<ModelInfo>> toModelInfos(Optional<ModelFile>... modelFiles) {
+ return Arrays.stream(modelFiles)
+ .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
+ .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
+ }
+
+ /** Fetch metadata of a model file. */
+ private static class ModelInfoFetcher {
+ private final Function<AssetFileDescriptor, Integer> versionFetcher;
+ private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
+
+ private ModelInfoFetcher(
+ Function<AssetFileDescriptor, Integer> versionFetcher,
+ Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
+ this.versionFetcher = versionFetcher;
+ this.supportedLocalesFetcher = supportedLocalesFetcher;
+ }
+
+ int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return versionFetcher.apply(assetFileDescriptor);
+ }
+
+ String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
+ return supportedLocalesFetcher.apply(assetFileDescriptor);
+ }
+
+ static ModelInfoFetcher create(@ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return new ModelInfoFetcher(
+ ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
+ case ModelType.LANG_ID:
+ return new ModelInfoFetcher(
+ LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
+ default: // fall out
+ }
+ throw new IllegalStateException("Unsupported model types");
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/ModelFileManager.java b/java/src/com/android/textclassifier/common/ModelFileManager.java
deleted file mode 100644
index 406a889..0000000
--- a/java/src/com/android/textclassifier/common/ModelFileManager.java
+++ /dev/null
@@ -1,603 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.common;
-
-import android.content.Context;
-import android.content.res.AssetFileDescriptor;
-import android.content.res.AssetManager;
-import android.os.LocaleList;
-import android.os.ParcelFileDescriptor;
-import android.util.ArraySet;
-import androidx.annotation.GuardedBy;
-import androidx.collection.ArrayMap;
-import com.android.textclassifier.common.ModelType.ModelTypeDef;
-import com.android.textclassifier.common.base.TcLog;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.android.textclassifier.utils.IndentingPrintWriter;
-import com.google.android.textclassifier.ActionsSuggestionsModel;
-import com.google.android.textclassifier.AnnotatorModel;
-import com.google.android.textclassifier.LangIdModel;
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Function;
-import com.google.common.base.Optional;
-import com.google.common.base.Preconditions;
-import com.google.common.base.Supplier;
-import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Objects;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-import java.util.stream.Collectors;
-import javax.annotation.Nullable;
-
-// TODO(licha): Consider making this a singleton class
-// TODO(licha): Check whether this is thread-safe
-/**
- * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
- * model files to load.
- */
-public final class ModelFileManager {
-
- private static final String TAG = "ModelFileManager";
-
- private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
- private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
- private static final String ASSETS_DIR = "textclassifier";
-
- private final List<ModelFileLister> modelFileListers;
- private final File modelDownloaderDir;
-
- public ModelFileManager(Context context, TextClassifierSettings settings) {
- Preconditions.checkNotNull(context);
- Preconditions.checkNotNull(settings);
-
- AssetManager assetManager = context.getAssets();
- this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
- modelFileListers =
- ImmutableList.of(
- // Annotator models.
- new RegularFilePatternMatchLister(
- ModelType.ANNOTATOR,
- this.modelDownloaderDir,
- "annotator\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
- new RegularFileFullMatchLister(
- ModelType.ANNOTATOR,
- new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
- /* isEnabled= */ () -> true),
- new AssetFilePatternMatchLister(
- assetManager,
- ModelType.ANNOTATOR,
- ASSETS_DIR,
- "annotator\\.(.*)\\.model",
- /* isEnabled= */ () -> true),
- // Actions models.
- new RegularFilePatternMatchLister(
- ModelType.ACTIONS_SUGGESTIONS,
- this.modelDownloaderDir,
- "actions_suggestions\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
- new RegularFileFullMatchLister(
- ModelType.ACTIONS_SUGGESTIONS,
- new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
- /* isEnabled= */ () -> true),
- new AssetFilePatternMatchLister(
- assetManager,
- ModelType.ACTIONS_SUGGESTIONS,
- ASSETS_DIR,
- "actions_suggestions\\.(.*)\\.model",
- /* isEnabled= */ () -> true),
- // LangID models.
- new RegularFilePatternMatchLister(
- ModelType.LANG_ID,
- this.modelDownloaderDir,
- "lang_id\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
- new RegularFileFullMatchLister(
- ModelType.LANG_ID,
- new File(CONFIG_UPDATER_DIR, "lang_id.model"),
- /* isEnabled= */ () -> true),
- new AssetFilePatternMatchLister(
- assetManager,
- ModelType.LANG_ID,
- ASSETS_DIR,
- "lang_id.model",
- /* isEnabled= */ () -> true));
- }
-
- @VisibleForTesting
- public ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
- this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
- this.modelFileListers = ImmutableList.copyOf(modelFileListers);
- }
-
- /**
- * Returns an immutable list of model files listed by the given model files supplier.
- *
- * @param modelType which type of model files to look for
- */
- public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
- Preconditions.checkNotNull(modelType);
-
- ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
- for (ModelFileLister modelFileLister : modelFileListers) {
- modelFiles.addAll(modelFileLister.list(modelType));
- }
- return modelFiles.build();
- }
-
- /** Lists model files. */
- public interface ModelFileLister {
- List<ModelFile> list(@ModelTypeDef String modelType);
- }
-
- /** Lists model files by performing full match on file path. */
- public static class RegularFileFullMatchLister implements ModelFileLister {
- private final String modelType;
- private final File targetFile;
- private final Supplier<Boolean> isEnabled;
-
- /**
- * @param modelType the type of the model
- * @param targetFile the expected model file
- * @param isEnabled whether this lister is enabled
- */
- public RegularFileFullMatchLister(
- @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
- this.modelType = Preconditions.checkNotNull(modelType);
- this.targetFile = Preconditions.checkNotNull(targetFile);
- this.isEnabled = Preconditions.checkNotNull(isEnabled);
- }
-
- @Override
- public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
- if (!this.modelType.equals(modelType)) {
- return ImmutableList.of();
- }
- if (!isEnabled.get()) {
- return ImmutableList.of();
- }
- if (!targetFile.exists()) {
- return ImmutableList.of();
- }
- try {
- return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
- } catch (IOException e) {
- TcLog.e(
- TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
- }
- return ImmutableList.of();
- }
- }
-
- /** Lists model file in a specified folder by doing pattern matching on file names. */
- public static class RegularFilePatternMatchLister implements ModelFileLister {
- private final String modelType;
- private final File folder;
- private final Pattern fileNamePattern;
- private final Supplier<Boolean> isEnabled;
-
- /**
- * @param modelType the type of the model
- * @param folder the folder to list files
- * @param fileNameRegex the regex to match the file name in the specified folder
- * @param isEnabled whether the lister is enabled
- */
- public RegularFilePatternMatchLister(
- @ModelTypeDef String modelType,
- File folder,
- String fileNameRegex,
- Supplier<Boolean> isEnabled) {
- this.modelType = Preconditions.checkNotNull(modelType);
- this.folder = Preconditions.checkNotNull(folder);
- this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
- this.isEnabled = Preconditions.checkNotNull(isEnabled);
- }
-
- @Override
- public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
- if (!this.modelType.equals(modelType)) {
- return ImmutableList.of();
- }
- if (!isEnabled.get()) {
- return ImmutableList.of();
- }
- if (!folder.isDirectory()) {
- return ImmutableList.of();
- }
- File[] files = folder.listFiles();
- if (files == null) {
- return ImmutableList.of();
- }
- ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
- for (File file : files) {
- final Matcher matcher = fileNamePattern.matcher(file.getName());
- if (!matcher.matches() || !file.isFile()) {
- continue;
- }
- try {
- modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
- } catch (IOException e) {
- TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
- }
- }
- return modelFilesBuilder.build();
- }
- }
-
- /** Lists the model files preloaded in the APK file. */
- public static class AssetFilePatternMatchLister implements ModelFileLister {
- private final AssetManager assetManager;
- private final String modelType;
- private final String pathToList;
- private final Pattern fileNamePattern;
- private final Supplier<Boolean> isEnabled;
- private final Object lock = new Object();
- // Assets won't change without updating the app, so cache the result for performance reason.
- @GuardedBy("lock")
- private final Map<String, ImmutableList<ModelFile>> resultCache;
-
- /**
- * @param modelType the type of the model.
- * @param pathToList the folder to list files
- * @param fileNameRegex the regex to match the file name in the specified folder
- * @param isEnabled whether this lister is enabled
- */
- public AssetFilePatternMatchLister(
- AssetManager assetManager,
- @ModelTypeDef String modelType,
- String pathToList,
- String fileNameRegex,
- Supplier<Boolean> isEnabled) {
- this.assetManager = Preconditions.checkNotNull(assetManager);
- this.modelType = Preconditions.checkNotNull(modelType);
- this.pathToList = Preconditions.checkNotNull(pathToList);
- this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
- this.isEnabled = Preconditions.checkNotNull(isEnabled);
- resultCache = new ArrayMap<>();
- }
-
- @Override
- public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
- if (!this.modelType.equals(modelType)) {
- return ImmutableList.of();
- }
- if (!isEnabled.get()) {
- return ImmutableList.of();
- }
- synchronized (lock) {
- if (resultCache.get(modelType) != null) {
- return resultCache.get(modelType);
- }
- String[] fileNames = null;
- try {
- fileNames = assetManager.list(pathToList);
- } catch (IOException e) {
- TcLog.e(TAG, "Failed to list assets", e);
- }
- if (fileNames == null) {
- return ImmutableList.of();
- }
- ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
- for (String fileName : fileNames) {
- final Matcher matcher = fileNamePattern.matcher(fileName);
- if (!matcher.matches()) {
- continue;
- }
- String absolutePath =
- new StringBuilder(pathToList).append('/').append(fileName).toString();
- try {
- modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
- } catch (IOException e) {
- TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
- }
- }
- ImmutableList<ModelFile> result = modelFilesBuilder.build();
- resultCache.put(modelType, result);
- return result;
- }
- }
- }
-
- /**
- * Returns the best model file for the given localelist, {@code null} if nothing is found.
- *
- * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
- * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
- * there is no preference.
- */
- @Nullable
- public ModelFile findBestModelFile(
- @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
- final String languages =
- localePreferences == null || localePreferences.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localePreferences.toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
-
- ModelFile bestModel = null;
- for (ModelFile model : listModelFiles(modelType)) {
- // TODO(licha): update this when we want to support multiple languages
- if (model.isAnyLanguageSupported(languageRangeList)) {
- if (model.isPreferredTo(bestModel)) {
- bestModel = model;
- }
- }
- }
- return bestModel;
- }
-
- /**
- * Deletes model files that are not preferred for any locales in user's preference.
- *
- * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
- * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
- * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
- */
- public void deleteUnusedModelFiles() {
- TcLog.d(TAG, "Start to delete unused model files.");
- LocaleList localeList = LocaleList.getDefault();
- for (@ModelTypeDef String modelType : ModelType.values()) {
- ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
- for (int i = 0; i < localeList.size(); i++) {
- // If a model file is preferred for any local in locale list, then keep it
- ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
- allModelFiles.remove(bestModel);
- }
- for (ModelFile modelFile : allModelFiles) {
- if (modelFile.canWrite()) {
- TcLog.d(TAG, "Deleting model: " + modelFile);
- if (!modelFile.delete()) {
- TcLog.w(TAG, "Failed to delete model: " + modelFile);
- }
- }
- }
- }
- }
-
- /** Returns the directory containing models downloaded by the downloader. */
- public File getModelDownloaderDir() {
- return modelDownloaderDir;
- }
-
- /**
- * Dumps the internal state for debugging.
- *
- * @param printWriter writer to write dumped states
- */
- public void dump(IndentingPrintWriter printWriter) {
- printWriter.println("ModelFileManager:");
- printWriter.increaseIndent();
- for (@ModelTypeDef String modelType : ModelType.values()) {
- printWriter.println(modelType + " model file(s):");
- printWriter.increaseIndent();
- for (ModelFile modelFile : listModelFiles(modelType)) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- }
- printWriter.decreaseIndent();
- }
-
- /** Fetch metadata of a model file. */
- private static class ModelInfoFetcher {
- private final Function<AssetFileDescriptor, Integer> versionFetcher;
- private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
-
- private ModelInfoFetcher(
- Function<AssetFileDescriptor, Integer> versionFetcher,
- Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
- this.versionFetcher = versionFetcher;
- this.supportedLocalesFetcher = supportedLocalesFetcher;
- }
-
- int getVersion(AssetFileDescriptor assetFileDescriptor) {
- return versionFetcher.apply(assetFileDescriptor);
- }
-
- String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
- return supportedLocalesFetcher.apply(assetFileDescriptor);
- }
-
- static ModelInfoFetcher create(@ModelTypeDef String modelType) {
- switch (modelType) {
- case ModelType.ANNOTATOR:
- return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
- case ModelType.ACTIONS_SUGGESTIONS:
- return new ModelInfoFetcher(
- ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
- case ModelType.LANG_ID:
- return new ModelInfoFetcher(
- LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
- default: // fall out
- }
- throw new IllegalStateException("Unsupported model types");
- }
- }
-
- /** Describes TextClassifier model files on disk. */
- public static class ModelFile {
- @VisibleForTesting static final String LANGUAGE_INDEPENDENT = "*";
-
- @ModelTypeDef public final String modelType;
- public final String absolutePath;
- public final int version;
- public final LocaleList supportedLocales;
- public final boolean languageIndependent;
- public final boolean isAsset;
-
- public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
- throws IOException {
- ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
- return createFromAssetFileDescriptor(
- file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
- }
- }
-
- public static ModelFile createFromAsset(
- AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
- throws IOException {
- try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
- return createFromAssetFileDescriptor(
- absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
- }
- }
-
- private static ModelFile createFromAssetFileDescriptor(
- String absolutePath,
- @ModelTypeDef String modelType,
- AssetFileDescriptor assetFileDescriptor,
- boolean isAsset) {
- ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
- return new ModelFile(
- modelType,
- absolutePath,
- modelInfoFetcher.getVersion(assetFileDescriptor),
- modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
- isAsset);
- }
-
- @VisibleForTesting
- ModelFile(
- @ModelTypeDef String modelType,
- String absolutePath,
- int version,
- String supportedLocaleTags,
- boolean isAsset) {
- this.modelType = modelType;
- this.absolutePath = absolutePath;
- this.version = version;
- this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
- this.supportedLocales =
- languageIndependent
- ? LocaleList.getEmptyLocaleList()
- : LocaleList.forLanguageTags(supportedLocaleTags);
- this.isAsset = isAsset;
- }
-
- /** Returns if this model file is preferred to the given one. */
- public boolean isPreferredTo(@Nullable ModelFile model) {
- // A model is preferred to no model.
- if (model == null) {
- return true;
- }
-
- // A language-specific model is preferred to a language independent
- // model.
- if (!languageIndependent && model.languageIndependent) {
- return true;
- }
- if (languageIndependent && !model.languageIndependent) {
- return false;
- }
-
- // A higher-version model is preferred.
- if (version > model.version) {
- return true;
- }
- return false;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- if (languageIndependent) {
- return true;
- }
- List<String> supportedLocaleTags =
- Arrays.asList(supportedLocales.toLanguageTags().split(","));
- return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
- }
-
- public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
- if (isAsset) {
- return assetManager.openFd(absolutePath);
- }
- File file = new File(absolutePath);
- ParcelFileDescriptor parcelFileDescriptor =
- ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
- }
-
- public boolean canWrite() {
- if (isAsset) {
- return false;
- }
- return new File(absolutePath).canWrite();
- }
-
- public boolean delete() {
- if (isAsset) {
- throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
- }
- return new File(absolutePath).delete();
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (!(o instanceof ModelFile)) {
- return false;
- }
- ModelFile modelFile = (ModelFile) o;
- return version == modelFile.version
- && languageIndependent == modelFile.languageIndependent
- && isAsset == modelFile.isAsset
- && Objects.equals(modelType, modelFile.modelType)
- && Objects.equals(absolutePath, modelFile.absolutePath)
- && Objects.equals(supportedLocales, modelFile.supportedLocales);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(
- modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
- }
-
- public ModelInfo toModelInfo() {
- return new ModelInfo(version, supportedLocales.toLanguageTags());
- }
-
- @Override
- public String toString() {
- return String.format(
- Locale.US,
- "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
- modelType,
- absolutePath,
- version,
- languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
- isAsset);
- }
-
- public static ImmutableList<Optional<ModelInfo>> toModelInfos(
- Optional<ModelFileManager.ModelFile>... modelFiles) {
- return Arrays.stream(modelFiles)
- .map(modelFile -> modelFile.transform(ModelFileManager.ModelFile::toModelInfo))
- .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
- }
- }
-}
diff --git a/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
index 43164e0..011ed4f 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
@@ -41,6 +41,21 @@
return LowPriorityExecutorHolder.lowPriorityExecutor;
}
+ /**
+ * Returns a single-thread executor with min priority. Used for downloader background processing.
+ */
+ public static ListeningExecutorService getDownloaderExecutor() {
+ return DownloaderExecutorHolder.downloaderExecutor;
+ }
+
+ /**
+ * Returns a single-thread executor with min priority for network IO ops. Currently only used by
+ * model downloader service.
+ */
+ public static ListeningExecutorService getNetworkIOExecutor() {
+ return NetworkIOExecutorHolder.networkIOExecutor;
+ }
+
private static class NormPriorityExecutorHolder {
static final ListeningExecutorService normPriorityExecutor =
init("tcs-norm-prio-executor-%d", Thread.NORM_PRIORITY, /* corePoolSize= */ 2);
@@ -51,6 +66,16 @@
init("tcs-low-prio-executor-%d", Thread.NORM_PRIORITY - 1, /* corePoolSize= */ 1);
}
+ private static class DownloaderExecutorHolder {
+ static final ListeningExecutorService downloaderExecutor =
+ init("tcs-download-executor-%d", Thread.MIN_PRIORITY, /* corePoolSize= */ 1);
+ }
+
+ private static class NetworkIOExecutorHolder {
+ static final ListeningExecutorService networkIOExecutor =
+ init("tcs-network-io-executor-%d", Thread.MIN_PRIORITY, /* corePoolSize= */ 1);
+ }
+
private static ListeningExecutorService init(String nameFormat, int priority, int corePoolSize) {
TcLog.v(TAG, "Creating executor: " + nameFormat);
return MoreExecutors.listeningDecorator(
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index fdf259e..205680d 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -20,6 +20,7 @@
import android.provider.DeviceConfig;
import android.provider.DeviceConfig.Properties;
+import android.text.TextUtils;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
@@ -27,6 +28,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -108,7 +110,8 @@
*/
private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
"detect_languages_from_text_enabled";
-
+ /** Whether to use models downloaded by config updater. */
+ private static final String CONFIG_UPDATER_MODEL_ENABLED = "config_updater_model_enabled";
/** Whether to enable model downloading with ModelDownloadManager */
@VisibleForTesting
public static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
@@ -117,13 +120,46 @@
"manifest_download_required_network_type";
/** Max attempts allowed for a single ModelDownloader downloading task. */
@VisibleForTesting
- static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
+ static final String MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS = "model_download_worker_max_attempts";
+ /** Max attempts allowed for a certain manifest url. */
+ @VisibleForTesting
+ public static final String MANIFEST_DOWNLOAD_MAX_ATTEMPTS = "manifest_download_max_attempts";
@VisibleForTesting
static final String MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS =
"model_download_backoff_delay_in_millis";
+
+ private static final String MANIFEST_DOWNLOAD_REQUIRES_CHARGING =
+ "manifest_download_requires_charging";
+ private static final String MANIFEST_DOWNLOAD_REQUIRES_DEVICE_IDLE =
+ "manifest_download_requires_device_idle";
+
/** Flag name for manifest url is dynamically formatted based on model type and model language. */
@VisibleForTesting public static final String MANIFEST_URL_TEMPLATE = "manifest_url_%s_%s";
+
+ @VisibleForTesting public static final String MODEL_URL_BLOCKLIST = "model_url_blocklist";
+ @VisibleForTesting public static final String MODEL_URL_BLOCKLIST_SEPARATOR = ",";
+
+ /** Flags to control multi-language support settings. */
+ @VisibleForTesting
+ public static final String MULTI_LANGUAGE_SUPPORT_ENABLED = "multi_language_support_enabled";
+
+ @VisibleForTesting
+ public static final String MULTI_LANGUAGE_MODELS_LIMIT = "multi_language_models_limit";
+
+ @VisibleForTesting
+ public static final String ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT =
+ "enabled_model_types_for_multi_language_support";
+
+ @VisibleForTesting
+ public static final String MULTI_ANNOTATOR_CACHE_ENABLED = "multi_annotator_cache_enabled";
+
+ private static final String MULTI_ANNOTATOR_CACHE_SIZE = "multi_annotator_cache_size";
+
+ /** List of locale tags to override LocaleList for TextClassifier. Testing/debugging only. */
+ @VisibleForTesting
+ public static final String TESTING_LOCALE_LIST_OVERRIDE = "testing_locale_list_override";
+
/** Sampling rate for TextClassifier API logging. */
static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
@@ -193,11 +229,23 @@
private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final boolean CONFIG_UPDATER_MODEL_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "UNMETERED";
- private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
+ private static final int MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT = 5;
+ private static final int MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 3;
private static final long MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT = HOURS.toMillis(1);
+ private static final boolean MANIFEST_DOWNLOAD_REQUIRES_DEVICE_IDLE_DEFAULT = false;
+ private static final boolean MANIFEST_DOWNLOAD_REQUIRES_CHARGING_DEFAULT = false;
+ private static final boolean MULTI_LANGUAGE_SUPPORT_ENABLED_DEFAULT = false;
+ private static final int MULTI_LANGUAGE_MODELS_LIMIT_DEFAULT = 2;
+ private static final ImmutableList<String>
+ ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT_DEFAULT =
+ ImmutableList.of(ModelType.ANNOTATOR);
+ private static final boolean MULTI_ANNOTATOR_CACHE_ENABLED_DEFAULT = false;
+ private static final int MULTI_ANNOTATOR_CACHE_SIZE_DEFAULT = 2;
private static final String MANIFEST_URL_DEFAULT = "";
+ private static final String TESTING_LOCALE_LIST_OVERRIDE_DEFAULT = "";
private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
/**
* Sampling rate for API logging. For example, 100 means there is a 0.01 chance that the API call
@@ -367,6 +415,11 @@
return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
}
+ public boolean isConfigUpdaterModelEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, CONFIG_UPDATER_MODEL_ENABLED, CONFIG_UPDATER_MODEL_ENABLED_DEFAULT);
+ }
+
public boolean isModelDownloadManagerEnabled() {
return deviceConfig.getBoolean(
NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
@@ -380,9 +433,14 @@
MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
}
- public int getModelDownloadMaxAttempts() {
+ public int getModelDownloadWorkerMaxAttempts() {
return deviceConfig.getInt(
- NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
+ NAMESPACE, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT);
+ }
+
+ public int getManifestDownloadMaxAttempts() {
+ return deviceConfig.getInt(
+ NAMESPACE, MANIFEST_DOWNLOAD_MAX_ATTEMPTS, MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
}
public long getModelDownloadBackoffDelayInMillis() {
@@ -392,39 +450,87 @@
MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT);
}
- /**
- * Get model's manifest url for given model type and language.
- *
- * @param modelType the type of model for the target url
- * @param modelLanguageTag the language tag for the model (e.g. en), but can also be "universal"
- * @return DeviceConfig configured url or empty string if not set
- */
- public String getManifestURL(@ModelType.ModelTypeDef String modelType, String modelLanguageTag) {
- // E.g: manifest_url_annotator_zh, manifest_url_lang_id_universal,
- // manifest_url_actions_suggestions_en
- String urlFlagName = String.format(MANIFEST_URL_TEMPLATE, modelType, modelLanguageTag);
- return deviceConfig.getString(NAMESPACE, urlFlagName, MANIFEST_URL_DEFAULT);
+ public boolean getManifestDownloadRequiresDeviceIdle() {
+ return deviceConfig.getBoolean(
+ NAMESPACE,
+ MANIFEST_DOWNLOAD_REQUIRES_DEVICE_IDLE,
+ MANIFEST_DOWNLOAD_REQUIRES_DEVICE_IDLE_DEFAULT);
}
+ public boolean getManifestDownloadRequiresCharging() {
+ return deviceConfig.getBoolean(
+ NAMESPACE,
+ MANIFEST_DOWNLOAD_REQUIRES_CHARGING,
+ MANIFEST_DOWNLOAD_REQUIRES_CHARGING_DEFAULT);
+ }
+
+ /* Gets a list of models urls that should not be used. Usually used for a quick rollback. */
+ public ImmutableList<String> getModelUrlBlocklist() {
+ return ImmutableList.copyOf(
+ Splitter.on(MODEL_URL_BLOCKLIST_SEPARATOR)
+ .split(deviceConfig.getString(NAMESPACE, MODEL_URL_BLOCKLIST, "")));
+ }
+
+ public boolean isMultiLanguageSupportEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, MULTI_LANGUAGE_SUPPORT_ENABLED, MULTI_LANGUAGE_SUPPORT_ENABLED_DEFAULT);
+ }
+
+ public int getMultiLanguageModelsLimit() {
+ return deviceConfig.getInt(
+ NAMESPACE, MULTI_LANGUAGE_MODELS_LIMIT, MULTI_LANGUAGE_MODELS_LIMIT_DEFAULT);
+ }
+
+ public List<String> getEnabledModelTypesForMultiLanguageSupport() {
+ return getDeviceConfigStringList(
+ ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT,
+ ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT_DEFAULT);
+ }
+
+ public boolean getMultiAnnotatorCacheEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, MULTI_ANNOTATOR_CACHE_ENABLED, MULTI_ANNOTATOR_CACHE_ENABLED_DEFAULT);
+ }
+
+ public int getMultiAnnotatorCacheSize() {
+ return deviceConfig.getInt(
+ NAMESPACE, MULTI_ANNOTATOR_CACHE_SIZE, MULTI_ANNOTATOR_CACHE_SIZE_DEFAULT);
+ }
/**
- * Gets all language variants configured for a specific ModelType.
+ * Gets all language variants and associated manifest url configured for a specific ModelType.
*
* <p>For a specific language, there can be many variants: de-CH, de-LI, zh-Hans, zh-Hant. There
* is no easy way to hardcode the list in client. Therefore, we parse all configured flag's name
* in DeviceConfig, and let the client to choose the best variant to download.
+ *
+ * <p>If one flag's value is empty, it will be ignored.
+ *
+ * @param modelType the type of model for the target url
+ * @return <localeTag, flagValue> map.
*/
- public ImmutableList<String> getLanguageTagsForManifestURL(
+ public ImmutableMap<String, String> getLanguageTagAndManifestUrlMap(
@ModelType.ModelTypeDef String modelType) {
String urlFlagBaseName = String.format(MANIFEST_URL_TEMPLATE, modelType, /* language */ "");
Properties properties = deviceConfig.getProperties(NAMESPACE);
- ImmutableList.Builder<String> variantsBuilder = ImmutableList.builder();
+ ImmutableMap.Builder<String, String> variantsMapBuilder = ImmutableMap.builder();
for (String name : properties.getKeyset()) {
- if (name.startsWith(urlFlagBaseName)
- && properties.getString(name, /* defaultValue= */ null) != null) {
- variantsBuilder.add(name.substring(urlFlagBaseName.length()));
+ if (!name.startsWith(urlFlagBaseName)) {
+ continue;
+ }
+ String value = properties.getString(name, /* defaultValue= */ null);
+ if (!TextUtils.isEmpty(value)) {
+ String modelLanguageTag = name.substring(urlFlagBaseName.length());
+ String urlFlagName = String.format(MANIFEST_URL_TEMPLATE, modelType, modelLanguageTag);
+ String urlFlagValue = deviceConfig.getString(NAMESPACE, urlFlagName, MANIFEST_URL_DEFAULT);
+ variantsMapBuilder.put(modelLanguageTag, urlFlagValue);
}
}
- return variantsBuilder.build();
+ return variantsMapBuilder.build();
+ }
+
+ public String getTestingLocaleListOverride() {
+ return deviceConfig.getString(
+ NAMESPACE, TESTING_LOCALE_LIST_OVERRIDE, TESTING_LOCALE_LIST_OVERRIDE_DEFAULT);
}
public int getTextClassifierApiLogSampleRate() {
@@ -457,8 +563,21 @@
pw.printPair(USER_LANGUAGE_PROFILE_ENABLED, isUserLanguageProfileEnabled());
pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
+ pw.printPair(CONFIG_UPDATER_MODEL_ENABLED, isConfigUpdaterModelEnabled());
pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
- pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
+ pw.printPair(MULTI_LANGUAGE_SUPPORT_ENABLED, isMultiLanguageSupportEnabled());
+ pw.printPair(MULTI_LANGUAGE_MODELS_LIMIT, getMultiLanguageModelsLimit());
+ pw.printPair(
+ ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT,
+ getEnabledModelTypesForMultiLanguageSupport());
+ pw.printPair(MULTI_ANNOTATOR_CACHE_ENABLED, getMultiAnnotatorCacheEnabled());
+ pw.printPair(MULTI_ANNOTATOR_CACHE_SIZE, getMultiAnnotatorCacheSize());
+ pw.printPair(MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE, getManifestDownloadRequiredNetworkType());
+ pw.printPair(MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, getModelDownloadWorkerMaxAttempts());
+ pw.printPair(MANIFEST_DOWNLOAD_MAX_ATTEMPTS, getManifestDownloadMaxAttempts());
+ pw.printPair(MANIFEST_DOWNLOAD_REQUIRES_CHARGING, getManifestDownloadRequiresCharging());
+ pw.printPair(MANIFEST_DOWNLOAD_REQUIRES_DEVICE_IDLE, getManifestDownloadRequiresDeviceIdle());
+ pw.printPair(TESTING_LOCALE_LIST_OVERRIDE, getTestingLocaleListOverride());
pw.decreaseIndent();
pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize());
diff --git a/java/src/com/android/textclassifier/common/TextSelectionCompat.java b/java/src/com/android/textclassifier/common/TextSelectionCompat.java
new file mode 100644
index 0000000..325bc5d
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/TextSelectionCompat.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextSelection;
+import androidx.annotation.RequiresApi;
+import androidx.core.os.BuildCompat;
+import javax.annotation.Nullable;
+
+/** Compatibility methods for {@link TextSelection}. */
+public final class TextSelectionCompat {
+
+ public static boolean shouldIncludeTextClassification(TextSelection.Request request) {
+ if (BuildCompat.isAtLeastS()) {
+ return Api31Impl.shouldIncludeTextClassification(request);
+ }
+ return Api30Impl.shouldIncludeTextClassification(request);
+ }
+
+ public static void setTextClassification(
+ TextSelection.Builder builder, @Nullable TextClassification textClassification) {
+ if (BuildCompat.isAtLeastS()) {
+ Api31Impl.setTextClassification(builder, textClassification);
+ }
+ }
+
+ private static final class Api30Impl {
+
+ private Api30Impl() {}
+
+ public static boolean shouldIncludeTextClassification(TextSelection.Request request) {
+ return false;
+ }
+ }
+
+ @RequiresApi(31)
+ private static final class Api31Impl {
+
+ private Api31Impl() {}
+
+ public static boolean shouldIncludeTextClassification(TextSelection.Request request) {
+ return request.shouldIncludeTextClassification();
+ }
+
+ public static void setTextClassification(
+ TextSelection.Builder builder, @Nullable TextClassification textClassification) {
+ builder.setTextClassification(textClassification);
+ }
+ }
+
+ private TextSelectionCompat() {}
+}
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
index dae0442..67e300d 100644
--- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
@@ -66,8 +66,8 @@
}
/** Returns if the result id was generated from the default text classifier. */
- public static boolean isFromDefaultTextClassifier(String resultId) {
- return resultId.startsWith(CLASSIFIER_ID + '|');
+ public static boolean isFromDefaultTextClassifier(@Nullable String resultId) {
+ return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|');
}
/** Returns all the model names encoded in the signature. */
diff --git a/java/src/com/android/textclassifier/downloader/DownloadFileType.java b/java/src/com/android/textclassifier/downloader/DownloadFileType.java
new file mode 100644
index 0000000..c070102
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadFileType.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import androidx.annotation.IntDef;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
+/** Effectively an enum class to represent types of files to be downloaded. */
+final class DownloadFileType {
+ /** File types to be downloaded for TextClassifier. */
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({UNKNOWN, MANIFEST, MODEL})
+ public @interface DownloadFileTypeDef {}
+
+ public static final int UNKNOWN = 0;
+ public static final int MANIFEST = 1;
+ public static final int MODEL = 2;
+
+ private DownloadFileType() {}
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
new file mode 100644
index 0000000..37430d8
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
@@ -0,0 +1,373 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import androidx.annotation.IntDef;
+import androidx.annotation.NonNull;
+import androidx.room.ColumnInfo;
+import androidx.room.Dao;
+import androidx.room.Database;
+import androidx.room.DatabaseView;
+import androidx.room.Delete;
+import androidx.room.Embedded;
+import androidx.room.Entity;
+import androidx.room.ForeignKey;
+import androidx.room.Index;
+import androidx.room.Insert;
+import androidx.room.OnConflictStrategy;
+import androidx.room.Query;
+import androidx.room.RoomDatabase;
+import androidx.room.Transaction;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.Iterables;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+
+/** Database storing info about models downloaded by model downloader */
+@Database(
+ entities = {
+ DownloadedModelDatabase.Model.class,
+ DownloadedModelDatabase.Manifest.class,
+ DownloadedModelDatabase.ManifestModelCrossRef.class,
+ DownloadedModelDatabase.ManifestEnrollment.class
+ },
+ views = {DownloadedModelDatabase.ModelView.class},
+ version = 1,
+ exportSchema = true)
+abstract class DownloadedModelDatabase extends RoomDatabase {
+ public static final String TAG = "DownloadedModelDatabase";
+
+ /** Rpresents a downloaded model file. */
+ @AutoValue
+ @Entity(
+ tableName = "model",
+ primaryKeys = {"model_url"})
+ abstract static class Model {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_url")
+ @NonNull
+ public abstract String getModelUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_path")
+ @NonNull
+ public abstract String getModelPath();
+
+ public static Model create(String modelUrl, String modelPath) {
+ return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath);
+ }
+ }
+
+ /** Rpresents a manifest we processed. */
+ @AutoValue
+ @Entity(
+ tableName = "manifest",
+ primaryKeys = {"manifest_url"})
+ abstract static class Manifest {
+ // TODO(licha): Consider using Enum here
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED})
+ @interface StatusDef {}
+
+ public static final int STATUS_UNKNOWN = 0;
+ /** Failed to download this manifest. Could be retried in the future. */
+ public static final int STATUS_FAILED = 1;
+ /** Downloaded this manifest successfully and it's currently in storage. */
+ public static final int STATUS_SUCCEEDED = 2;
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "status")
+ @StatusDef
+ public abstract int getStatus();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "failure_counts")
+ public abstract int getFailureCounts();
+
+ public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) {
+ return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts);
+ }
+ }
+
+ /**
+ * Represents the relationship between manfiests and downloaded models.
+ *
+ * <p>A manifest can include multiple models, a model can also be included in multiple manifests.
+ * In different manifests, a model may have different configurations (e.g. primary model in
+ * manfiest A but dark model in B).
+ */
+ @AutoValue
+ @Entity(
+ tableName = "manifest_model_cross_ref",
+ primaryKeys = {"manifest_url", "model_url"},
+ foreignKeys = {
+ @ForeignKey(
+ entity = Manifest.class,
+ parentColumns = "manifest_url",
+ childColumns = "manifest_url",
+ onDelete = ForeignKey.CASCADE),
+ @ForeignKey(
+ entity = Model.class,
+ parentColumns = "model_url",
+ childColumns = "model_url",
+ onDelete = ForeignKey.CASCADE),
+ },
+ indices = {
+ @Index(value = {"manifest_url"}),
+ @Index(value = {"model_url"}),
+ })
+ abstract static class ManifestModelCrossRef {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_url")
+ @NonNull
+ public abstract String getModelUrl();
+
+ public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) {
+ return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl);
+ }
+ }
+
+ /**
+ * Represents the relationship between user scenarios and manifests.
+ *
+ * <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should
+ * use. The same manifest can be used for different scenarios.
+ */
+ @AutoValue
+ @Entity(
+ tableName = "manifest_enrollment",
+ primaryKeys = {"model_type", "locale_tag"},
+ foreignKeys = {
+ @ForeignKey(
+ entity = Manifest.class,
+ parentColumns = "manifest_url",
+ childColumns = "manifest_url",
+ onDelete = ForeignKey.CASCADE)
+ },
+ indices = {@Index(value = {"manifest_url"})})
+ abstract static class ManifestEnrollment {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_type")
+ @NonNull
+ @ModelTypeDef
+ public abstract String getModelType();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "locale_tag")
+ @NonNull
+ public abstract String getLocaleTag();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ public static ManifestEnrollment create(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ return new AutoValue_DownloadedModelDatabase_ManifestEnrollment(
+ modelType, localeTag, manifestUrl);
+ }
+ }
+
+ /** Represents the mapping from manfiest enrollments to models. */
+ @AutoValue
+ @DatabaseView(
+ value =
+ "SELECT manifest_enrollment.*, model.* "
+ + "FROM manifest_enrollment "
+ + "INNER JOIN manifest_model_cross_ref "
+ + "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url "
+ + "INNER JOIN model "
+ + "ON manifest_model_cross_ref.model_url = model.model_url",
+ viewName = "model_view")
+ abstract static class ModelView {
+ @AutoValue.CopyAnnotations
+ @Embedded
+ @NonNull
+ public abstract ManifestEnrollment getManifestEnrollment();
+
+ @AutoValue.CopyAnnotations
+ @Embedded
+ @NonNull
+ public abstract Model getModel();
+
+ public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) {
+ return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model);
+ }
+ }
+
+ @Dao
+ abstract static class DownloadedModelDatabaseDao {
+ // Full table scan
+ @Query("SELECT * FROM model")
+ abstract List<Model> queryAllModels();
+
+ @Query("SELECT * FROM manifest")
+ abstract List<Manifest> queryAllManifests();
+
+ @Query("SELECT * FROM manifest_model_cross_ref")
+ abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs();
+
+ @Query("SELECT * FROM manifest_enrollment")
+ abstract List<ManifestEnrollment> queryAllManifestEnrollments();
+
+ @Query("SELECT * FROM model_view")
+ abstract List<ModelView> queryAllModelViews();
+
+ // Single table query
+ @Query("SELECT * FROM model WHERE model_url = :modelUrl")
+ abstract List<Model> queryModelWithModelUrl(String modelUrl);
+
+ @Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl")
+ abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl);
+
+ @Query(
+ "SELECT * FROM manifest_enrollment WHERE model_type = :modelType "
+ + "AND locale_tag = :localeTag")
+ abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag(
+ @ModelTypeDef String modelType, String localeTag);
+
+ // Helpers for clean up
+ @Query(
+ "SELECT manifest.* FROM manifest "
+ + "LEFT JOIN model_view "
+ + "ON manifest.manifest_url = model_view.manifest_url "
+ + "WHERE model_view.manifest_url IS NULL "
+ + "AND manifest.status = 2")
+ abstract List<Manifest> queryUnusedManifests();
+
+ @Query(
+ "SELECT * FROM manifest WHERE manifest.status = 1 "
+ + "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)")
+ abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep);
+
+ @Query(
+ "SELECT model.* FROM model LEFT JOIN model_view "
+ + "ON model.model_url = model_view.model_url "
+ + "WHERE model_view.model_url IS NULL")
+ abstract List<Model> queryUnusedModels();
+
+ // Insertion
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(Model model);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(Manifest manifest);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(ManifestModelCrossRef manifestModelCrossRef);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(ManifestEnrollment manifestEnrollment);
+
+ @Transaction
+ void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) {
+ insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+ insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+ }
+
+ @Transaction
+ void increaseManifestFailureCounts(String manifestUrl) {
+ List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl);
+ if (manifests.isEmpty()) {
+ insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1));
+ } else {
+ Manifest prevManifest = Iterables.getOnlyElement(manifests);
+ insert(
+ Manifest.create(
+ manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1));
+ }
+ }
+
+ // Deletion
+ @Delete
+ abstract void deleteModels(List<Model> models);
+
+ @Delete
+ abstract void deleteManifests(List<Manifest> manifests);
+
+ @Delete
+ abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs);
+
+ @Delete
+ abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments);
+
+ @Transaction
+ void deleteUnusedManifestsAndModels() {
+ // Because Manifest table is the parent table of ManifestModelCrossRef table, related cross
+ // ref row in that table will be deleted automatically
+ deleteManifests(queryUnusedManifests());
+ deleteModels(queryUnusedModels());
+ }
+
+ @Transaction
+ void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) {
+ deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep));
+ }
+ }
+
+ abstract DownloadedModelDatabaseDao dao();
+
+ /** Dump the database for debugging. */
+ void dump(IndentingPrintWriter printWriter, ExecutorService executorService) {
+ printWriter.println("DownloadedModelDatabase");
+ printWriter.increaseIndent();
+ printWriter.println("Model Table:");
+ printWriter.increaseIndent();
+ List<Model> models = dao().queryAllModels();
+ for (Model model : models) {
+ printWriter.println(model.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("Manifest Table:");
+ printWriter.increaseIndent();
+ List<Manifest> manifests = dao().queryAllManifests();
+ for (Manifest manifest : manifests) {
+ printWriter.println(manifest.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("ManifestModelCrossRef Table:");
+ printWriter.increaseIndent();
+ List<ManifestModelCrossRef> manifestModelCrossRefs = dao().queryAllManifestModelCrossRefs();
+ for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) {
+ printWriter.println(manifestModelCrossRef.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("ManifestEnrollment Table:");
+ printWriter.increaseIndent();
+ List<ManifestEnrollment> manifestEnrollments = dao().queryAllManifestEnrollments();
+ for (ManifestEnrollment manifestEnrollment : manifestEnrollments) {
+ printWriter.println(manifestEnrollment.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.decreaseIndent();
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
new file mode 100644
index 0000000..84440d0
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
@@ -0,0 +1,136 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import java.util.List;
+import javax.annotation.Nullable;
+
+// TODO(licha): Let Worker access DB class directly, then we can make this a lister interface
+/** An interface to provide easy access to DownloadedModelDatabase. */
+public interface DownloadedModelManager {
+
+ /** Returns the directory containing models downloaded by the downloader. */
+ File getModelDownloaderDir();
+
+ /**
+ * Returns all downloaded model files for the given modelType
+ *
+ * <p>This method should return quickly as it may be on the critical path of serving requests.
+ *
+ * @param modelType the type of the model
+ * @return the model files. Empty if no suitable model found
+ */
+ @Nullable
+ List<File> listModels(@ModelTypeDef String modelType);
+
+ /**
+ * Returns the model entry if the model represented by the url is in our database.
+ *
+ * @param modelUrl the model url
+ * @return model entry from internal database, null if not exist
+ */
+ @Nullable
+ Model getModel(String modelUrl);
+
+ /**
+ * Returns the manifest entry if the manifest represented by the url is in our database.
+ *
+ * @param manifestUrl the manifest url
+ * @return manifest entry from internal database, null if not exist
+ */
+ @Nullable
+ Manifest getManifest(String manifestUrl);
+
+ /**
+ * Returns the manifest enrollment entry if a manifest is registered for the given type and
+ * locale.
+ *
+ * @param modelType the model type of the enrollment
+ * @param localeTag the locale tag of the enrollment
+ * @return manifest enrollment entry from internal database, null if not exist
+ */
+ @Nullable
+ ManifestEnrollment getManifestEnrollment(@ModelTypeDef String modelType, String localeTag);
+
+ /**
+ * Add a newly downloaded model to the internal database.
+ *
+ * <p>The model must be linked to a manifest via #registerManifest(). Otherwise it will be cleaned
+ * up automatically later.
+ *
+ * @param modelUrl the url where we downloaded model from
+ * @param modelPath the path where we store the downloaded model
+ */
+ void registerModel(String modelUrl, String modelPath);
+
+ /**
+ * Add a newly downloaded manifest to the internal database.
+ *
+ * <p>The manifest must be linked to a specific use case via #registerManifestEnrollment().
+ * Otherwise it will be cleaned up automatically later. Currently there is only one model in one
+ * manifest.
+ *
+ * @param manifestUrl the url where we downloaded manifest
+ * @param modelUrl the url where we downloaded the only model inside the manifest
+ */
+ void registerManifest(String manifestUrl, String modelUrl);
+
+ /**
+ * Add a failure records for the given manifest url.
+ *
+ * <p>If the manifest failed before, then increase the prevFailureCounts by one. We skip manifest
+ * if it failed too many times before.
+ *
+ * @param manifestUrl the failed manifest url
+ */
+ void registerManifestDownloadFailure(String manifestUrl);
+
+ /**
+ * Link a manifest to a specific (modelType, localeTag) use case.
+ *
+ * <p>After this registration, we will start to use this model file for all requests for the given
+ * locale and the specified model type.
+ *
+ * @param modelType the model type
+ * @param localeTag the tag of the locale on user's device that this manifest should be used for
+ * @param manifestUrl the url of the manifest
+ */
+ void registerManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl);
+
+ /**
+ * Clean up unused downloaded models and update other internal states.
+ *
+ * @param manifestsToDownload Map<modelType, manifestsToDownloadMyType> that the worker tried to
+ * download
+ */
+ void onDownloadCompleted(ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload);
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ void dump(IndentingPrintWriter printWriter);
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
new file mode 100644
index 0000000..9bdfb5e
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -0,0 +1,301 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.content.Context;
+import android.util.ArrayMap;
+import androidx.annotation.GuardedBy;
+import androidx.room.Room;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** A singleton implementation of DownloadedModelManager. */
+public final class DownloadedModelManagerImpl implements DownloadedModelManager {
+ private static final String TAG = "DownloadedModelManagerImpl";
+ private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models";
+ private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db";
+
+ private static final Object staticLock = new Object();
+
+ @GuardedBy("staticLock")
+ private static DownloadedModelManagerImpl instance;
+
+ private final File modelDownloaderDir;
+ private final DownloadedModelDatabase db;
+ private final TextClassifierSettings settings;
+
+ private final Object cacheLock = new Object();
+
+ // modeltype -> downloaded model files
+ @GuardedBy("cacheLock")
+ private final ArrayMap<String, List<Model>> modelLookupCache;
+
+ @GuardedBy("cacheLock")
+ private boolean cacheInitialized;
+
+ @Nullable
+ public static DownloadedModelManager getInstance(Context context) {
+ synchronized (staticLock) {
+ if (instance == null) {
+ DownloadedModelDatabase db =
+ Room.databaseBuilder(
+ context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME)
+ .build();
+ File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ instance =
+ new DownloadedModelManagerImpl(db, modelDownloaderDir, new TextClassifierSettings());
+ }
+ return instance;
+ }
+ }
+
+ @VisibleForTesting
+ static DownloadedModelManagerImpl getInstanceForTesting(
+ DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+ return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings);
+ }
+
+ private DownloadedModelManagerImpl(
+ DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+ this.db = db;
+ this.modelDownloaderDir = modelDownloaderDir;
+ this.modelLookupCache = new ArrayMap<>();
+ for (String modelType : ModelType.values()) {
+ this.modelLookupCache.put(modelType, new ArrayList<>());
+ }
+ this.settings = settings;
+ this.cacheInitialized = false;
+ }
+
+ @Override
+ public File getModelDownloaderDir() {
+ if (!modelDownloaderDir.exists()) {
+ modelDownloaderDir.mkdirs();
+ }
+ return modelDownloaderDir;
+ }
+
+ @Override
+ @Nullable
+ public ImmutableList<File> listModels(@ModelTypeDef String modelType) {
+ synchronized (cacheLock) {
+ if (!cacheInitialized) {
+ updateCache();
+ }
+ ImmutableList.Builder<File> builder = ImmutableList.builder();
+ ImmutableList<String> blockedModels = settings.getModelUrlBlocklist();
+ for (Model model : modelLookupCache.get(modelType)) {
+ if (blockedModels.contains(model.getModelUrl())) {
+ TcLog.d(TAG, "Model is blocklisted: " + model);
+ continue;
+ }
+ builder.add(new File(model.getModelPath()));
+ }
+ return builder.build();
+ }
+ }
+
+ @Override
+ @Nullable
+ public Model getModel(String modelUrl) {
+ List<Model> models = db.dao().queryModelWithModelUrl(modelUrl);
+ return Iterables.getFirst(models, null);
+ }
+
+ @Override
+ @Nullable
+ public Manifest getManifest(String manifestUrl) {
+ List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl);
+ return Iterables.getFirst(manifests, null);
+ }
+
+ @Override
+ @Nullable
+ public ManifestEnrollment getManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag) {
+ List<ManifestEnrollment> manifestEnrollments =
+ db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag);
+ return Iterables.getFirst(manifestEnrollments, null);
+ }
+
+ @Override
+ public void registerModel(String modelUrl, String modelPath) {
+ db.dao().insert(Model.create(modelUrl, modelPath));
+ }
+
+ @Override
+ public void registerManifest(String manifestUrl, String modelUrl) {
+ db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl);
+ }
+
+ @Override
+ public void registerManifestDownloadFailure(String manifestUrl) {
+ db.dao().increaseManifestFailureCounts(manifestUrl);
+ }
+
+ @Override
+ public void registerManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+ }
+
+ @Override
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("DownloadedModelManagerImpl:");
+ printWriter.increaseIndent();
+ db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor());
+ printWriter.println("ModelLookupCache:");
+ synchronized (cacheLock) {
+ for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) {
+ printWriter.println(entry.getKey());
+ printWriter.increaseIndent();
+ for (Model model : entry.getValue()) {
+ printWriter.println(model.toString());
+ }
+ printWriter.decreaseIndent();
+ }
+ }
+ printWriter.decreaseIndent();
+ }
+
+ @Override
+ public void onDownloadCompleted(
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
+ TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
+ // Step 1: Clean up ManifestEnrollment table
+ List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
+ List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
+ for (String modelType : ModelType.values()) {
+ List<ManifestEnrollment> manifestEnrollmentsByType =
+ allManifestEnrollments.stream()
+ .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType))
+ .collect(Collectors.toList());
+ ManifestsToDownloadByType manifestsToDownloadByType = manifestsToDownload.get(modelType);
+
+ if (manifestsToDownloadByType == null) {
+ // No suitable manifests configured for this model type. Delete everything.
+ manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
+ continue;
+ }
+ ImmutableMap<String, String> localeTagToManifestUrl =
+ manifestsToDownloadByType.localeTagToManifestUrl();
+
+ boolean allModelsDownloaded = true;
+ for (Map.Entry<String, String> entry : localeTagToManifestUrl.entrySet()) {
+ String localeTag = entry.getKey();
+ String manifestUrl = entry.getValue();
+ Optional<ManifestEnrollment> manifestEnrollmentForLocaleTagAndManifestUrl =
+ manifestEnrollmentsByType.stream()
+ .filter(
+ manifestEnrollment ->
+ manifestEnrollment.getLocaleTag().equals(localeTag)
+ && manifestEnrollment.getManifestUrl().equals(manifestUrl))
+ .findAny();
+ if (!manifestEnrollmentForLocaleTagAndManifestUrl.isPresent()) {
+ // The desired manifest failed to be downloaded.
+ TcLog.w(
+ TAG,
+ String.format(
+ "Desired manifest is missing on download completed: %s, %s, %s",
+ modelType, localeTag, manifestUrl));
+ allModelsDownloaded = false;
+ }
+ }
+ if (allModelsDownloaded) {
+ // Delete unused manifest enrollments.
+ manifestEnrollmentsToDelete.addAll(
+ manifestEnrollmentsByType.stream()
+ .filter(
+ manifestEnrollment ->
+ !manifestEnrollment
+ .getManifestUrl()
+ .equals(localeTagToManifestUrl.get(manifestEnrollment.getLocaleTag())))
+ .collect(Collectors.toList()));
+ } else {
+ // TODO(licha): We may still need to delete models here. E.g. we are switching from en to
+ // zh. Although we fail to download zh model, we still want to delete en models.
+ TcLog.w(
+ TAG, "Unused models were not deleted because downloading of at least one model failed");
+ }
+ }
+ db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete);
+ // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment
+ db.dao().deleteUnusedManifestsAndModels();
+ // Step 3: Clean up Manifest failure records
+ // We only keep a failure record if the worker stills trys to download it
+ // We restrict the deletion to failure records only because although some manifest urls are not
+ // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we
+ // failed to download v902. v901 will not be in the map, but it should be kept.)
+ List<String> allAttemptedManifestUrls =
+ manifestsToDownload.entrySet().stream()
+ .flatMap(
+ entry ->
+ entry.getValue().localeTagToManifestUrl().entrySet().stream()
+ .map(Map.Entry::getValue))
+ .collect(Collectors.toList());
+ db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls);
+ // Step 4: Update lookup cache
+ updateCache();
+ // Step 5: Clean up unused model files.
+ Set<String> modelPathsToKeep =
+ db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet());
+ for (File modelFile : getModelDownloaderDir().listFiles()) {
+ if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) {
+ TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath());
+ if (!modelFile.delete()) {
+ TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath());
+ }
+ }
+ }
+ }
+
+ // Clear the cache table and rebuild the cache based on ModelView table
+ private void updateCache() {
+ synchronized (cacheLock) {
+ TcLog.d(TAG, "Updating model lookup cache...");
+ for (String modelType : ModelType.values()) {
+ modelLookupCache.get(modelType).clear();
+ }
+ for (ModelView modelView : db.dao().queryAllModelViews()) {
+ modelLookupCache
+ .get(modelView.getManifestEnrollment().getModelType())
+ .add(modelView.getModel());
+ }
+ cacheInitialized = true;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/IModelDownloaderCallback.aidl b/java/src/com/android/textclassifier/downloader/IModelDownloaderCallback.aidl
new file mode 100644
index 0000000..2ea744a
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/IModelDownloaderCallback.aidl
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+/**
+ * Callback for download requests from ModelDownloaderImpl to
+ * ModelDownloaderService.
+ */
+oneway interface IModelDownloaderCallback {
+
+ void onSuccess(long bytesWritten);
+
+ void onFailure(int errorCode, String errorMsg);
+}
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/downloader/IModelDownloaderService.aidl b/java/src/com/android/textclassifier/downloader/IModelDownloaderService.aidl
new file mode 100644
index 0000000..007fcbc
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/IModelDownloaderService.aidl
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import com.android.textclassifier.downloader.IModelDownloaderCallback;
+
+/**
+ * ModelDownloaderService binder interface.
+ */
+oneway interface IModelDownloaderService {
+
+ /**
+ * @param url the full url to download model from
+ * @param targetFilePath the absolute file path for the destination file
+ * @param callback callback to notify caller the downloader result
+ */
+ void download(
+ String url, String targetFilePath, IModelDownloaderCallback callback);
+}
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/downloader/LocaleUtils.java b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
new file mode 100644
index 0000000..79bc529
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.text.TextUtils;
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
+import java.util.Collection;
+import java.util.Locale;
+import javax.annotation.Nullable;
+
+/** Utilities for locale matching. */
+final class LocaleUtils {
+ @VisibleForTesting static final String UNIVERSAL_LOCALE_TAG = "universal";
+
+ /**
+ * Find the best locale tag as well as the configured manfiest url from device config.
+ *
+ * @param modelType the model type
+ * @param targetLocale target locale
+ * @param settings TextClassifierSettings to check device config
+ * @return a pair of <bestLocaleTag, manfiestUrl>. Null if not found.
+ */
+ @Nullable
+ static Pair<String, String> lookupBestLocaleTagAndManifestUrl(
+ @ModelTypeDef String modelType, Locale targetLocale, TextClassifierSettings settings) {
+ ImmutableMap<String, String> localeTagUrlMap =
+ settings.getLanguageTagAndManifestUrlMap(modelType);
+ Collection<String> allLocaleTags = localeTagUrlMap.keySet();
+ String bestLocaleTag = lookupBestLocaleTag(targetLocale, allLocaleTags);
+ if (bestLocaleTag == null) {
+ return null;
+ }
+ String manifestUrl = localeTagUrlMap.get(bestLocaleTag);
+ if (TextUtils.isEmpty(manifestUrl)) {
+ return null;
+ }
+ return Pair.create(bestLocaleTag, manifestUrl);
+ }
+ /** Find the best locale tag for the target locale. Return null if no one is suitable. */
+ @Nullable
+ static String lookupBestLocaleTag(Locale targetLocale, Collection<String> availableTags) {
+ // Notice: this lookup API just trys to match the longest prefix for the target locale tag.
+ // Its implementation looks inefficient and the behavior may not be 100% desired. E.g. if the
+ // target locale is en, and we only have en-uk in available tags, the current API returns null.
+ String bestTag =
+ Locale.lookupTag(Locale.LanguageRange.parse(targetLocale.toLanguageTag()), availableTags);
+ if (bestTag != null) {
+ return bestTag;
+ }
+ if (availableTags.contains(UNIVERSAL_LOCALE_TAG)) {
+ return UNIVERSAL_LOCALE_TAG;
+ }
+ return null;
+ }
+
+ private LocaleUtils() {}
+}
diff --git a/java/src/com/android/textclassifier/downloader/ManifestsToDownloadByType.java b/java/src/com/android/textclassifier/downloader/ManifestsToDownloadByType.java
new file mode 100644
index 0000000..076dfb0
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ManifestsToDownloadByType.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.ImmutableMap;
+
+/** Stores manifests to be downloaded on a given model type */
+@AutoValue
+public abstract class ManifestsToDownloadByType {
+ public static ManifestsToDownloadByType create(
+ ImmutableMap<String, String> localeTagToManifestUrl) {
+ return new AutoValue_ManifestsToDownloadByType(localeTagToManifestUrl);
+ }
+
+ public abstract ImmutableMap<String, String> localeTagToManifestUrl();
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadException.java b/java/src/com/android/textclassifier/downloader/ModelDownloadException.java
new file mode 100644
index 0000000..99d91b8
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadException.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static java.lang.annotation.RetentionPolicy.SOURCE;
+
+import androidx.annotation.IntDef;
+import java.lang.annotation.Retention;
+
+// TODO(licha): Consider making this a checked exception
+/** Exception thrown when downloading a model. */
+final class ModelDownloadException extends RuntimeException {
+
+ // Consistent with TextClassifierDownloadReported.failure_reason. [1, 8, 9] reserved
+ public static final int UNKNOWN_FAILURE_REASON = 0;
+ public static final int FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN = 2;
+ public static final int FAILED_TO_DOWNLOAD_404_ERROR = 3;
+ public static final int FAILED_TO_DOWNLOAD_OTHER = 4;
+ public static final int DOWNLOADED_FILE_MISSING = 5;
+ public static final int FAILED_TO_PARSE_MANIFEST = 6;
+ public static final int FAILED_TO_VALIDATE_MODEL = 7;
+
+ /** Error code for a failed download task. */
+ @Retention(SOURCE)
+ @IntDef({
+ UNKNOWN_FAILURE_REASON,
+ FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
+ FAILED_TO_DOWNLOAD_404_ERROR,
+ FAILED_TO_DOWNLOAD_OTHER,
+ DOWNLOADED_FILE_MISSING,
+ FAILED_TO_PARSE_MANIFEST,
+ FAILED_TO_VALIDATE_MODEL
+ })
+ public @interface ErrorCode {}
+
+ public static final int DEFAULT_DOWNLOADER_LIB_ERROR_CODE = -1;
+
+ private final int errorCode;
+
+ private final int downloaderLibErrorCode;
+
+ public ModelDownloadException(@ErrorCode int errorCode, Throwable cause) {
+ super(cause);
+ this.errorCode = errorCode;
+ this.downloaderLibErrorCode = DEFAULT_DOWNLOADER_LIB_ERROR_CODE;
+ }
+
+ public ModelDownloadException(@ErrorCode int errorCode, String message) {
+ super(message);
+ this.errorCode = errorCode;
+ this.downloaderLibErrorCode = DEFAULT_DOWNLOADER_LIB_ERROR_CODE;
+ }
+
+ public ModelDownloadException(
+ @ErrorCode int errorCode, int downloaderLibErrorCode, String message) {
+ super(message);
+ this.errorCode = errorCode;
+ this.downloaderLibErrorCode = downloaderLibErrorCode;
+ }
+
+ /** Returns the error code from Model Downloader itself. */
+ @ErrorCode
+ public int getErrorCode() {
+ return errorCode;
+ }
+
+ /** Returns the error code from internal HTTP stack. */
+ public int getDownloaderLibErrorCode() {
+ return downloaderLibErrorCode;
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
new file mode 100644
index 0000000..af33e21
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -0,0 +1,299 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.android.textclassifier.downloader.TextClassifierDownloadLogger.REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED;
+import static com.android.textclassifier.downloader.TextClassifierDownloadLogger.REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED;
+import static com.android.textclassifier.downloader.TextClassifierDownloadLogger.REASON_TO_SCHEDULE_TCS_STARTED;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
+import android.content.BroadcastReceiver;
+import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
+import android.os.LocaleList;
+import android.provider.DeviceConfig;
+import android.text.TextUtils;
+import androidx.work.BackoffPolicy;
+import androidx.work.Constraints;
+import androidx.work.Data;
+import androidx.work.ExistingWorkPolicy;
+import androidx.work.ListenableWorker;
+import androidx.work.NetworkType;
+import androidx.work.OneTimeWorkRequest;
+import androidx.work.Operation;
+import androidx.work.WorkManager;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Enums;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.hash.Hashing;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import java.io.File;
+import java.time.Instant;
+import java.util.List;
+import java.util.Locale;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+
+/** Manager to listen to config update and download latest models. */
+public final class ModelDownloadManager {
+ private static final String TAG = "ModelDownloadManager";
+
+ @VisibleForTesting static final String UNIQUE_QUEUE_NAME = "ModelDownloadWorkManagerQueue";
+
+ private final Context appContext;
+ private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
+ private final Callable<WorkManager> workManagerSupplier;
+ private final DownloadedModelManager downloadedModelManager;
+ private final TextClassifierSettings settings;
+ private final ListeningExecutorService executorService;
+ private final DeviceConfig.OnPropertiesChangedListener deviceConfigListener;
+ private final BroadcastReceiver localeChangedReceiver;
+
+ /**
+ * Constructor for ModelDownloadManager.
+ *
+ * @param appContext the context of this application
+ * @param settings TextClassifierSettings to access DeviceConfig and other settings
+ * @param executorService background executor service
+ */
+ public ModelDownloadManager(
+ Context appContext,
+ TextClassifierSettings settings,
+ ListeningExecutorService executorService) {
+ this(
+ appContext,
+ ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(appContext),
+ DownloadedModelManagerImpl.getInstance(appContext),
+ settings,
+ executorService);
+ }
+
+ @VisibleForTesting
+ public ModelDownloadManager(
+ Context appContext,
+ Class<? extends ListenableWorker> modelDownloadWorkerClass,
+ Callable<WorkManager> workManagerSupplier,
+ DownloadedModelManager downloadedModelManager,
+ TextClassifierSettings settings,
+ ListeningExecutorService executorService) {
+ this.appContext = Preconditions.checkNotNull(appContext);
+ this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
+ this.workManagerSupplier = Preconditions.checkNotNull(workManagerSupplier);
+ this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
+ this.settings = Preconditions.checkNotNull(settings);
+ this.executorService = Preconditions.checkNotNull(executorService);
+
+ this.deviceConfigListener =
+ new DeviceConfig.OnPropertiesChangedListener() {
+ @Override
+ public void onPropertiesChanged(DeviceConfig.Properties unused) {
+ onTextClassifierDeviceConfigChanged();
+ }
+ };
+ this.localeChangedReceiver =
+ new BroadcastReceiver() {
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ onLocaleChanged();
+ }
+ };
+ }
+
+ /** Returns the downlaoded models for the given modelType. */
+ @Nullable
+ public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
+ try {
+ return downloadedModelManager.listModels(modelType);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to list downloaded models", t);
+ return ImmutableList.of();
+ }
+ }
+
+ /** Notifies the model downlaoder that the text classifier service is created. */
+ public void onTextClassifierServiceCreated() {
+ try {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
+ appContext.registerReceiver(
+ localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
+ TcLog.d(TAG, "DeviceConfig listener and locale change listener are registered.");
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ maybeOverrideLocaleListForTesting();
+ TcLog.d(TAG, "Try to schedule model download work because TextClassifierService started.");
+ scheduleDownloadWork(REASON_TO_SCHEDULE_TCS_STARTED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierServiceCreated", t);
+ }
+ }
+
+ // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
+ /** Notifies the model downlaoder that the system locale setting is changed. */
+ @VisibleForTesting
+ void onLocaleChanged() {
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ TcLog.d(TAG, "Try to schedule model download work because of system locale changes.");
+ try {
+ scheduleDownloadWork(REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onLocaleChanged", t);
+ }
+ }
+
+ // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
+ /** Notifies the model downlaoder that the device config for textclassifier is changed. */
+ @VisibleForTesting
+ void onTextClassifierDeviceConfigChanged() {
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ TcLog.d(TAG, "Try to schedule model download work because of device config changes.");
+ try {
+ maybeOverrideLocaleListForTesting();
+ scheduleDownloadWork(REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed inside onTextClassifierDeviceConfigChanged", t);
+ }
+ }
+
+ /** Clean up internal states on destroying. */
+ public void destroy() {
+ try {
+ DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
+ appContext.unregisterReceiver(localeChangedReceiver);
+ TcLog.d(TAG, "DeviceConfig and Locale listener unregistered by ModelDownloadeManager");
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to destroy ModelDownloadManager", t);
+ }
+ }
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ public void dump(IndentingPrintWriter printWriter) {
+ if (!settings.isModelDownloadManagerEnabled()) {
+ return;
+ }
+ try {
+ printWriter.println("ModelDownloadManager:");
+ printWriter.increaseIndent();
+ downloadedModelManager.dump(printWriter);
+ printWriter.decreaseIndent();
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Failed to dump ModelDownloadManager", t);
+ }
+ }
+
+ /**
+ * Enqueue an idempotent work to check device configs and download model files if necessary.
+ *
+ * <p>At any time there will only be at most one work running. If a work is already pending or
+ * running, the newly scheduled work will be appended as a child of that work.
+ */
+ private void scheduleDownloadWork(int reasonToSchedule) {
+ long workId =
+ Hashing.farmHashFingerprint64().hashUnencodedChars(UUID.randomUUID().toString()).asLong();
+ try {
+ NetworkType networkType =
+ Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+ .or(NetworkType.UNMETERED);
+ OneTimeWorkRequest downloadRequest =
+ new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(networkType)
+ .setRequiresBatteryNotLow(true)
+ .setRequiresStorageNotLow(true)
+ .setRequiresDeviceIdle(settings.getManifestDownloadRequiresDeviceIdle())
+ .setRequiresCharging(settings.getManifestDownloadRequiresCharging())
+ .build())
+ .setBackoffCriteria(
+ BackoffPolicy.EXPONENTIAL,
+ settings.getModelDownloadBackoffDelayInMillis(),
+ MILLISECONDS)
+ .setInputData(
+ new Data.Builder()
+ .putLong(ModelDownloadWorker.INPUT_DATA_KEY_WORK_ID, workId)
+ .putLong(
+ ModelDownloadWorker.INPUT_DATA_KEY_SCHEDULED_TIMESTAMP,
+ Instant.now().toEpochMilli())
+ .build())
+ .build();
+ ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+ workManagerSupplier
+ .call()
+ .enqueueUniqueWork(
+ UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+ .getResult();
+ Futures.addCallback(
+ enqueueResultFuture,
+ new FutureCallback<Operation.State.SUCCESS>() {
+ @Override
+ public void onSuccess(Operation.State.SUCCESS unused) {
+ TcLog.d(TAG, "Download work scheduled.");
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ false);
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
+ },
+ executorService);
+ } catch (Throwable t) {
+ // TODO(licha): this is just for temporary fix. Refactor the try-catch in the future.
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ workId, reasonToSchedule, /* failedToSchedule= */ true);
+ }
+ }
+
+ private void maybeOverrideLocaleListForTesting() {
+ String localeList = settings.getTestingLocaleListOverride();
+ if (TextUtils.isEmpty(localeList)) {
+ return;
+ }
+ TcLog.d(
+ TAG,
+ String.format(
+ Locale.US,
+ "Override LocaleList from %s to %s",
+ LocaleList.getAdjustedDefault().toLanguageTags(),
+ localeList));
+ LocaleList.setDefault(LocaleList.forLanguageTags(localeList));
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
new file mode 100644
index 0000000..3db0815
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadWorker.java
@@ -0,0 +1,433 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static java.lang.Math.min;
+
+import android.content.Context;
+import android.os.LocaleList;
+import android.util.ArrayMap;
+import android.util.Pair;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkerParameters;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.google.auto.value.AutoValue;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import java.time.Clock;
+import java.util.ArrayList;
+import java.util.Locale;
+
+/** The WorkManager worker to download models for TextClassifierService. */
+public final class ModelDownloadWorker extends ListenableWorker {
+ private static final String TAG = "ModelDownloadWorker";
+
+ public static final String INPUT_DATA_KEY_WORK_ID = "ModelDownloadWorker_workId";
+ public static final String INPUT_DATA_KEY_SCHEDULED_TIMESTAMP =
+ "ModelDownloadWorker_scheduledTimestamp";
+
+ private final ListeningExecutorService executorService;
+ private final ModelDownloader downloader;
+ private final DownloadedModelManager downloadedModelManager;
+ private final TextClassifierSettings settings;
+
+ private final long workId;
+
+ private final Clock clock;
+ private final long workScheduledTimeMillis;
+
+ private final Object lock = new Object();
+
+ private long workStartedTimeMillis = 0;
+
+ @GuardedBy("lock")
+ private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads;
+
+ private ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload;
+
+ public ModelDownloadWorker(Context context, WorkerParameters workerParams) {
+ super(context, workerParams);
+ this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor();
+ this.downloader = new ModelDownloaderImpl(context, executorService);
+ this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context);
+ this.settings = new TextClassifierSettings();
+ this.pendingDownloads = new ArrayMap<>();
+ this.manifestsToDownload = null;
+
+ this.workId = workerParams.getInputData().getLong(INPUT_DATA_KEY_WORK_ID, 0);
+ this.workScheduledTimeMillis =
+ workerParams.getInputData().getLong(INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, 0);
+ this.clock = Clock.systemUTC();
+ }
+
+ @VisibleForTesting
+ ModelDownloadWorker(
+ Context context,
+ WorkerParameters workerParams,
+ ListeningExecutorService executorService,
+ ModelDownloader modelDownloader,
+ DownloadedModelManager downloadedModelManager,
+ TextClassifierSettings settings,
+ long workId,
+ Clock clock,
+ long workScheduledTimeMillis) {
+ super(context, workerParams);
+ this.executorService = executorService;
+ this.downloader = modelDownloader;
+ this.downloadedModelManager = downloadedModelManager;
+ this.settings = settings;
+ this.pendingDownloads = new ArrayMap<>();
+ this.manifestsToDownload = null;
+ this.workId = workId;
+ this.clock = clock;
+ this.workScheduledTimeMillis = workScheduledTimeMillis;
+ }
+
+ @Override
+ public final ListenableFuture<ListenableWorker.Result> startWork() {
+ TcLog.d(TAG, "Start download work...");
+ workStartedTimeMillis = getCurrentTimeMillis();
+ // Notice: startWork() is invoked on the main thread
+ if (!settings.isModelDownloadManagerEnabled()) {
+ TcLog.e(TAG, "Model Downloader is disabled. Abort the work.");
+ logDownloadWorkCompleted(
+ TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
+ return Futures.immediateFuture(ListenableWorker.Result.failure());
+ }
+ if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
+ TcLog.d(TAG, "Max attempt reached. Abort download work.");
+ logDownloadWorkCompleted(
+ TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MAX_RUN_ATTEMPT_REACHED);
+ return Futures.immediateFuture(ListenableWorker.Result.failure());
+ }
+
+ return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService))
+ .transform(
+ downloadResult -> {
+ Preconditions.checkNotNull(manifestsToDownload);
+ downloadedModelManager.onDownloadCompleted(manifestsToDownload);
+ TcLog.d(TAG, "Download work completed: " + downloadResult);
+ if (downloadResult.failureCount() == 0) {
+ logDownloadWorkCompleted(
+ downloadResult.successCount() > 0
+ ? TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_MODEL_DOWNLOADED
+ : TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_NO_UPDATE_AVAILABLE);
+ return ListenableWorker.Result.success();
+ } else {
+ logDownloadWorkCompleted(
+ TextClassifierDownloadLogger.WORK_RESULT_RETRY_MODEL_DOWNLOAD_FAILED);
+ return ListenableWorker.Result.retry();
+ }
+ },
+ executorService)
+ .catching(
+ Throwable.class,
+ t -> {
+ TcLog.e(TAG, "Unexpected Exception during downloading: ", t);
+ logDownloadWorkCompleted(
+ TextClassifierDownloadLogger.WORK_RESULT_RETRY_RUNTIME_EXCEPTION);
+ return ListenableWorker.Result.retry();
+ },
+ executorService);
+ }
+
+ /**
+ * Checks device settings and returns the list of locales to download according to multi language
+ * support settings. Guarantees that the primary locale goes first.
+ */
+ private ImmutableList<Locale> getLocalesToDownload() {
+ LocaleList localeList = LocaleList.getAdjustedDefault();
+ Locale primaryLocale = localeList.get(0);
+ if (!settings.isMultiLanguageSupportEnabled()) {
+ return ImmutableList.of(primaryLocale);
+ }
+ ImmutableList.Builder<Locale> localesToDownloadBuilder = ImmutableList.builder();
+ int size = min(settings.getMultiLanguageModelsLimit(), localeList.size());
+ for (int i = 0; i < size; i++) {
+ localesToDownloadBuilder.add(localeList.get(i));
+ }
+ return localesToDownloadBuilder.build();
+ }
+
+ /**
+ * Returns list of locales to download from {@code localeList} for the given {@code modelType}.
+ */
+ private ImmutableList<Locale> getLocalesToDownloadByType(
+ ImmutableList<Locale> localeList, @ModelTypeDef String modelType) {
+ if (!settings.getEnabledModelTypesForMultiLanguageSupport().contains(modelType)) {
+ return ImmutableList.of(Locale.getDefault());
+ }
+ return localeList;
+ }
+
+ /**
+ * Check device config and dispatch download tasks for all modelTypes.
+ *
+ * <p>Download tasks will be combined and logged after completion. Return true if all tasks
+ * succeeded
+ */
+ private ListenableFuture<DownloadResult> checkAndDownloadModels() {
+ ImmutableList<Locale> localesToDownload = getLocalesToDownload();
+ ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>();
+ ImmutableMap.Builder<String, ManifestsToDownloadByType> manifestsToDownloadBuilder =
+ ImmutableMap.builder();
+ for (String modelType : ModelType.values()) {
+ ImmutableList<Locale> localesToDownloadByType =
+ getLocalesToDownloadByType(localesToDownload, modelType);
+ ImmutableMap.Builder<String, String> localeTagToManifestUrlBuilder = ImmutableMap.builder();
+ for (Locale locale : localesToDownloadByType) {
+ Pair<String, String> bestLocaleTagAndManifestUrl =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, locale, settings);
+ if (bestLocaleTagAndManifestUrl == null) {
+ TcLog.w(
+ TAG,
+ String.format(
+ Locale.US, "No suitable manifest for %s, %s", modelType, locale.toLanguageTag()));
+ continue;
+ }
+ String bestLocaleTag = bestLocaleTagAndManifestUrl.first;
+ String manifestUrl = bestLocaleTagAndManifestUrl.second;
+ localeTagToManifestUrlBuilder.put(bestLocaleTag, manifestUrl);
+ TcLog.d(
+ TAG,
+ String.format(
+ Locale.US,
+ "model type: %s, current locale tag: %s, best locale tag: %s, manifest url: %s",
+ modelType,
+ locale.toLanguageTag(),
+ bestLocaleTag,
+ manifestUrl));
+ if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) {
+ continue;
+ }
+ downloadResultFutures.add(
+ downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
+ }
+ manifestsToDownloadBuilder.put(
+ modelType, ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.build()));
+ }
+ manifestsToDownload = manifestsToDownloadBuilder.build();
+
+ return Futures.whenAllComplete(downloadResultFutures)
+ .call(
+ () -> {
+ TcLog.d(TAG, "All Download Tasks Completed");
+ int successCount = 0;
+ int failureCount = 0;
+ for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
+ if (Futures.getDone(downloadResultFuture)) {
+ successCount += 1;
+ } else {
+ failureCount += 1;
+ }
+ }
+ return DownloadResult.create(successCount, failureCount);
+ },
+ executorService);
+ }
+
+ private boolean shouldDownloadManifest(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+ if (downloadedManifest == null) {
+ return true;
+ }
+ if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) {
+ if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) {
+ TcLog.w(
+ TAG,
+ String.format(
+ Locale.US,
+ "Manifest failed too many times, stop retrying: %s %d",
+ manifestUrl,
+ downloadedManifest.getFailureCounts()));
+ return false;
+ } else {
+ return true;
+ }
+ }
+ ManifestEnrollment manifestEnrollment =
+ downloadedModelManager.getManifestEnrollment(modelType, localeTag);
+ return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl());
+ }
+
+ /**
+ * Downloads a single manifest and models configured inside it.
+ *
+ * <p>The returned future should always resolve to a ManifestDownloadResult as we catch all
+ * exceptions.
+ */
+ private ListenableFuture<Boolean> downloadManifestAndRegister(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ long downloadStartTimestamp = getCurrentTimeMillis();
+ return FluentFuture.from(downloadManifest(manifestUrl))
+ .transform(
+ unused -> {
+ downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl);
+ TextClassifierDownloadLogger.downloadSucceeded(
+ workId,
+ modelType,
+ manifestUrl,
+ getRunAttemptCount(),
+ getCurrentTimeMillis() - downloadStartTimestamp);
+ TcLog.d(TAG, "Manifest downloaded and registered: " + manifestUrl);
+ return true;
+ },
+ executorService)
+ .catching(
+ Throwable.class,
+ t -> {
+ downloadedModelManager.registerManifestDownloadFailure(manifestUrl);
+ int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON;
+ int downloaderLibErrorCode = 0;
+ if (t instanceof ModelDownloadException) {
+ ModelDownloadException mde = (ModelDownloadException) t;
+ errorCode = mde.getErrorCode();
+ downloaderLibErrorCode = mde.getDownloaderLibErrorCode();
+ }
+ TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t);
+ TextClassifierDownloadLogger.downloadFailed(
+ workId,
+ modelType,
+ manifestUrl,
+ errorCode,
+ getRunAttemptCount(),
+ downloaderLibErrorCode,
+ getCurrentTimeMillis() - downloadStartTimestamp);
+ return false;
+ },
+ executorService);
+ }
+
+ // Download a manifest and its models, and register it to Manifest table.
+ private ListenableFuture<Void> downloadManifest(String manifestUrl) {
+ synchronized (lock) {
+ Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+ if (downloadedManifest != null
+ && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
+ TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
+ return Futures.immediateVoidFuture();
+ }
+ if (pendingDownloads.containsKey(manifestUrl)) {
+ return pendingDownloads.get(manifestUrl);
+ }
+ ListenableFuture<Void> manfiestDownloadFuture =
+ FluentFuture.from(downloader.downloadManifest(manifestUrl))
+ .transformAsync(
+ manifest -> {
+ ModelManifest.Model modelInfo = manifest.getModels(0);
+ return Futures.transform(
+ downloadModel(modelInfo), unused -> modelInfo, executorService);
+ },
+ executorService)
+ .transform(
+ modelInfo -> {
+ downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl());
+ return null;
+ },
+ executorService);
+ pendingDownloads.put(manifestUrl, manfiestDownloadFuture);
+ return manfiestDownloadFuture;
+ }
+ }
+ // Download a model and register it into Model table.
+ private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) {
+ String modelUrl = modelInfo.getUrl();
+ synchronized (lock) {
+ Model downloadedModel = downloadedModelManager.getModel(modelUrl);
+ if (downloadedModel != null) {
+ TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath());
+ return Futures.immediateVoidFuture();
+ }
+ if (pendingDownloads.containsKey(modelUrl)) {
+ return pendingDownloads.get(modelUrl);
+ }
+ ListenableFuture<Void> modelDownloadFuture =
+ FluentFuture.from(
+ downloader.downloadModel(
+ downloadedModelManager.getModelDownloaderDir(), modelInfo))
+ .transform(
+ modelFile -> {
+ downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath());
+ TcLog.d(TAG, "Model File downloaded: " + modelUrl);
+ return null;
+ },
+ executorService);
+ pendingDownloads.put(modelUrl, modelDownloadFuture);
+ return modelDownloadFuture;
+ }
+ }
+
+ /**
+ * This method will be called when we our work gets interrupted by the system. Result future
+ * should have already been cancelled in that case. Unless it's because the REPLACE policy of
+ * WorkManager unique queue, the interrupted work will be rescheduled later.
+ */
+ @Override
+ public final void onStopped() {
+ TcLog.d(TAG, String.format(Locale.US, "Stop download. Attempt:%d", getRunAttemptCount()));
+ logDownloadWorkCompleted(TextClassifierDownloadLogger.WORK_RESULT_RETRY_STOPPED_BY_OS);
+ }
+
+ private long getCurrentTimeMillis() {
+ return clock.instant().toEpochMilli();
+ }
+
+ private void logDownloadWorkCompleted(int workResult) {
+ if (workStartedTimeMillis < workScheduledTimeMillis) {
+ TcLog.w(
+ TAG,
+ String.format(
+ Locale.US,
+ "Bad workStartedTimeMillis: %d, workScheduledTimeMillis: %d",
+ workStartedTimeMillis,
+ workScheduledTimeMillis));
+ workStartedTimeMillis = workScheduledTimeMillis;
+ }
+ TextClassifierDownloadLogger.downloadWorkCompleted(
+ workId,
+ workResult,
+ getRunAttemptCount(),
+ workStartedTimeMillis - workScheduledTimeMillis,
+ getCurrentTimeMillis() - workStartedTimeMillis);
+ }
+
+ @AutoValue
+ abstract static class DownloadResult {
+ public abstract int successCount();
+
+ public abstract int failureCount();
+
+ public static DownloadResult create(int successCount, int failureCount) {
+ return new AutoValue_ModelDownloadWorker_DownloadResult(successCount, failureCount);
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloader.java b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
new file mode 100644
index 0000000..7e22d99
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import java.io.File;
+
+/** Interface for downloading files from certain URI. */
+interface ModelDownloader {
+
+ /**
+ * Downloads a manifest file from given url, parse it and return the proto.
+ *
+ * <p>The downloaded file should be deleted no matter the download succeeds or not.
+ *
+ * @param manifestUrl url to download manifest file from
+ * @return listenable future of ModelManifest proto
+ */
+ ListenableFuture<ModelManifest> downloadManifest(String manifestUrl);
+
+ /**
+ * Downloads a model file and validate it based on given model info.
+ *
+ * <p>The file should be in the target folder. Returns the File if succeed. If the download or
+ * validation fails, the unfinished model file should be cleaned up. Failures should be wrapped
+ * inside a {@link ModelDownloadException} and throw.
+ *
+ * @param targetDir the target directory for the downloaded model
+ * @param modelInfo the model information in manifest used for downloading and validation
+ * @return the downloaded model file
+ */
+ ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model modelInfo);
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
new file mode 100644
index 0000000..0b76f22
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -0,0 +1,267 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static android.content.Context.BIND_AUTO_CREATE;
+import static android.content.Context.BIND_NOT_FOREGROUND;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.ServiceConnection;
+import android.os.IBinder;
+import androidx.concurrent.futures.CallbackToFutureAdapter;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.protobuf.ExtensionRegistryLite;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.hash.HashCode;
+import com.google.common.hash.Hashing;
+import com.google.common.io.Files;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.net.URI;
+import java.util.concurrent.ExecutorService;
+
+/**
+ * ModelDownloader implementation that forwards requests to ModelDownloaderService. This is to
+ * restrict the INTERNET permission to the service process only (instead of the whole ExtServices).
+ */
+final class ModelDownloaderImpl implements ModelDownloader {
+ private static final String TAG = "ModelDownloaderImpl";
+
+ private final Context context;
+ private final ExecutorService bgExecutorService;
+ private final Class<?> downloaderServiceClass;
+
+ public ModelDownloaderImpl(Context context, ExecutorService bgExecutorService) {
+ this(context, bgExecutorService, ModelDownloaderService.class);
+ }
+
+ @VisibleForTesting
+ ModelDownloaderImpl(
+ Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass) {
+ this.context = context.getApplicationContext();
+ this.bgExecutorService = bgExecutorService;
+ this.downloaderServiceClass = downloaderServiceClass;
+ }
+
+ @Override
+ public ListenableFuture<ModelManifest> downloadManifest(String manifestUrl) {
+ File manifestFile =
+ new File(context.getCacheDir(), manifestUrl.replaceAll("[^A-Za-z0-9]", "_") + ".manifest");
+ return Futures.transform(
+ download(URI.create(manifestUrl), manifestFile),
+ bytesWritten -> {
+ try {
+ return ModelManifest.parseFrom(
+ new FileInputStream(manifestFile), ExtensionRegistryLite.getEmptyRegistry());
+ } catch (Throwable t) {
+ throw new ModelDownloadException(ModelDownloadException.FAILED_TO_PARSE_MANIFEST, t);
+ } finally {
+ manifestFile.delete();
+ }
+ },
+ bgExecutorService);
+ }
+
+ @Override
+ public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) {
+ File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
+ ListenableFuture<File> modelFileFuture =
+ Futures.transform(
+ download(URI.create(model.getUrl()), modelFile),
+ bytesWritten -> {
+ validateModel(modelFile, model.getSizeInBytes(), model.getFingerprint());
+ return modelFile;
+ },
+ bgExecutorService);
+ Futures.addCallback(
+ modelFileFuture,
+ new FutureCallback<File>() {
+ @Override
+ public void onSuccess(File pendingModelFile) {
+ TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ modelFile.delete();
+ TcLog.e(TAG, "Failed to download: " + modelFile.getAbsolutePath(), t);
+ }
+ },
+ bgExecutorService);
+ return modelFileFuture;
+ }
+
+ // TODO(licha): Make this visible for testing. So we can avoid some duplicated test cases.
+ /**
+ * Downloads the file from uri to the targetFile. If the targetFile already exists, it will be
+ * deleted. Return bytes written if succeeds.
+ */
+ private ListenableFuture<Long> download(URI uri, File targetFile) {
+ if (targetFile.exists()) {
+ TcLog.w(
+ TAG,
+ "Target file already exists. Delete it before downloading: "
+ + targetFile.getAbsolutePath());
+ targetFile.delete();
+ }
+ DownloaderServiceConnection conn = new DownloaderServiceConnection();
+ ListenableFuture<IModelDownloaderService> downloaderServiceFuture = connect(conn);
+ ListenableFuture<Long> bytesWrittenFuture =
+ Futures.transformAsync(
+ downloaderServiceFuture,
+ service -> scheduleDownload(service, uri, targetFile),
+ bgExecutorService);
+ bytesWrittenFuture.addListener(
+ () -> {
+ try {
+ context.unbindService(conn);
+ } catch (IllegalArgumentException e) {
+ TcLog.e(TAG, "Error when unbind", e);
+ }
+ },
+ bgExecutorService);
+ return bytesWrittenFuture;
+ }
+
+ /** Model verification. Throws unchecked Exceptions if validation fails. */
+ private static void validateModel(File pendingModelFile, long sizeInBytes, String fingerprint) {
+ if (!pendingModelFile.exists()) {
+ throw new ModelDownloadException(
+ ModelDownloadException.DOWNLOADED_FILE_MISSING, "PendingModelFile does not exist.");
+ }
+ if (pendingModelFile.length() != sizeInBytes) {
+ throw new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
+ String.format(
+ "PendingModelFile size does not match: expected [%d] actual [%d]",
+ sizeInBytes, pendingModelFile.length()));
+ }
+ try {
+ HashCode pendingModelFingerprint =
+ Files.asByteSource(pendingModelFile).hash(Hashing.sha384());
+ if (!pendingModelFingerprint.equals(HashCode.fromString(fingerprint))) {
+ throw new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
+ String.format(
+ "PendingModelFile fingerprint does not match: expected [%s] actual [%s]",
+ fingerprint, pendingModelFingerprint));
+ }
+ } catch (IOException e) {
+ throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e);
+ }
+ TcLog.d(TAG, "Pending model file passed validation.");
+ }
+
+ private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
+ TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ conn.attachCompleter(completer);
+ Intent intent = new Intent(context, downloaderServiceClass);
+ if (context.bindService(intent, conn, BIND_AUTO_CREATE | BIND_NOT_FOREGROUND)) {
+ return "Binding to service";
+ } else {
+ completer.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
+ "Unable to bind to service"));
+ return "Binding failed";
+ }
+ });
+ }
+
+ // Here the returned download result future can be set by: 1) the service can invoke the callback
+ // and set the result/exception; 2) If the service crashed, the CallbackToFutureAdapter will try
+ // to fail the future when the callback is garbage collected. If somehow none of them worked, the
+ // restult future will hang there until time out. (WorkManager forces a 10-min running time.)
+ private static ListenableFuture<Long> scheduleDownload(
+ IModelDownloaderService service, URI uri, File targetFile) {
+ TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ service.download(
+ uri.toString(),
+ targetFile.getAbsolutePath(),
+ new IModelDownloaderCallback.Stub() {
+ @Override
+ public void onSuccess(long bytesWritten) {
+ completer.set(bytesWritten);
+ }
+
+ @Override
+ public void onFailure(int downloaderLibErrorCode, String errorMsg) {
+ completer.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER,
+ downloaderLibErrorCode,
+ errorMsg));
+ }
+ });
+ return "downlaoderService.download";
+ });
+ }
+
+ /** The implementation of {@link ServiceConnection} that handles changes in the connection. */
+ @VisibleForTesting
+ static class DownloaderServiceConnection implements ServiceConnection {
+ private static final String TAG = "ModelDownloaderImpl.DownloaderServiceConnection";
+
+ private CallbackToFutureAdapter.Completer<IModelDownloaderService> completer;
+
+ public void attachCompleter(
+ CallbackToFutureAdapter.Completer<IModelDownloaderService> completer) {
+ this.completer = completer;
+ }
+
+ @Override
+ public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
+ TcLog.d(TAG, "DownloaderService connected");
+ completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
+ }
+
+ @Override
+ public void onServiceDisconnected(ComponentName componentName) {
+ // If this is invoked after onServiceConnected, it will be ignored by the completer.
+ completer.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
+ "Service disconnected"));
+ }
+
+ @Override
+ public void onBindingDied(ComponentName name) {
+ completer.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, "Binding died"));
+ }
+
+ @Override
+ public void onNullBinding(ComponentName name) {
+ completer.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
+ "Unable to bind to DownloaderService"));
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
new file mode 100644
index 0000000..6d7e47e
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderService.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.app.Service;
+import android.content.Intent;
+import android.os.IBinder;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.base.TcLog;
+
+/** Service to expose IModelDownloaderService. */
+public final class ModelDownloaderService extends Service {
+ private static final String TAG = "ModelDownloaderService";
+
+ private IBinder iBinder;
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ this.iBinder =
+ new ModelDownloaderServiceImpl(
+ /* bgExecutorService= */ TextClassifierServiceExecutors.getDownloaderExecutor(),
+ /* transportExecutorService= */ TextClassifierServiceExecutors.getNetworkIOExecutor());
+ }
+
+ @Override
+ public IBinder onBind(Intent intent) {
+ TcLog.d(TAG, "Binding to ModelDownloadService");
+ return iBinder;
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
new file mode 100644
index 0000000..47e6f19
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderServiceImpl.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.android.textclassifier.downloader.ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE;
+import static com.google.common.base.Predicates.instanceOf;
+import static com.google.common.base.Throwables.getCausalChain;
+
+import android.os.RemoteException;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.android.downloader.AndroidDownloaderLogger;
+import com.google.android.downloader.ConnectivityHandler;
+import com.google.android.downloader.DownloadConstraints;
+import com.google.android.downloader.DownloadRequest;
+import com.google.android.downloader.DownloadResult;
+import com.google.android.downloader.Downloader;
+import com.google.android.downloader.PlatformUrlEngine;
+import com.google.android.downloader.RequestException;
+import com.google.android.downloader.SimpleFileDownloadDestination;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.net.URI;
+import java.util.concurrent.ExecutorService;
+import javax.annotation.concurrent.ThreadSafe;
+
+/** IModelDownloaderService implementation with Android Downloader library. */
+@ThreadSafe
+final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
+ private static final String TAG = "ModelDownloaderServiceImpl";
+
+ // Connectivity constraints will be checked by WorkManager instead.
+ private static class NoOpConnectivityHandler implements ConnectivityHandler {
+ @Override
+ public ListenableFuture<Void> checkConnectivity(DownloadConstraints constraints) {
+ return Futures.immediateVoidFuture();
+ }
+ }
+
+ private final ExecutorService bgExecutorService;
+ private final Downloader downloader;
+
+ public ModelDownloaderServiceImpl(
+ ExecutorService bgExecutorService, ListeningExecutorService transportExecutorService) {
+ this.bgExecutorService = bgExecutorService;
+ this.downloader =
+ new Downloader.Builder()
+ // This executor is for callbacks, not network IO. See discussions in cl/337156844
+ .withIOExecutor(bgExecutorService)
+ .withConnectivityHandler(new NoOpConnectivityHandler())
+ .addUrlEngine(
+ // clear text traffic won't actually work without a manifest change, so http link
+ // is still not supported on production builds.
+ // Adding "http" here only for testing purposes.
+ ImmutableList.of("https", "http"),
+ new PlatformUrlEngine(
+ // This executor handles network transportation and can stall for long
+ transportExecutorService,
+ /* connectTimeoutMs= */ 60 * 1000,
+ /* readTimeoutMs= */ 60 * 1000))
+ .withLogger(new AndroidDownloaderLogger())
+ .build();
+ }
+
+ @VisibleForTesting
+ ModelDownloaderServiceImpl(ExecutorService bgExecutorService, Downloader downloader) {
+ this.bgExecutorService = Preconditions.checkNotNull(bgExecutorService);
+ this.downloader = Preconditions.checkNotNull(downloader);
+ }
+
+ @Override
+ public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) {
+ TcLog.d(TAG, "Download request received: " + uri);
+ try {
+ File targetFile = new File(targetFilePath);
+ File tempMetadataFile = getMetadataFile(targetFile);
+ DownloadRequest request =
+ downloader
+ .newRequestBuilder(
+ URI.create(uri), new SimpleFileDownloadDestination(targetFile, tempMetadataFile))
+ .build();
+ downloader
+ .execute(request)
+ .transform(DownloadResult::bytesWritten, MoreExecutors.directExecutor())
+ .addCallback(
+ new FutureCallback<Long>() {
+ @Override
+ public void onSuccess(Long bytesWritten) {
+ tempMetadataFile.delete();
+ dispatchOnSuccessSafely(callback, bytesWritten);
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure", t);
+ // TODO(licha): We may be able to resume the download if we keep those files
+ targetFile.delete();
+ tempMetadataFile.delete();
+ // Try to infer the failure reason
+ RequestException requestException =
+ (RequestException)
+ Iterables.find(
+ getCausalChain(t),
+ instanceOf(RequestException.class),
+ /* defaultValue= */ null);
+ // TODO(b/181805039): Use error code once downloader lib supports it.
+ int downloaderLibErrorCode =
+ requestException != null
+ ? requestException.getErrorDetails().getHttpStatusCode()
+ : DEFAULT_DOWNLOADER_LIB_ERROR_CODE;
+ dispatchOnFailureSafely(callback, downloaderLibErrorCode, t);
+ }
+ },
+ bgExecutorService);
+ } catch (Throwable t) {
+ dispatchOnFailureSafely(callback, DEFAULT_DOWNLOADER_LIB_ERROR_CODE, t);
+ }
+ }
+
+ @VisibleForTesting
+ static File getMetadataFile(File targetFile) {
+ return new File(targetFile.getParentFile(), targetFile.getName() + ".metadata");
+ }
+
+ private static void dispatchOnSuccessSafely(
+ IModelDownloaderCallback callback, long bytesWritten) {
+ try {
+ callback.onSuccess(bytesWritten);
+ } catch (RemoteException e) {
+ TcLog.e(TAG, "Unable to notify successful download", e);
+ }
+ }
+
+ private static void dispatchOnFailureSafely(
+ IModelDownloaderCallback callback, int downloaderLibErrorCode, Throwable throwable) {
+ try {
+ callback.onFailure(downloaderLibErrorCode, throwable.getMessage());
+ } catch (RemoteException e) {
+ TcLog.e(TAG, "Unable to notify failures in download", e);
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java
new file mode 100644
index 0000000..7416b00
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/TextClassifierDownloadLogger.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.android.textclassifier.downloader.ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE;
+
+import android.text.TextUtils;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.statsd.TextClassifierStatsLog;
+import com.android.textclassifier.downloader.ModelDownloadException.ErrorCode;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import java.util.Locale;
+
+/** Logs TextClassifier download event. */
+final class TextClassifierDownloadLogger {
+ private static final String TAG = "TextClassifierDownloadLogger";
+
+ // Values for TextClassifierDownloadReported.download_status
+ private static final int DOWNLOAD_STATUS_SUCCEEDED =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__DOWNLOAD_STATUS__SUCCEEDED;
+ private static final int DOWNLOAD_STATUS_FAILED_AND_RETRY =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__DOWNLOAD_STATUS__FAILED_AND_RETRY;
+
+ private static final int DEFAULT_MODEL_TYPE =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__MODEL_TYPE__UNKNOWN_MODEL_TYPE;
+ private static final ImmutableMap<String, Integer> MODEL_TYPE_MAP =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__MODEL_TYPE__ANNOTATOR,
+ ModelType.LANG_ID,
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__MODEL_TYPE__LANG_ID,
+ ModelType.ACTIONS_SUGGESTIONS,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__MODEL_TYPE__ACTIONS_SUGGESTIONS);
+
+ private static final int DEFAULT_FILE_TYPE =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FILE_TYPE__UNKNOWN_FILE_TYPE;
+
+ private static final int DEFAULT_FAILURE_REASON =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__UNKNOWN_FAILURE_REASON;
+ private static final ImmutableMap<Integer, Integer> FAILURE_REASON_MAP =
+ ImmutableMap.<Integer, Integer>builder()
+ .put(
+ ModelDownloadException.UNKNOWN_FAILURE_REASON,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__UNKNOWN_FAILURE_REASON)
+ .put(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN)
+ .put(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_404_ERROR,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_DOWNLOAD_404_ERROR)
+ .put(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_DOWNLOAD_OTHER)
+ .put(
+ ModelDownloadException.DOWNLOADED_FILE_MISSING,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__DOWNLOADED_FILE_MISSING)
+ .put(
+ ModelDownloadException.FAILED_TO_PARSE_MANIFEST,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_PARSE_MANIFEST)
+ .put(
+ ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_REPORTED__FAILURE_REASON__FAILED_TO_VALIDATE_MODEL)
+ .build();
+
+ // Reasons to schedule
+ public static final int REASON_TO_SCHEDULE_TCS_STARTED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_SCHEDULED__REASON_TO_SCHEDULE__TCS_STARTED;
+ public static final int REASON_TO_SCHEDULE_LOCALE_SETTINGS_CHANGED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_SCHEDULED__REASON_TO_SCHEDULE__LOCALE_SETTINGS_CHANGED;
+ public static final int REASON_TO_SCHEDULE_DEVICE_CONFIG_UPDATED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_SCHEDULED__REASON_TO_SCHEDULE__DEVICE_CONFIG_UPDATED;
+
+ // Work results
+ public static final int WORK_RESULT_UNKNOWN_WORK_RESULT =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__UNKNOWN_WORK_RESULT;
+ public static final int WORK_RESULT_SUCCESS_MODEL_DOWNLOADED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__SUCCESS_MODEL_DOWNLOADED;
+ public static final int WORK_RESULT_SUCCESS_NO_UPDATE_AVAILABLE =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__SUCCESS_NO_UPDATE_AVAILABLE;
+ public static final int WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__FAILURE_MODEL_DOWNLOADER_DISABLED;
+ public static final int WORK_RESULT_FAILURE_MAX_RUN_ATTEMPT_REACHED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__FAILURE_MAX_RUN_ATTEMPT_REACHED;
+ public static final int WORK_RESULT_RETRY_MODEL_DOWNLOAD_FAILED =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__RETRY_MODEL_DOWNLOAD_FAILED;
+ public static final int WORK_RESULT_RETRY_RUNTIME_EXCEPTION =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__RETRY_RUNTIME_EXCEPTION;
+ public static final int WORK_RESULT_RETRY_STOPPED_BY_OS =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED__WORK_RESULT__RETRY_STOPPED_BY_OS;
+
+ /** Logs a succeeded download task. */
+ public static void downloadSucceeded(
+ long workId,
+ @ModelTypeDef String modelType,
+ String url,
+ int runAttemptCount,
+ long downloadDurationMillis) {
+ Preconditions.checkArgument(!TextUtils.isEmpty(url), "url cannot be null/empty");
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED,
+ MODEL_TYPE_MAP.getOrDefault(modelType, DEFAULT_MODEL_TYPE),
+ DEFAULT_FILE_TYPE,
+ DOWNLOAD_STATUS_SUCCEEDED,
+ url,
+ DEFAULT_FAILURE_REASON,
+ runAttemptCount,
+ DEFAULT_DOWNLOADER_LIB_ERROR_CODE,
+ downloadDurationMillis,
+ workId);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "Download Reported: modelType=%s, fileType=%d, status=%d, url=%s, "
+ + "failureReason=%d, runAttemptCount=%d, downloaderLibErrorCode=%d, "
+ + "downloadDurationMillis=%d, workId=%d",
+ MODEL_TYPE_MAP.getOrDefault(modelType, DEFAULT_MODEL_TYPE),
+ DEFAULT_FILE_TYPE,
+ DOWNLOAD_STATUS_SUCCEEDED,
+ url,
+ DEFAULT_FAILURE_REASON,
+ runAttemptCount,
+ DEFAULT_DOWNLOADER_LIB_ERROR_CODE,
+ downloadDurationMillis,
+ workId));
+ }
+ }
+
+ /** Logs a failed download task which will be retried later. */
+ public static void downloadFailed(
+ long workId,
+ @ModelTypeDef String modelType,
+ String url,
+ @ErrorCode int errorCode,
+ int runAttemptCount,
+ int downloaderLibErrorCode,
+ long downloadDurationMillis) {
+ Preconditions.checkArgument(!TextUtils.isEmpty(url), "url cannot be null/empty");
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_REPORTED,
+ MODEL_TYPE_MAP.getOrDefault(modelType, DEFAULT_MODEL_TYPE),
+ DEFAULT_FILE_TYPE,
+ DOWNLOAD_STATUS_FAILED_AND_RETRY,
+ url,
+ FAILURE_REASON_MAP.getOrDefault(errorCode, DEFAULT_FAILURE_REASON),
+ runAttemptCount,
+ downloaderLibErrorCode,
+ downloadDurationMillis,
+ workId);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "Download Reported: modelType=%s, fileType=%d, status=%d, url=%s, "
+ + "failureReason=%d, runAttemptCount=%d, downloaderLibErrorCode=%d, "
+ + "downloadDurationMillis=%d, workId=%d",
+ MODEL_TYPE_MAP.getOrDefault(modelType, DEFAULT_MODEL_TYPE),
+ DEFAULT_FILE_TYPE,
+ DOWNLOAD_STATUS_FAILED_AND_RETRY,
+ url,
+ FAILURE_REASON_MAP.getOrDefault(errorCode, DEFAULT_FAILURE_REASON),
+ runAttemptCount,
+ downloaderLibErrorCode,
+ downloadDurationMillis,
+ workId));
+ }
+ }
+
+ public static void downloadWorkScheduled(
+ long workId, int reasonToSchedule, boolean failedToSchedule) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_WORK_SCHEDULED,
+ workId,
+ reasonToSchedule,
+ failedToSchedule);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "Download Work Scheduled: workId=%d, reasonToSchedule=%d, failedToSchedule=%b",
+ workId,
+ reasonToSchedule,
+ failedToSchedule));
+ }
+ }
+
+ public static void downloadWorkCompleted(
+ long workId,
+ int workResult,
+ int runAttemptCount,
+ long workScheduledToStartedDurationMillis,
+ long workStartedToEndedDurationMillis) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED,
+ workId,
+ workResult,
+ runAttemptCount,
+ workScheduledToStartedDurationMillis,
+ workStartedToEndedDurationMillis);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "Download Work Completed: workId=%d, result=%d, runAttemptCount=%d, "
+ + "workScheduledToStartedDurationMillis=%d, "
+ + "workStartedToEndedDurationMillis=%d",
+ workId,
+ workResult,
+ runAttemptCount,
+ workScheduledToStartedDurationMillis,
+ workStartedToEndedDurationMillis));
+ }
+ }
+
+ private TextClassifierDownloadLogger() {}
+}
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index 74261c1..775f9f9 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -23,6 +23,22 @@
default_applicable_licenses: ["external_libtextclassifier_license"],
}
+java_library {
+ name: "TextClassifierServiceTestingLib",
+
+ srcs: [
+ "src/com/android/textclassifier/testing/*.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ "TextClassifierServiceLib",
+ "androidx.test.espresso.core",
+ "mockito-target-minus-junit4",
+ ],
+}
+
android_test {
name: "TextClassifierServiceTest",
@@ -32,9 +48,14 @@
"src/**/*.java",
],
+ exclude_srcs: [
+ "src/**/ModelDownloaderIntegrationTest.java",
+ "src/com/android/textclassifier/testing/*.java",
+ ],
+
+
static_libs: [
"androidx.test.ext.junit",
- "androidx.test.rules",
"androidx.test.espresso.core",
"androidx.test.ext.truth",
"mockito-target-minus-junit4",
@@ -45,7 +66,9 @@
"TextClassifierServiceLib",
"statsdprotolite",
"textclassifierprotoslite",
- "TextClassifierCoverageLib"
+ "TextClassifierCoverageLib",
+ "androidx.work_work-testing",
+ "TextClassifierServiceTestingLib",
],
jni_libs: [
@@ -66,4 +89,37 @@
instrumentation_for: "TextClassifierService",
data: ["testdata/*"],
+
+ test_config: "AndroidTest.xml",
+}
+
+android_test {
+ name: "TCSModelDownloaderIntegrationTest",
+
+ manifest: "AndroidManifest_TCSModelDownloaderIntegrationTest.xml",
+
+ srcs: [
+ "src/**/ModelDownloaderIntegrationTest.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.espresso.core",
+ "androidx.test.ext.truth",
+ "ub-uiautomator",
+ "TextClassifierServiceTestingLib",
+ ],
+
+ jni_libs: [
+ "libtextclassifier",
+ ],
+
+ test_suites: [
+ "general-tests"
+ ],
+
+ min_sdk_version: "30",
+ sdk_version: "system_current",
+
+ test_config: "AndroidTest_TCSModelDownloaderIntegrationTest.xml",
}
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 3ee30da..b370cf7 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -8,6 +8,10 @@
<application>
<uses-library android:name="android.test.runner"/>
+ <service
+ android:exported="false"
+ android:name="com.android.textclassifier.downloader.TestModelDownloaderService">
+ </service>
</application>
<instrumentation
diff --git a/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml b/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml
new file mode 100644
index 0000000..ff6ab85
--- /dev/null
+++ b/java/tests/instrumentation/AndroidManifest_TCSModelDownloaderIntegrationTest.xml
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.downloader.tests">
+
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.textclassifier.downloader.tests"/>
+</manifest>
diff --git a/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml b/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml
new file mode 100644
index 0000000..424b0f5
--- /dev/null
+++ b/java/tests/instrumentation/AndroidTest_TCSModelDownloaderIntegrationTest.xml
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<configuration description="Runs TCSModelDownloaderIntegrationTest.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TCSModelDownloaderIntegrationTest.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.downloader.tests" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+</configuration>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 746931b..ddab8bd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -17,17 +17,17 @@
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import android.content.Context;
-import android.os.Binder;
import android.os.CancellationSignal;
-import android.os.Parcel;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.TextClassification;
-import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
@@ -42,29 +42,36 @@
import com.android.os.AtomsProto.TextClassifierApiUsageReported;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
-import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.StatsdTestUtils;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
+import com.android.textclassifier.downloader.ModelDownloadManager;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
+import java.io.IOException;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class DefaultTextClassifierServiceTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
@@ -79,14 +86,21 @@
@Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
@Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
@Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+ @Mock private ModelFileManager testModelFileManager;
@Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+ public void setup() throws IOException {
+ testInjector =
+ new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager);
defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
defaultTextClassifierService.onCreate();
+
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Before
@@ -113,7 +127,7 @@
new TextClassification.Request.Builder(text, 0, text.length()).build();
defaultTextClassifierService.onClassifyText(
- createTextClassificationSessionId(),
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textClassificationCallback);
@@ -135,7 +149,7 @@
TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build();
defaultTextClassifierService.onSuggestSelection(
- createTextClassificationSessionId(),
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textSelectionCallback);
@@ -153,7 +167,10 @@
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
defaultTextClassifierService.onGenerateLinks(
- createTextClassificationSessionId(), request, new CancellationSignal(), textLinksCallback);
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
+ request,
+ new CancellationSignal(),
+ textLinksCallback);
ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class);
verify(textLinksCallback).onSuccess(captor.capture());
@@ -170,7 +187,7 @@
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
defaultTextClassifierService.onDetectLanguage(
- createTextClassificationSessionId(),
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textLanguageCallback);
@@ -192,7 +209,7 @@
new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
defaultTextClassifierService.onSuggestConversationActions(
- createTextClassificationSessionId(),
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
conversationActionsCallback);
@@ -207,13 +224,13 @@
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
- testInjector.setModelFileManager(
- new ModelFileManager(ApplicationProvider.getApplicationContext(), ImmutableList.of()));
+ when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(null);
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
defaultTextClassifierService.onClassifyText(
- createTextClassificationSessionId(),
+ TestingUtils.createTextClassificationSessionId(SESSION_ID),
request,
new CancellationSignal(),
textClassificationCallback);
@@ -240,25 +257,13 @@
assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
}
- private static TextClassificationSessionId createTextClassificationSessionId() {
- // Used a hack to create TextClassificationSessionId because its constructor is @hide.
- Parcel parcel = Parcel.obtain();
- parcel.writeString(SESSION_ID);
- parcel.writeStrongBinder(new Binder());
- parcel.setDataPosition(0);
- return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
- }
-
private static final class TestInjector implements DefaultTextClassifierService.Injector {
private final Context context;
private ModelFileManager modelFileManager;
- private TestInjector(Context context) {
+ private TestInjector(Context context, ModelFileManager modelFileManager) {
this.context = Preconditions.checkNotNull(context);
- }
-
- private void setModelFileManager(ModelFileManager modelFileManager) {
- this.modelFileManager = modelFileManager;
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
}
@Override
@@ -267,10 +272,8 @@
}
@Override
- public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
- if (modelFileManager == null) {
- return TestDataUtils.createModelFileManagerForTesting(context);
- }
+ public ModelFileManager createModelFileManager(
+ TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
return modelFileManager;
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
new file mode 100644
index 0000000..0e40515
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -0,0 +1,434 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.android.textclassifier.common.ModelFile.LANGUAGE_INDEPENDENT;
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.work.WorkManager;
+import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
+import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
+import com.android.textclassifier.common.ModelFile;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.downloader.DownloadedModelManager;
+import com.android.textclassifier.downloader.ModelDownloadManager;
+import com.android.textclassifier.downloader.ModelDownloadWorker;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableList;
+import com.google.common.io.Files;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ModelFileManagerImplTest {
+ private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ private TestingDeviceConfig deviceConfig;
+
+ @Mock private DownloadedModelManager downloadedModelManager;
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
+ private File rootTestDir;
+ private ModelFileManagerImpl modelFileManager;
+ private ModelDownloadManager modelDownloadManager;
+ private TextClassifierSettings settings;
+
+ @Before
+ public void setup() {
+ deviceConfig = new TestingDeviceConfig();
+ rootTestDir =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
+ rootTestDir.mkdirs();
+ Context context = ApplicationProvider.getApplicationContext();
+ settings = new TextClassifierSettings(deviceConfig);
+ modelDownloadManager =
+ new ModelDownloadManager(
+ context,
+ ModelDownloadWorker.class,
+ () -> WorkManager.getInstance(context),
+ downloadedModelManager,
+ settings,
+ MoreExecutors.newDirectExecutorService());
+ modelFileManager = new ModelFileManagerImpl(context, modelDownloadManager, settings);
+ setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE));
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(rootTestDir);
+ }
+
+ @Test
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
+
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
+
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFile olderModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile newerModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2);
+ ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, /* localePreferences= */ null, /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile languageDependentModelFile =
+ createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel() {
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile languageDependentModelFile = createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFile matchButOlderModel = createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+ ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2);
+ ModelFileManager modelFileManager =
+ createModelFileManager(matchButOlderModel, mismatchButNewerModel);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh"));
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk"));
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, LocaleList.forLanguageTags("zh"), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, LocaleList.forLanguageTags("zh"), /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile nonPrimaryLocaleModelFile = createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null);
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+
+ ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFile nonPrimaryLocalePreferenceModelFile = createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE,
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")),
+ /*detectedLocales=*/ null);
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_multiLanguageEnabled_noMatchedModel() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+
+ ModelFile primaryLocalePreferenceModelFile = createModelFile("en", /* version= */ 1);
+ ModelFile secondaryLocalePreferencetModelFile = createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(
+ primaryLocalePreferenceModelFile, secondaryLocalePreferencetModelFile);
+ final LocaleList requestLocalePreferences =
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("fy"));
+ final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("hr");
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
+ assertThat(bestModelFile).isEqualTo(primaryLocalePreferenceModelFile);
+ }
+
+ @Test
+ public void findBestModel_multiLanguageEnabled_matchDetected() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+
+ ModelFile localePreferenceModelFile = createModelFile("zh", /*version*/ 1);
+ ModelFileManager modelFileManager = createModelFileManager(localePreferenceModelFile);
+ final LocaleList requestLocalePreferences =
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("zh"));
+ final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("zh");
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
+ assertThat(bestModelFile).isEqualTo(localePreferenceModelFile);
+ }
+
+ @Test
+ public void findBestModel_multiLanguageDisabled_matchDetected() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, false);
+
+ ModelFile nonLocalePreferenceModelFile = createModelFile("zh", /*version*/ 1);
+ ModelFileManager modelFileManager = createModelFileManager(nonLocalePreferenceModelFile);
+ final LocaleList requestLocalePreferences = new LocaleList(Locale.forLanguageTag("en"));
+ final LocaleList detectedLocalePreferences = LocaleList.getEmptyLocaleList();
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
+ assertThat(bestModelFile).isEqualTo(null);
+ }
+
+ @Test
+ public void downloaderModelsLister() throws IOException {
+ File annotatorFile = new File(rootTestDir, "annotator.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
+ File langIdFile = new File(rootTestDir, "langId.model");
+ Files.copy(TestDataUtils.getLangIdModelFile(), langIdFile);
+
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+
+ DownloaderModelsLister downloaderModelsLister =
+ new DownloaderModelsLister(modelDownloadManager, settings);
+
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
+ when(downloadedModelManager.listModels(ModelType.LANG_ID))
+ .thenReturn(Arrays.asList(langIdFile));
+ when(downloadedModelManager.listModels(ModelType.ACTIONS_SUGGESTIONS))
+ .thenReturn(new ArrayList<>());
+ assertThat(downloaderModelsLister.list(MODEL_TYPE))
+ .containsExactly(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
+ assertThat(downloaderModelsLister.list(ModelType.LANG_ID))
+ .containsExactly(ModelFile.createFromRegularFile(langIdFile, ModelType.LANG_ID));
+ assertThat(downloaderModelsLister.list(ModelType.ACTIONS_SUGGESTIONS)).isEmpty();
+ }
+
+ @Test
+ public void downloaderModelsLister_checkModelFileManager() throws IOException {
+ File annotatorFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
+
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
+ assertThat(modelFileManager.listModelFiles(MODEL_TYPE))
+ .contains(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
+ }
+
+ @Test
+ public void downloaderModelsLister_disabled() throws IOException {
+ File annotatorFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
+
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
+ DownloaderModelsLister downloaderModelsLister =
+ new DownloaderModelsLister(modelDownloadManager, settings);
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
+ assertThat(downloaderModelsLister.list(MODEL_TYPE)).isEmpty();
+ }
+
+ @Test
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
+
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ }
+
+ @Test
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).isAsset).isFalse();
+ assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
+ .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
+ }
+
+ @Test
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
+ }
+
+ private ModelFileManager createModelFileManager(ModelFile... modelFiles) {
+ return new ModelFileManagerImpl(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)),
+ settings);
+ }
+
+ private ModelFile createModelFile(String supportedLocaleTags, int version) {
+ return new ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version))
+ .getAbsolutePath(),
+ version,
+ supportedLocaleTags,
+ /* isAsset= */ false);
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 5c1d95e..a19e3ff 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,12 +16,10 @@
package com.android.textclassifier;
-import android.content.Context;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
-import com.google.common.collect.ImmutableList;
import java.io.File;
+import java.io.IOException;
/** Utils to access test data files. */
public final class TestDataUtils {
@@ -30,7 +28,7 @@
private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
/** Returns the root folder that contains the test data. */
- public static File getTestDataFolder() {
+ private static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
@@ -38,23 +36,25 @@
return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
}
+ public static ModelFile getTestAnnotatorModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getTestAnnotatorModelFile(), ModelType.ANNOTATOR);
+ }
+
public static File getTestActionsModelFile() {
return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
}
+ public static ModelFile getTestActionsModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(
+ getTestActionsModelFile(), ModelType.ACTIONS_SUGGESTIONS);
+ }
+
public static File getLangIdModelFile() {
return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
}
- public static ModelFileManager createModelFileManagerForTesting(Context context) {
- return new ModelFileManager(
- context,
- ImmutableList.of(
- new RegularFileFullMatchLister(
- ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
- new RegularFileFullMatchLister(
- ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
- new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)));
+ public static ModelFile getLangIdModelFileWrapped() throws IOException {
+ return ModelFile.createFromRegularFile(getLangIdModelFile(), ModelType.LANG_ID);
}
private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestingUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestingUtils.java
new file mode 100644
index 0000000..12924fe
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestingUtils.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.Binder;
+import android.os.Parcel;
+import android.view.textclassifier.TextClassificationSessionId;
+
+/** Utils class for helper functions to use in tests. */
+public final class TestingUtils {
+
+ /** Used a hack to create TextClassificationSessionId because its constructor is @hide. */
+ public static TextClassificationSessionId createTextClassificationSessionId(String sessionId) {
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(sessionId);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
+
+ private TestingUtils() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 27ea7f0..e7bf90c 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -18,31 +18,24 @@
import static com.google.common.truth.Truth.assertThat;
-import android.app.UiAutomation;
-import android.content.pm.PackageManager;
-import android.content.pm.PackageManager.NameNotFoundException;
import android.icu.util.ULocale;
-import android.provider.DeviceConfig;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.TextClassification;
-import android.view.textclassifier.TextClassificationManager;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextLinks.TextLink;
import android.view.textclassifier.TextSelection;
-import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import androidx.test.platform.app.InstrumentationRegistry;
+import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.rules.ExternalResource;
import org.junit.runner.RunWith;
/**
@@ -63,6 +56,10 @@
@Before
public void setup() {
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
+
textClassifier = extServicesTextClassifierRule.getTextClassifier();
}
@@ -88,8 +85,8 @@
@Test
public void classifyText() {
- String text = "Contact me at droid@android.com";
- String classifiedText = "droid@android.com";
+ String text = "Contact me at http://www.android.com";
+ String classifiedText = "http://www.android.com";
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
@@ -97,7 +94,7 @@
TextClassification classification = textClassifier.classifyText(request);
assertThat(classification.getEntityCount()).isGreaterThan(0);
- assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
assertThat(classification.getText()).isEqualTo(classifiedText);
assertThat(classification.getActions()).isNotEmpty();
}
@@ -146,67 +143,4 @@
assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
assertThat(conversationAction.getAction()).isNotNull();
}
-
- /** A rule that manages a text classifier that is backed by the ExtServices. */
- private static class ExtServicesTextClassifierRule extends ExternalResource {
- private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
- "textclassifier_service_package_override";
- private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
- private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
-
- private String textClassifierServiceOverrideFlagOldValue;
-
- @Override
- protected void before() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- textClassifierServiceOverrideFlagOldValue =
- DeviceConfig.getString(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- null);
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- getExtServicesPackageName(),
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
- }
-
- @Override
- protected void after() {
- UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
- try {
- uiAutomation.adoptShellPermissionIdentity();
- DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
- textClassifierServiceOverrideFlagOldValue,
- /* makeDefault= */ false);
- } finally {
- uiAutomation.dropShellPermissionIdentity();
- }
- }
-
- private static String getExtServicesPackageName() {
- PackageManager packageManager =
- ApplicationProvider.getApplicationContext().getPackageManager();
- try {
- packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
- return PKG_NAME_GOOGLE_EXTSERVICES;
- } catch (NameNotFoundException e) {
- return PKG_NAME_AOSP_EXTSERVICES;
- }
- }
-
- public TextClassifier getTextClassifier() {
- TextClassificationManager textClassificationManager =
- ApplicationProvider.getApplicationContext()
- .getSystemService(TextClassificationManager.class);
- return textClassificationManager.getTextClassifier();
- }
- }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 81aa832..c20ec8a 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -18,9 +18,14 @@
import static com.google.common.truth.Truth.assertThat;
import static org.hamcrest.CoreMatchers.not;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import static org.testng.Assert.expectThrows;
import android.app.RemoteAction;
@@ -38,12 +43,17 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.collection.LruCache;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SdkSuppress;
import androidx.test.filters.SmallTest;
-import com.android.textclassifier.common.ModelFileManager;
+import com.android.textclassifier.common.ModelFile;
+import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.testing.FakeContextBuilder;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.android.textclassifier.AnnotatorModel;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
@@ -56,6 +66,8 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
@SmallTest
@RunWith(AndroidJUnit4.class)
@@ -65,19 +77,34 @@
private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
private static final String NO_TYPE = null;
+ @Mock private ModelFileManager modelFileManager;
+
+ private Context context;
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+ private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
private TextClassifierImpl classifier;
- private final ModelFileManager modelFileManager =
- TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
@Before
- public void setup() {
- Context context =
+ public void setup() throws IOException {
+ MockitoAnnotations.initMocks(this);
+ this.context =
new FakeContextBuilder()
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
- TextClassifierSettings settings = new TextClassifierSettings();
- classifier = new TextClassifierImpl(context, settings, modelFileManager);
+ this.deviceConfig = new TestingDeviceConfig();
+ this.settings = new TextClassifierSettings(deviceConfig);
+ this.annotatorModelCache = new LruCache<>(2);
+ this.classifier =
+ new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
+
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
+ .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
+ when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
+ .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
}
@Test
@@ -90,9 +117,7 @@
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(
@@ -100,6 +125,24 @@
}
@Test
+ public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
+ String text = "Contact me at droid@android.com";
+ String selected = "droid";
+ String suggested = "droid@android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ classifier.suggestSelection(null, null, request);
+ verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
+ }
+
+ @Test
public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
@@ -109,9 +152,7 @@
int smartStartIndex = text.indexOf(suggested);
int smartEndIndex = smartStartIndex + suggested.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
@@ -124,14 +165,45 @@
int startIndex = text.indexOf(selected);
int endIndex = startIndex + selected.length();
TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
TextSelection selection = classifier.suggestSelection(null, null, request);
assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
}
+ @SdkSuppress(minSdkVersion = 31, codeName = "S")
+ @Test
+ public void testSuggestSelection_includeTextClassification() throws IOException {
+ String text = "Visit http://www.android.com for more information";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(suggested);
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1)
+ .setIncludeTextClassification(true)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(null, null, request);
+
+ assertThat(
+ selection.getTextClassification(),
+ isTextClassification(suggested, TextClassifier.TYPE_URL));
+ assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW));
+ }
+
+ @SdkSuppress(minSdkVersion = 31, codeName = "S")
+ @Test
+ public void testSuggestSelection_notIncludeTextClassification() throws IOException {
+ String text = "Visit http://www.android.com for more information";
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4)
+ .setIncludeTextClassification(false)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(null, null, request);
+
+ assertThat(selection.getTextClassification()).isNull();
+ }
+
@Test
public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
@@ -139,9 +211,7 @@
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification =
classifier.classifyText(/* sessionId= */ null, null, request);
@@ -155,9 +225,7 @@
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -168,9 +236,7 @@
public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
- new TextClassification.Request.Builder(text, 0, text.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, 0, text.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
@@ -183,9 +249,7 @@
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
@@ -199,9 +263,7 @@
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
@@ -220,9 +282,7 @@
int startIndex = text.indexOf(classifiedText);
int endIndex = startIndex + classifiedText.length();
TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
TextClassification classification = classifier.classifyText(null, null, request);
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
@@ -234,14 +294,12 @@
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
TextClassification.Request request =
- new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
- .setDefaultLocales(LOCALES)
- .build();
+ new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
TextClassification classification = classifier.classifyText(null, null, request);
RemoteAction translateAction = classification.getActions().get(0);
assertEquals(1, classification.getActions().size());
- assertEquals("Translate", translateAction.getTitle().toString());
+ assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
@@ -268,18 +326,17 @@
@Test
public void testGenerateLinks_exclude() throws IOException {
- String text = "You want apple@banana.com. See you tonight!";
+ String text = "The number is +12122537077. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
- List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
- not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
}
@Test
@@ -289,7 +346,6 @@
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -306,7 +362,6 @@
TextLinks.Request request =
new TextLinks.Request.Builder(text)
.setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
.build();
assertThat(
classifier.generateLinks(null, null, request),
@@ -509,6 +564,135 @@
assertThat(conversationActions.getConversationActions()).isEmpty();
}
+ @Test
+ public void testUseCachedAnnotatorModelDisabled() throws IOException {
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+
+ String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
+ ModelFile annotatorModelA =
+ new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
+ ModelFile annotatorModelB =
+ new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
+
+ String englishText = "You can reach me on +12122537077.";
+ String classifiedText = "+12122537077";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
+
+ // Check modelFileA v701
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationA = classifier.classifyText(null, null, request);
+
+ assertThat(classificationA.getId()).contains("v701");
+ assertThat(classificationA.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {0, 0, 0, 0},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+
+ // Check modelFileB v801
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
+
+ assertThat(classificationB.getId()).contains("v801");
+ assertThat(classificationB.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {0, 0, 0, 0},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+
+ // Reload modelFileA v701
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
+
+ assertThat(classificationAcached.getId()).contains("v701");
+ assertThat(classificationAcached.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {0, 0, 0, 0},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+ }
+
+ @Test
+ public void testUseCachedAnnotatorModelEnabled() throws IOException {
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true);
+
+ String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
+ ModelFile annotatorModelA =
+ new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
+ ModelFile annotatorModelB =
+ new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
+
+ String englishText = "You can reach me on +12122537077.";
+ String classifiedText = "+12122537077";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
+
+ // Check modelFileA v701
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classification = classifier.classifyText(null, null, request);
+
+ assertThat(classification.getId()).contains("v701");
+ assertThat(classification.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {1, 0, 0, 1},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+
+ // Check modelFileB v801
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelB);
+ TextClassification classificationB = classifier.classifyText(null, null, request);
+
+ assertThat(classificationB.getId()).contains("v801");
+ assertThat(classificationB.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {2, 0, 0, 2},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+
+ // Reload modelFileA v701
+ when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
+ .thenReturn(annotatorModelA);
+ TextClassification classificationAcached = classifier.classifyText(null, null, request);
+
+ assertThat(classificationAcached.getId()).contains("v701");
+ assertThat(classificationAcached.getText()).contains(classifiedText);
+ assertArrayEquals(
+ new int[] {2, 0, 1, 2},
+ new int[] {
+ annotatorModelCache.putCount(),
+ annotatorModelCache.evictionCount(),
+ annotatorModelCache.hitCount(),
+ annotatorModelCache.missCount()
+ });
+ }
+
private static void assertNoPackageInfoInExtras(Intent intent) {
assertThat(intent.getComponent()).isNull();
assertThat(intent.getPackage()).isNull();
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
deleted file mode 100644
index 40838ac..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
+++ /dev/null
@@ -1,507 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.common;
-
-import static com.android.textclassifier.common.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
-import static com.google.common.truth.Truth.assertThat;
-
-import android.os.LocaleList;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.test.filters.SmallTest;
-import com.android.textclassifier.TestDataUtils;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
-import com.android.textclassifier.common.ModelFileManager.RegularFilePatternMatchLister;
-import com.android.textclassifier.common.ModelType.ModelTypeDef;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.android.textclassifier.testing.SetDefaultLocalesRule;
-import com.google.common.base.Optional;
-import com.google.common.collect.ImmutableList;
-import com.google.common.io.Files;
-import java.io.File;
-import java.io.IOException;
-import java.util.List;
-import java.util.Locale;
-import java.util.stream.Collectors;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public final class ModelFileManagerTest {
- private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
-
- @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
-
- @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
-
- @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
-
- private File rootTestDir;
- private ModelFileManager modelFileManager;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- rootTestDir =
- new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
- rootTestDir.mkdirs();
- modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- new TextClassifierSettings(mockDeviceConfig));
- }
-
- @After
- public void removeTestDir() {
- recursiveDelete(rootTestDir);
- }
-
- @Test
- public void annotatorModelPreloaded() {
- verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
- }
-
- @Test
- public void actionsModelPreloaded() {
- verifyModelPreloadedAsAsset(
- ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
- }
-
- @Test
- public void langIdModelPreloaded() {
- verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
- }
-
- private void verifyModelPreloadedAsAsset(
- @ModelTypeDef String modelType, String expectedModelPath) {
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
- List<ModelFile> assetFiles =
- modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
-
- assertThat(assetFiles).hasSize(1);
- assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
- }
-
- @Test
- public void findBestModel_versionCode() {
- ModelFileManager.ModelFile olderModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile newerModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
-
- ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
- assertThat(bestModelFile).isEqualTo(newerModelFile);
- }
-
- @Test
- public void findBestModel_languageDependentModelIsPreferred() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- DEFAULT_LOCALE.toLanguageTag(),
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType ->
- ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
- ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
- assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- DEFAULT_LOCALE.toLanguageTag(),
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType ->
- ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion() {
- ModelFileManager.ModelFile matchButOlderModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- "fr",
- /* isAsset= */ false);
- ModelFileManager.ModelFile mismatchButNewerModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 1,
- "ja",
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
- assertThat(bestModelFile).isEqualTo(matchButOlderModel);
- }
-
- @Test
- public void findBestModel_preferMatchedLocaleModel() {
- ModelFileManager.ModelFile matchLocaleModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- "ja",
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageIndependentModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
-
- assertThat(bestModelFile).isEqualTo(matchLocaleModel);
- }
-
- @Test
- public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isFalse();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- model1.getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- model2.getAbsolutePath(),
- /* version= */ 2,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isFalse();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- assertThat(model2.exists()).isFalse();
- }
-
- @Test
- public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
- setDefaultLocalesRule.set(
- new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
- File readOnlyDir = new File(rootTestDir, "read_only/");
- readOnlyDir.mkdirs();
- File model1 = new File(readOnlyDir, "model1.fb");
- model1.createNewFile();
- readOnlyDir.setWritable(false);
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- }
-
- @Test
- public void modelFileEquals() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA).isEqualTo(modelB);
- }
-
- @Test
- public void modelFile_different() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA).isNotEqualTo(modelB);
- }
-
- @Test
- public void modelFile_isPreferredTo_languageDependentIsBetter() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_isPreferredTo_version() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_toModelInfo() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ModelInfo modelInfo = modelFile.toModelInfo();
-
- assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
- }
-
- @Test
- public void modelFile_toModelInfos() {
- ModelFile englishModelFile =
- new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
- ModelFile japaneseModelFile =
- new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ImmutableList<Optional<ModelInfo>> modelInfos =
- ModelFileManager.ModelFile.toModelInfos(
- Optional.of(englishModelFile), Optional.of(japaneseModelFile));
-
- assertThat(
- modelInfos.stream()
- .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
- .collect(Collectors.toList()))
- .containsExactly("en_v1", "ja_v2")
- .inOrder();
- }
-
- @Test
- public void regularFileFullMatchLister() throws IOException {
- File modelFile = new File(rootTestDir, "test.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
- File wrongFile = new File(rootTestDir, "wrong.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
-
- RegularFileFullMatchLister regularFileFullMatchLister =
- new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
- ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).hasSize(1);
- assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
- assertThat(listedModels.get(0).isAsset).isFalse();
- }
-
- @Test
- public void regularFilePatternMatchLister() throws IOException {
- File modelFile1 = new File(rootTestDir, "annotator.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
- File modelFile2 = new File(rootTestDir, "annotator.fr.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
- File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
-
- RegularFilePatternMatchLister regularFilePatternMatchLister =
- new RegularFilePatternMatchLister(
- MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
- ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).hasSize(2);
- assertThat(listedModels.get(0).isAsset).isFalse();
- assertThat(listedModels.get(1).isAsset).isFalse();
- assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
- .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
- }
-
- @Test
- public void regularFilePatternMatchLister_disabled() throws IOException {
- File modelFile1 = new File(rootTestDir, "annotator.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
-
- RegularFilePatternMatchLister regularFilePatternMatchLister =
- new RegularFilePatternMatchLister(
- MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
- ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).isEmpty();
- }
-
- private static void recursiveDelete(File f) {
- if (f.isDirectory()) {
- for (File innerFile : f.listFiles()) {
- recursiveDelete(innerFile);
- }
- }
- f.delete();
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileTest.java
new file mode 100644
index 0000000..75eb4cd
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import static com.android.textclassifier.common.ModelFile.LANGUAGE_INDEPENDENT;
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import java.util.stream.Collectors;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ModelFileTest {
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ @Test
+ public void modelFileEquals() {
+ ModelFile modelA =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFile modelB =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFile modelA =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFile modelB =
+ new ModelFile(MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFile modelA =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFile modelB =
+ new ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFile modelA =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+ ModelFile modelB =
+ new ModelFile(MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_toModelInfo() {
+ ModelFile modelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ModelInfo modelInfo = modelFile.toModelInfo();
+
+ assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
+ }
+
+ @Test
+ public void modelFile_toModelInfo_universal() {
+ ModelFile modelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "*", /* isAsset= */ false);
+
+ ModelInfo modelInfo = modelFile.toModelInfo();
+
+ assertThat(modelInfo.toModelName()).isEqualTo("*_v2");
+ }
+
+ @Test
+ public void modelFile_toModelInfos() {
+ ModelFile englishModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
+ ModelFile japaneseModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ImmutableList<Optional<ModelInfo>> modelInfos =
+ ModelFile.toModelInfos(Optional.of(englishModelFile), Optional.of(japaneseModelFile));
+
+ assertThat(
+ modelInfos.stream()
+ .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
+ .collect(Collectors.toList()))
+ .containsExactly("en_v1", "ja_v2")
+ .inOrder();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java
index 8072d72..17aef84 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/TextClassifierSettingsTest.java
@@ -95,29 +95,7 @@
}
@Test
- public void getManifestURLSetting() {
- assertSettings(
- "manifest_url_annotator_en",
- "https://annotator",
- settings ->
- assertThat(settings.getManifestURL(ModelType.ANNOTATOR, "en"))
- .isEqualTo("https://annotator"));
- assertSettings(
- "manifest_url_lang_id_universal",
- "https://lang_id",
- settings ->
- assertThat(settings.getManifestURL(ModelType.LANG_ID, "universal"))
- .isEqualTo("https://lang_id"));
- assertSettings(
- "manifest_url_actions_suggestions_zh",
- "https://actions_suggestions",
- settings ->
- assertThat(settings.getManifestURL(ModelType.ACTIONS_SUGGESTIONS, "zh"))
- .isEqualTo("https://actions_suggestions"));
- }
-
- @Test
- public void getLanguageTagsForManifestURL() {
+ public void getLanguageTagsForManifestAndUrlMap() {
assertSettings(
ImmutableMap.of(
"manifest_url_annotator_en", "https://annotator-en",
@@ -125,8 +103,12 @@
"manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
"manifest_url_lang_id_universal", "https://lang_id"),
settings ->
- assertThat(settings.getLanguageTagsForManifestURL(ModelType.ANNOTATOR))
- .containsExactly("en", "en-us", "zh-hant-hk"));
+ assertThat(settings.getLanguageTagAndManifestUrlMap(ModelType.ANNOTATOR))
+ .containsExactlyEntriesIn(
+ ImmutableMap.of(
+ "en", "https://annotator-en",
+ "en-us", "https://annotator-en-us",
+ "zh-hant-hk", "https://annotator-zh")));
assertSettings(
ImmutableMap.of(
@@ -135,8 +117,8 @@
"manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
"manifest_url_lang_id_universal", "https://lang_id"),
settings ->
- assertThat(settings.getLanguageTagsForManifestURL(ModelType.LANG_ID))
- .containsExactly("universal"));
+ assertThat(settings.getLanguageTagAndManifestUrlMap(ModelType.LANG_ID))
+ .containsExactlyEntriesIn(ImmutableMap.of("universal", "https://lang_id")));
assertSettings(
ImmutableMap.of(
@@ -145,7 +127,7 @@
"manifest_url_annotator_zh-hant-hk", "https://annotator-zh",
"manifest_url_lang_id_universal", "https://lang_id"),
settings ->
- assertThat(settings.getLanguageTagsForManifestURL(ModelType.ACTIONS_SUGGESTIONS))
+ assertThat(settings.getLanguageTagAndManifestUrlMap(ModelType.ACTIONS_SUGGESTIONS))
.isEmpty());
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/TextSelectionCompatTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/TextSelectionCompatTest.java
new file mode 100644
index 0000000..3314fc3
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/TextSelectionCompatTest.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextSelection;
+import androidx.test.filters.SdkSuppress;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class TextSelectionCompatTest {
+
+ @SdkSuppress(minSdkVersion = 30)
+ @Test
+ public void shouldIncludeTextClassification_negative() {
+ TextSelection.Request request =
+ new TextSelection.Request.Builder("text", /*startIndex=*/ 0, /*endIndex=*/ 1).build();
+
+ assertThat(TextSelectionCompat.shouldIncludeTextClassification(request)).isFalse();
+ }
+
+ @SdkSuppress(minSdkVersion = 31, codeName = "S")
+ @Test
+ public void shouldIncludeTextClassification_positive() {
+ TextSelection.Request request =
+ new TextSelection.Request.Builder("text", /*startIndex=*/ 0, /*endIndex=*/ 1)
+ .setIncludeTextClassification(true)
+ .build();
+
+ assertThat(TextSelectionCompat.shouldIncludeTextClassification(request)).isTrue();
+ }
+
+ @SdkSuppress(minSdkVersion = 30, maxSdkVersion = 30)
+ @Test
+ public void setTextClassification_api30() {
+ TextSelection.Builder selectionBuilder =
+ new TextSelection.Builder(/*startIndex=*/ 0, /*endIndex=*/ 1);
+
+ // This should not crash.
+ TextSelectionCompat.setTextClassification(selectionBuilder, null);
+ }
+
+ @SdkSuppress(minSdkVersion = 31)
+ @Test
+ public void setTextClassification_api31() {
+ TextSelection.Builder selectionBuilder =
+ new TextSelection.Builder(/*startIndex=*/ 0, /*endIndex=*/ 1);
+ TextClassification classification = new TextClassification.Builder().setText("text").build();
+
+ TextSelectionCompat.setTextClassification(selectionBuilder, classification);
+
+ assertThat(selectionBuilder.build().getTextClassification()).isSameInstanceAs(classification);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index 216cd5d..3aab211 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -27,14 +27,18 @@
import com.google.android.textclassifier.RemoteActionTemplate;
import java.util.List;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class TemplateIntentFactoryTest {
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final String TITLE_WITHOUT_ENTITY = "Map";
private static final String TITLE_WITH_ENTITY = "Map NW14D1";
private static final String DESCRIPTION = "Check the map";
@@ -71,7 +75,6 @@
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
templateIntentFactory = new TemplateIntentFactory();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index ffd2ee4..3a8fefc 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -86,8 +86,8 @@
return ImmutableList.copyOf(
metricsList.stream()
.flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream())
- .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric(
- eventMetricData).stream())
+ .flatMap(
+ eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream())
.sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos))
.map(EventMetricData::getAtom)
.collect(Collectors.toList()));
@@ -136,7 +136,7 @@
}
private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric(
- EventMetricData metricData) {
+ EventMetricData metricData) {
if (metricData.hasAtom()) {
return ImmutableList.of(metricData);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierDownloadLoggerTestRule.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierDownloadLoggerTestRule.java
new file mode 100644
index 0000000..9c49cb1
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierDownloadLoggerTestRule.java
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import android.util.Log;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextClassifierDownloadReported;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkCompleted;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled;
+import com.google.common.collect.ImmutableList;
+import java.util.stream.Collectors;
+import org.junit.rules.ExternalResource;
+
+// TODO(licha): Make this generic and useful for other atoms.
+/** Test rule to set up/clean up statsd for download logger tests. */
+public final class TextClassifierDownloadLoggerTestRule extends ExternalResource {
+ private static final String TAG = "DownloadLoggerTestRule";
+
+ // Statsd config IDs, which are arbitrary.
+ private static final long CONFIG_ID_DOWNLOAD_REPORTED = 423779;
+ private static final long CONFIG_ID_DOWNLOAD_WORK_SCHEDULED = 42;
+ private static final long CONFIG_ID_DOWNLOAD_WORK_COMPLETED = 2021;
+
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
+ @Override
+ public void before() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_REPORTED);
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_WORK_SCHEDULED);
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_WORK_COMPLETED);
+
+ StatsdConfig.Builder builder1 =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID_DOWNLOAD_REPORTED)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder1, Atom.TEXT_CLASSIFIER_DOWNLOAD_REPORTED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder1.build());
+
+ StatsdConfig.Builder builder2 =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID_DOWNLOAD_WORK_SCHEDULED)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(
+ builder2, Atom.TEXT_CLASSIFIER_DOWNLOAD_WORK_SCHEDULED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder2.build());
+
+ StatsdConfig.Builder builder3 =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID_DOWNLOAD_WORK_COMPLETED)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(
+ builder3, Atom.TEXT_CLASSIFIER_DOWNLOAD_WORK_COMPLETED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder3.build());
+ }
+
+ @Override
+ public void after() {
+ try {
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_REPORTED);
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_WORK_SCHEDULED);
+ StatsdTestUtils.cleanup(CONFIG_ID_DOWNLOAD_WORK_COMPLETED);
+ } catch (Exception e) {
+ Log.e(TAG, "Failed to clean up statsd after tests.");
+ }
+ }
+
+ /**
+ * Gets a list of TextClassifierDownloadReported atoms written into statsd, sorted by increasing
+ * timestamp.
+ */
+ public ImmutableList<TextClassifierDownloadReported> getLoggedDownloadReportedAtoms()
+ throws Exception {
+ ImmutableList<Atom> loggedAtoms =
+ StatsdTestUtils.getLoggedAtoms(CONFIG_ID_DOWNLOAD_REPORTED, SHORT_TIMEOUT_MS);
+ return ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .filter(Atom::hasTextClassifierDownloadReported)
+ .map(Atom::getTextClassifierDownloadReported)
+ .collect(Collectors.toList()));
+ }
+
+ /**
+ * Gets a list of TextClassifierDownloadWorkScheduled atoms written into statsd, sorted by
+ * increasing timestamp.
+ */
+ public ImmutableList<TextClassifierDownloadWorkScheduled> getLoggedDownloadWorkScheduledAtoms()
+ throws Exception {
+ ImmutableList<Atom> loggedAtoms =
+ StatsdTestUtils.getLoggedAtoms(CONFIG_ID_DOWNLOAD_WORK_SCHEDULED, SHORT_TIMEOUT_MS);
+ return ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .filter(Atom::hasTextClassifierDownloadWorkScheduled)
+ .map(Atom::getTextClassifierDownloadWorkScheduled)
+ .collect(Collectors.toList()));
+ }
+
+ /**
+ * Gets a list of TextClassifierDownloadWorkCompleted atoms written into statsd, sorted by
+ * increasing timestamp.
+ */
+ public ImmutableList<TextClassifierDownloadWorkCompleted> getLoggedDownloadWorkCompletedAtoms()
+ throws Exception {
+ ImmutableList<Atom> loggedAtoms =
+ StatsdTestUtils.getLoggedAtoms(CONFIG_ID_DOWNLOAD_WORK_COMPLETED, SHORT_TIMEOUT_MS);
+ return ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .filter(Atom::hasTextClassifierDownloadWorkCompleted)
+ .map(Atom::getTextClassifierDownloadWorkCompleted)
+ .collect(Collectors.toList()));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
new file mode 100644
index 0000000..835f50b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
@@ -0,0 +1,398 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.List;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+public class DownloadedModelDatabaseTest {
+ private static final String MODEL_URL = "https://model.url";
+ private static final String MODEL_URL_2 = "https://model2.url";
+ private static final String MODEL_PATH = "/data/test.model";
+ private static final String MODEL_PATH_2 = "/data/test.model2";
+ private static final String MANIFEST_URL = "https://manifest.url";
+ private static final String MANIFEST_URL_2 = "https://manifest2.url";
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
+ private static final String LOCALE_TAG = "zh";
+
+ private DownloadedModelDatabase db;
+
+ @Before
+ public void createDb() {
+ Context context = ApplicationProvider.getApplicationContext();
+ db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ }
+
+ @After
+ public void closeDb() throws IOException {
+ db.close();
+ }
+
+ @Test
+ public void insertModelAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ List<Model> models = db.dao().queryAllModels();
+ assertThat(models).containsExactly(model);
+ }
+
+ @Test
+ public void insertModelAndDelete() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ db.dao().deleteModels(ImmutableList.of(model));
+ List<Model> models = db.dao().queryAllModels();
+ assertThat(models).isEmpty();
+ }
+
+ @Test
+ public void insertManifestAndRead() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ List<Manifest> manifests = db.dao().queryAllManifests();
+ assertThat(manifests).containsExactly(manifest);
+ }
+
+ @Test
+ public void insertManifestAndDelete() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().deleteManifests(ImmutableList.of(manifest));
+ List<Manifest> manifests = db.dao().queryAllManifests();
+ assertThat(manifests).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).containsExactly(manifestModelCrossRef);
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDelete() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteManifestModelCrossRefs(ImmutableList.of(manifestModelCrossRef));
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDeleteManifest() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteManifests(ImmutableList.of(manifest)); // ON CASCADE
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDeleteModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteModels(ImmutableList.of(model)); // ON CASCADE
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefWithoutManifest() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+ }
+
+ @Test
+ public void insertManifestModelCrossRefWithoutModel() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndRead() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).containsExactly(manifestEnrollment);
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndDelete() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ db.dao().deleteManifestEnrollments(ImmutableList.of(manifestEnrollment));
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).isEmpty();
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndDeleteManifest() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ db.dao().deleteManifests(ImmutableList.of(manifest));
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).isEmpty();
+ }
+
+ @Test
+ public void insertManifestEnrollmentWithoutManifest() throws Exception {
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestEnrollment));
+ }
+
+ @Test
+ public void insertModelViewAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ List<ModelView> modelViews = db.dao().queryAllModelViews();
+ ModelView modelView = Iterables.getOnlyElement(modelViews);
+ assertThat(modelView.getManifestEnrollment()).isEqualTo(manifestEnrollment);
+ assertThat(modelView.getModel()).isEqualTo(model);
+ }
+
+ @Test
+ public void queryModelWithModelUrl() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+
+ assertThat(db.dao().queryModelWithModelUrl(MODEL_URL)).containsExactly(model);
+ assertThat(db.dao().queryModelWithModelUrl(MODEL_URL_2)).containsExactly(model2);
+ }
+
+ @Test
+ public void queryManifestWithManifestUrl() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest2);
+
+ assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL)).containsExactly(manifest);
+ assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL_2)).containsExactly(manifest2);
+ }
+
+ @Test
+ public void queryManifestEnrollmentWithModelTypeAndLocaleTag() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ ManifestEnrollment manifestEnrollment2 =
+ ManifestEnrollment.create(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ db.dao().insert(manifestEnrollment2);
+
+ assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE, LOCALE_TAG))
+ .containsExactly(manifestEnrollment);
+ assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE_2, LOCALE_TAG))
+ .containsExactly(manifestEnrollment2);
+ }
+
+ @Test
+ public void insertManifestAndModelCrossRef() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+
+ @Test
+ public void increaseManifestFailureCounts() throws Exception {
+ db.dao().increaseManifestFailureCounts(MODEL_URL);
+ Manifest manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(1);
+ db.dao().increaseManifestFailureCounts(MODEL_URL);
+ manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(2);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedManifestAndUnusedModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL_2);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedManifestAndSharedModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_failedManifest() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedModels() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestFailureRecords() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest2);
+
+ db.dao().deleteUnusedManifestFailureRecords(ImmutableList.of(MANIFEST_URL));
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
new file mode 100644
index 0000000..5ff4d89
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
@@ -0,0 +1,372 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class DownloadedModelManagerImplTest {
+
+ private File modelDownloaderDir;
+ private DownloadedModelDatabase db;
+ private DownloadedModelManagerImpl downloadedModelManagerImpl;
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ modelDownloaderDir = new File(context.getFilesDir(), "test_dir");
+ modelDownloaderDir.mkdirs();
+ deviceConfig = new TestingDeviceConfig();
+ settings = new TextClassifierSettings(deviceConfig);
+ db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ downloadedModelManagerImpl =
+ DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+ }
+
+ @After
+ public void cleanUp() {
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
+ db.close();
+ }
+
+ @Test
+ public void getModelDownloaderDir() throws Exception {
+ modelDownloaderDir.delete();
+ assertThat(downloadedModelManagerImpl.getModelDownloaderDir().exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.getModelDownloaderDir()).isEqualTo(modelDownloaderDir);
+ }
+
+ @Test
+ public void listModels_cacheNotInitialized() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"), new File("modelPathZh"));
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.LANG_ID)).isEmpty();
+ }
+
+ @Test
+ public void listModels_doNotListBlockedModels() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+ deviceConfig.setConfig(
+ TextClassifierSettings.MODEL_URL_BLOCKLIST,
+ String.format(
+ "%s%s%s",
+ "modelUrlEn", TextClassifierSettings.MODEL_URL_BLOCKLIST_SEPARATOR, "modelUrlXX"));
+
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathZh"));
+ }
+
+ @Test
+ public void listModels_cacheNotUpdatedUnlessOnDownloadCompleted() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"));
+
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"));
+
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("zh", "manifestUrlZh")));
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .contains(new File("modelPathZh"));
+ }
+
+ @Test
+ public void getModel() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+ .isEqualTo("modelPath");
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl2")).isNull();
+ }
+
+ @Test
+ public void getManifest() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+ }
+
+ @Test
+ public void getManifestEnrollment() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(
+ downloadedModelManagerImpl
+ .getManifestEnrollment(ModelType.ANNOTATOR, "en")
+ .getManifestUrl())
+ .isEqualTo("manifestUrl");
+ assertThat(downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "zh"))
+ .isNull();
+ }
+
+ @Test
+ public void registerModel() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+ .isEqualTo("modelPath");
+ }
+
+ @Test
+ public void registerManifest() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+ downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+ }
+
+ @Test
+ public void registerManifestDownloadFailure() throws Exception {
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl");
+
+ Manifest manifest = downloadedModelManagerImpl.getManifest("manifestUrl");
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(1);
+ }
+
+ @Test
+ public void registerManifestEnrollment() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+ downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+ downloadedModelManagerImpl.registerManifestEnrollment(ModelType.ANNOTATOR, "en", "manifestUrl");
+
+ ManifestEnrollment manifestEnrollment =
+ downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "en");
+ assertThat(manifestEnrollment.getModelType()).isEqualTo(ModelType.ANNOTATOR);
+ assertThat(manifestEnrollment.getLocaleTag()).isEqualTo("en");
+ assertThat(manifestEnrollment.getManifestUrl()).isEqualTo("manifestUrl");
+ }
+
+ @Test
+ public void onDownloadCompleted_newModelDownloaded() throws Exception {
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2")));
+ File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+ modelFile2.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isFalse();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile2);
+ }
+
+ @Test
+ public void onDownloadCompleted_newModelDownloadFailed() throws Exception {
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2")));
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+ }
+
+ @Test
+ public void onDownloadCompleted_flatUnset() throws Exception {
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ manifestsToDownload = ImmutableMap.of();
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isFalse();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)).isEmpty();
+ }
+
+ @Test
+ public void onDownloadCompleted_cleanUpFailureRecords() throws Exception {
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl1");
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl1").getStatus())
+ .isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+ }
+
+ @Test
+ public void onDownloadCompleted_modelsForMultipleLocalesDownloaded() throws Exception {
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(
+ ImmutableMap.of("en", "manifestUrl1", "es", "manifestUrl2")));
+
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+
+ File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+ modelFile2.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1, modelFile2);
+ }
+
+ @Test
+ public void onDownloadCompleted_multipleLocales_oneDownloadFailed() throws Exception {
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(
+ ImmutableMap.of("es", "manifestUrl2", "en", "manifestUrl3")));
+ File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+ modelFile2.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl3");
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1, modelFile2);
+ }
+
+ @Test
+ public void onDownoadCompleted_multipleLocales_replaceOldModel() throws Exception {
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+
+ ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
+ ImmutableMap.of(
+ ModelType.ANNOTATOR,
+ ManifestsToDownloadByType.create(
+ ImmutableMap.of("en", "manifestUrl2", "es", "manifestUrl3")));
+
+ File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+ modelFile2.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+
+ File modelFile3 = new File(modelDownloaderDir, "modelFile3");
+ modelFile3.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "es", "manifestUrl3", "modelUrl3", modelFile3.getAbsolutePath());
+
+ downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(modelFile3.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile2, modelFile3);
+ }
+
+ private void registerManifestToDB(
+ @ModelTypeDef String modelType,
+ String localeTag,
+ String manifestUrl,
+ String modelUrl,
+ String modelPath) {
+ db.dao().insert(Model.create(modelUrl, modelPath));
+ db.dao()
+ .insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+ db.dao().insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+ db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
new file mode 100644
index 0000000..37394e6
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.WorkQuery;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.util.List;
+
+/** Utils for downloader logic testing. */
+final class DownloaderTestUtils {
+
+ public static List<WorkInfo> queryWorkInfos(WorkManager workManager, String queueName)
+ throws Exception {
+ WorkQuery workQuery =
+ WorkQuery.Builder.fromUniqueWorkNames(ImmutableList.of(queueName)).build();
+ return workManager.getWorkInfos(workQuery).get();
+ }
+
+ // MoreFiles#deleteRecursively is not available for Android guava.
+ public static void deleteRecursively(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ deleteRecursively(innerFile);
+ }
+ }
+ f.delete();
+ }
+
+ private DownloaderTestUtils() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
new file mode 100644
index 0000000..a553c51
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableList;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class LocaleUtilsTest {
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Before
+ public void setUp() {
+ deviceConfig = new TestingDeviceConfig();
+ settings = new TextClassifierSettings(deviceConfig);
+ }
+
+ @Test
+ public void lookupBestLocaleTag_simpleMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en"), ImmutableList.of("en", "zh")))
+ .isEqualTo("en");
+ }
+
+ @Test
+ public void lookupBestLocaleTag_noMatch() {
+ assertThat(LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("zh")))
+ .isNull();
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("en-uk")))
+ .isNull();
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en-uk")))
+ .isNull();
+ }
+
+ @Test
+ public void lookupBestLocaleTag_partialMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "zh")))
+ .isEqualTo("en");
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-us")))
+ .isEqualTo("en-us");
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-uk")))
+ .isEqualTo("en");
+ }
+
+ @Test
+ public void lookupBestLocaleTag_universalMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en"),
+ ImmutableList.of("zh", LocaleUtils.UNIVERSAL_LOCALE_TAG)))
+ .isEqualTo(LocaleUtils.UNIVERSAL_LOCALE_TAG);
+ }
+
+ @Test
+ public void lookupBestLocaleTagAndManifestUrl_found() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, "en", "url_1");
+ Pair<String, String> pair =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+ MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+ assertThat(pair.first).isEqualTo("en");
+ assertThat(pair.second).isEqualTo("url_1");
+ }
+
+ @Test
+ public void lookupBestLocaleTagAndManifestUrl_notFound() throws Exception {
+ Pair<String, String> pair =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+ MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+ assertThat(pair).isNull();
+ }
+
+ private void setUpManifestUrl(
+ @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+ String deviceConfigFlag =
+ String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+ deviceConfig.setConfig(deviceConfigFlag, url);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadExceptionTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadExceptionTest.java
new file mode 100644
index 0000000..1e878b2
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadExceptionTest.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloadExceptionTest {
+ private static final int ERROR_CODE = ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER;
+ private static final int DOWNLOADER_LIB_ERROR_CODE = 500;
+
+ @Test
+ public void getErrorCode_constructor1() {
+ ModelDownloadException e = new ModelDownloadException(ERROR_CODE, new Exception());
+ assertThat(e.getErrorCode()).isEqualTo(ERROR_CODE);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ }
+
+ @Test
+ public void getErrorCode_constructor2() {
+ ModelDownloadException e = new ModelDownloadException(ERROR_CODE, "error_msg");
+ assertThat(e.getErrorCode()).isEqualTo(ERROR_CODE);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ }
+
+ @Test
+ public void getErrorCode_constructor3() {
+ ModelDownloadException e =
+ new ModelDownloadException(ERROR_CODE, DOWNLOADER_LIB_ERROR_CODE, "error_msg");
+ assertThat(e.getErrorCode()).isEqualTo(ERROR_CODE);
+ assertThat(e.getDownloaderLibErrorCode()).isEqualTo(DOWNLOADER_LIB_ERROR_CODE);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
new file mode 100644
index 0000000..9e11c09
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.testing.WorkManagerTestInitHelper;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled.ReasonToSchedule;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+@RunWith(AndroidJUnit4.class)
+public final class ModelDownloadManagerTest {
+ private static final String MODEL_PATH = "/data/test.model";
+ @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String LOCALE_TAG = "en";
+ private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+ @Rule
+ public final TextClassifierDownloadLoggerTestRule loggerTestRule =
+ new TextClassifierDownloadLoggerTestRule();
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
+ private TestingDeviceConfig deviceConfig;
+ private WorkManager workManager;
+ private ModelDownloadManager downloadManager;
+ private ModelDownloadManager downloadManagerWithBadWorkManager;
+ @Mock DownloadedModelManager downloadedModelManager;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ WorkManagerTestInitHelper.initializeTestWorkManager(context);
+
+ this.deviceConfig = new TestingDeviceConfig();
+ this.workManager = WorkManager.getInstance(context);
+ this.downloadManager =
+ new ModelDownloadManager(
+ context,
+ ModelDownloadWorker.class,
+ () -> workManager,
+ downloadedModelManager,
+ new TextClassifierSettings(deviceConfig),
+ MoreExecutors.newDirectExecutorService());
+ this.downloadManagerWithBadWorkManager =
+ new ModelDownloadManager(
+ context,
+ ModelDownloadWorker.class,
+ () -> {
+ throw new IllegalStateException("WorkManager may fail!");
+ },
+ downloadedModelManager,
+ new TextClassifierSettings(deviceConfig),
+ MoreExecutors.newDirectExecutorService());
+
+ setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+ }
+
+ @After
+ public void tearDown() {
+ workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME);
+ DownloaderTestUtils.deleteRecursively(
+ ApplicationProvider.getApplicationContext().getFilesDir());
+ }
+
+ @Test
+ public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ downloadManagerWithBadWorkManager.onTextClassifierServiceCreated();
+
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
+ public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ downloadManager.onTextClassifierServiceCreated();
+
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+ verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
+ }
+
+ @Test
+ public void onTextClassifierServiceCreated_localeListOverridden() throws Exception {
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
+ downloadManager.onTextClassifierServiceCreated();
+
+ assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
+ assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
+ verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
+ }
+
+ @Test
+ public void onLocaleChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onLocaleChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
+ public void onLocaleChanged_requestEnqueued() throws Exception {
+ downloadManager.onLocaleChanged();
+
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ verifyWorkScheduledLogging(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
+ }
+
+ @Test
+ public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception {
+ downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged();
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
+ public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
+ downloadManager.onTextClassifierDeviceConfigChanged();
+
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ }
+
+ @Test
+ public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception {
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
+ downloadManager.onTextClassifierDeviceConfigChanged();
+
+ assertThat(
+ DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME))
+ .isEmpty();
+ assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
+ }
+
+ @Test
+ public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception {
+ downloadManager.onTextClassifierDeviceConfigChanged();
+ downloadManager.onTextClassifierDeviceConfigChanged();
+ List<WorkInfo> workInfos =
+ DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME);
+
+ assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList()))
+ .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED);
+ List<TextClassifierDownloadWorkScheduled> atoms =
+ loggerTestRule.getLoggedDownloadWorkScheduledAtoms();
+ assertThat(atoms).hasSize(2);
+ verifyWorkScheduledAtom(atoms.get(0), ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ verifyWorkScheduledAtom(atoms.get(1), ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ }
+
+ @Test
+ public void onTextClassifierDeviceConfigChanged_localeListOverridden() throws Exception {
+ deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
+ downloadManager.onTextClassifierDeviceConfigChanged();
+
+ assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
+ assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
+ verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
+ }
+
+ @Test
+ public void listDownloadedModels() throws Exception {
+ File modelFile = new File(MODEL_PATH);
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile));
+
+ assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
+ }
+
+ @Test
+ public void listDownloadedModels_doNotCrashOnError() throws Exception {
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException());
+
+ assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty();
+ }
+
+ private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception {
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ verifyWorkScheduledAtom(atom, reasonToSchedule);
+ }
+
+ private void verifyWorkScheduledAtom(
+ TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule) {
+ assertThat(atom.getReasonToSchedule()).isEqualTo(reasonToSchedule);
+ assertThat(atom.getFailedToSchedule()).isFalse();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadWorkerTest.java
new file mode 100644
index 0000000..3646934
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadWorkerTest.java
@@ -0,0 +1,790 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import android.os.LocaleList;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkerFactory;
+import androidx.work.WorkerParameters;
+import androidx.work.testing.TestListenableWorkerBuilder;
+import com.android.os.AtomsProto.TextClassifierDownloadReported;
+import com.android.os.AtomsProto.TextClassifierDownloadReported.DownloadStatus;
+import com.android.os.AtomsProto.TextClassifierDownloadReported.FailureReason;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkCompleted;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkCompleted.WorkResult;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.time.Clock;
+import java.time.Instant;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloadWorkerTest {
+ private static final long WORK_ID = 123456789L;
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
+ private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
+ TextClassifierDownloadReported.ModelType.ANNOTATOR;
+ private static final String LOCALE_TAG = "en";
+ private static final String LOCALE_TAG_2 = "zh";
+ private static final String LOCALE_TAG_3 = "es";
+ private static final String MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/q/v711/en.fb.manifest";
+ private static final String MANIFEST_URL_2 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb.manifest";
+ private static final String MANIFEST_URL_3 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/es.fb.manifest";
+ private static final String MODEL_URL =
+ "https://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+ private static final String MODEL_URL_2 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb";
+ private static final String MODEL_URL_3 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/es.fb";
+ private static final int RUN_ATTEMPT_COUNT = 1;
+ private static final int WORKER_MAX_RUN_ATTEMPT_COUNT = 5;
+ private static final int MANIFEST_MAX_ATTEMPT_COUNT = 2;
+ private static final ModelManifest.Model MODEL_PROTO =
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL)
+ .setSizeInBytes(1)
+ .setFingerprint("fingerprint")
+ .build();
+ private static final ModelManifest.Model MODEL_PROTO_2 =
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL_2)
+ .setSizeInBytes(1)
+ .setFingerprint("fingerprint")
+ .build();
+ private static final ModelManifest.Model MODEL_PROTO_3 =
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL_3)
+ .setSizeInBytes(1)
+ .setFingerprint("fingerprint")
+ .build();
+
+ private static final ModelManifest MODEL_MANIFEST_PROTO =
+ ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
+ private static final ModelManifest MODEL_MANIFEST_PROTO_2 =
+ ModelManifest.newBuilder().addModels(MODEL_PROTO_2).build();
+ private static final ModelManifest MODEL_MANIFEST_PROTO_3 =
+ ModelManifest.newBuilder().addModels(MODEL_PROTO_3).build();
+ private static final ModelDownloadException FAILED_TO_DOWNLOAD_EXCEPTION =
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER, "failed to download");
+ private static final FailureReason FAILED_TO_DOWNLOAD_FAILURE_REASON =
+ TextClassifierDownloadReported.FailureReason.FAILED_TO_DOWNLOAD_OTHER;
+ private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
+ private static final LocaleList LOCALE_LIST_2 =
+ new LocaleList(new Locale(LOCALE_TAG), new Locale(LOCALE_TAG_2));
+ private static final LocaleList LOCALE_LIST_3 =
+ new LocaleList(new Locale(LOCALE_TAG), new Locale(LOCALE_TAG_2), new Locale(LOCALE_TAG_3));
+ private static final Instant WORK_SCHEDULED_TIME = Instant.now();
+ private static final Instant WORK_STARTED_TIME = WORK_SCHEDULED_TIME.plusSeconds(100);
+ // Make sure any combination has a different diff
+ private static final Instant DOWNLOAD_STARTED_TIME = WORK_STARTED_TIME.plusSeconds(1);
+ private static final Instant DOWNLOAD_ENDED_TIME = WORK_STARTED_TIME.plusSeconds(1 + 2);
+ private static final Instant DOWNLOAD_STARTED_TIME_2 = WORK_STARTED_TIME.plusSeconds(1 + 2 + 3);
+ private static final Instant DOWNLOAD_ENDED_TIME_2 = WORK_STARTED_TIME.plusSeconds(1 + 2 + 3 + 4);
+ private static final Instant WORK_ENDED_TIME = WORK_STARTED_TIME.plusSeconds(1 + 2 + 3 + 4 + 5);
+ private static final long DOWNLOAD_STARTED_TO_ENDED_MILLIS =
+ DOWNLOAD_ENDED_TIME.toEpochMilli() - DOWNLOAD_STARTED_TIME.toEpochMilli();
+ private static final long DOWNLOAD_STARTED_TO_ENDED_2_MILLIS =
+ DOWNLOAD_ENDED_TIME_2.toEpochMilli() - DOWNLOAD_STARTED_TIME_2.toEpochMilli();
+ private static final long WORK_SCHEDULED_TO_STARTED_MILLIS =
+ WORK_STARTED_TIME.toEpochMilli() - WORK_SCHEDULED_TIME.toEpochMilli();
+ private static final long WORK_STARTED_TO_ENDED_MILLIS =
+ WORK_ENDED_TIME.toEpochMilli() - WORK_STARTED_TIME.toEpochMilli();
+
+ @Mock private Clock clock;
+ @Mock private ModelDownloader modelDownloader;
+ private File modelDownloaderDir;
+ private File modelFile;
+ private File modelFile2;
+ private File modelFile3;
+ private DownloadedModelDatabase db;
+ private DownloadedModelManager downloadedModelManager;
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+ @Rule
+ public final TextClassifierDownloadLoggerTestRule loggerTestRule =
+ new TextClassifierDownloadLoggerTestRule();
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+
+ Context context = ApplicationProvider.getApplicationContext();
+ this.deviceConfig = new TestingDeviceConfig();
+ this.settings = new TextClassifierSettings(deviceConfig);
+ this.modelDownloaderDir = new File(context.getCacheDir(), "downloaded");
+ this.modelDownloaderDir.mkdirs();
+ this.modelFile = new File(modelDownloaderDir, "test.model");
+ this.modelFile2 = new File(modelDownloaderDir, "test2.model");
+ this.modelFile3 = new File(modelDownloaderDir, "test3.model");
+ this.db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ this.downloadedModelManager =
+ DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+
+ setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
+ }
+
+ @After
+ public void cleanUp() {
+ db.close();
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
+ }
+
+ @Test
+ public void downloadSucceed() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_modelAlreadyExists() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_manifestAlreadyExists() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+ downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_downloadMultipleModels() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile2);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_shareSingleModelDownloadForMultipleManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+ verify(modelDownloader, times(1)).downloadModel(modelDownloaderDir, MODEL_PROTO);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_shareManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+ verify(modelDownloader, times(1)).downloadManifest(MANIFEST_URL);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadFailed_failedToDownloadManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ verifyFailedDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.RETRY_MODEL_DOWNLOAD_FAILED);
+ }
+
+ @Test
+ public void downloadFailed_failedToDownloadModel() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ verifyFailedDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.RETRY_MODEL_DOWNLOAD_FAILED);
+ }
+
+ @Test
+ public void downloadFailed_modelDownloadManagerDisabled() throws Exception {
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
+ when(clock.instant()).thenReturn(WORK_STARTED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.failure());
+ assertThat(loggerTestRule.getLoggedDownloadReportedAtoms()).isEmpty();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.FAILURE_MODEL_DOWNLOADER_DISABLED);
+ }
+
+ @Test
+ public void downloadFailed_reachWorkerMaxRunAttempts() throws Exception {
+ when(clock.instant()).thenReturn(WORK_STARTED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(WORKER_MAX_RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.failure());
+ assertThat(loggerTestRule.getLoggedDownloadReportedAtoms()).isEmpty();
+ verifyWorkLogging(WORKER_MAX_RUN_ATTEMPT_COUNT, WorkResult.FAILURE_MAX_RUN_ATTEMPT_REACHED);
+ }
+
+ @Test
+ public void downloadSkipped_reachManifestMaxAttempts() throws Exception {
+ when(clock.instant()).thenReturn(WORK_STARTED_TIME, WORK_ENDED_TIME);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ deviceConfig.setConfig(
+ TextClassifierSettings.MANIFEST_DOWNLOAD_MAX_ATTEMPTS, MANIFEST_MAX_ATTEMPT_COUNT);
+
+ for (int i = 0; i < MANIFEST_MAX_ATTEMPT_COUNT; i++) {
+ downloadedModelManager.registerManifestDownloadFailure(MANIFEST_URL);
+ }
+ ModelDownloadWorker worker = createWorker(MANIFEST_MAX_ATTEMPT_COUNT);
+
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(loggerTestRule.getLoggedDownloadReportedAtoms()).isEmpty();
+ verifyWorkLogging(MANIFEST_MAX_ATTEMPT_COUNT, WorkResult.SUCCESS_NO_UPDATE_AVAILABLE);
+ }
+
+ @Test
+ public void downloadSkipped_manifestAlreadyProcessed() throws Exception {
+ when(clock.instant()).thenReturn(WORK_STARTED_TIME, WORK_ENDED_TIME);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+ downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+ downloadedModelManager.registerManifestEnrollment(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(loggerTestRule.getLoggedDownloadReportedAtoms()).isEmpty();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_NO_UPDATE_AVAILABLE);
+ }
+
+ @Test
+ public void downloadSucceeded_multiLanguageSupportEnabled() throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_2);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE))
+ .containsExactly(modelFile, modelFile2);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceeded_multiLanguageSupportEnabled_singleLocale() throws Exception {
+ setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceeded_multiLanguageSupportEnabled_oneManifestAlreadyDownloaded()
+ throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_2);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ modelFile.createNewFile();
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+
+ modelFile2.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL_2, modelFile2.getAbsolutePath());
+ downloadedModelManager.registerManifest(MANIFEST_URL_2, MODEL_URL_2);
+ downloadedModelManager.registerManifestEnrollment(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE))
+ .containsExactly(modelFile, modelFile2);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceeded_multiLanguageSupportDisabled() throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_2);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, false);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void oneDownloadFailed_multiLanguageSupportEnabled() throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_2);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
+
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.FAILED_AND_RETRY);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.RETRY_MODEL_DOWNLOAD_FAILED);
+ }
+
+ @Test
+ public void downloadSucceeded_multiLanguageSupportEnabled_checkLimit() throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_3);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_3, MANIFEST_URL_3);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_3))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_3));
+
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+ modelFile3.createNewFile();
+
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_3))
+ .thenReturn(Futures.immediateFuture(modelFile3));
+ // We assume we always download MODEL_TYPE first and then MODEL_TYPE_2, o/w this will be flaky
+ when(clock.instant())
+ .thenReturn(
+ WORK_STARTED_TIME,
+ DOWNLOAD_STARTED_TIME,
+ DOWNLOAD_ENDED_TIME,
+ DOWNLOAD_STARTED_TIME_2,
+ DOWNLOAD_ENDED_TIME_2,
+ WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE))
+ .containsExactly(modelFile, modelFile2);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedDownloadReportedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getDownloadDurationMillis)
+ .collect(Collectors.toList()))
+ .containsExactly(DOWNLOAD_STARTED_TO_ENDED_MILLIS, DOWNLOAD_STARTED_TO_ENDED_2_MILLIS);
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ @Test
+ public void downloadSucceed_multiLanguageSupportEnabled_onlyDownloadMultipleForEnabledModelType()
+ throws Exception {
+ setDefaultLocalesRule.set(LOCALE_LIST_2);
+ deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
+ deviceConfig.setConfig(
+ TextClassifierSettings.ENABLED_MODEL_TYPES_FOR_MULTI_LANGUAGE_SUPPORT, MODEL_TYPE_2);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL_2);
+
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+ when(clock.instant())
+ .thenReturn(WORK_STARTED_TIME, DOWNLOAD_STARTED_TIME, DOWNLOAD_ENDED_TIME, WORK_ENDED_TIME);
+
+ ModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifySucceededDownloadLogging();
+ verifyWorkLogging(RUN_ATTEMPT_COUNT, WorkResult.SUCCESS_MODEL_DOWNLOADED);
+ }
+
+ private ModelDownloadWorker createWorker(int runAttemptCount) {
+ return TestListenableWorkerBuilder.from(
+ ApplicationProvider.getApplicationContext(), ModelDownloadWorker.class)
+ .setRunAttemptCount(runAttemptCount)
+ .setWorkerFactory(
+ new WorkerFactory() {
+ @Override
+ public ListenableWorker createWorker(
+ Context appContext, String workerClassName, WorkerParameters workerParameters) {
+ return new ModelDownloadWorker(
+ appContext,
+ workerParameters,
+ MoreExecutors.newDirectExecutorService(),
+ modelDownloader,
+ downloadedModelManager,
+ settings,
+ WORK_ID,
+ clock,
+ WORK_SCHEDULED_TIME.toEpochMilli());
+ }
+ })
+ .build();
+ }
+
+ private void verifySucceededDownloadLogging() throws Exception {
+ TextClassifierDownloadReported atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadReportedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
+ assertThat(atom.getUrlSuffix()).isEqualTo(MANIFEST_URL);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(RUN_ATTEMPT_COUNT);
+ assertThat(atom.getFailureReason()).isEqualTo(FailureReason.UNKNOWN_FAILURE_REASON);
+ assertThat(atom.getDownloadDurationMillis()).isEqualTo(DOWNLOAD_STARTED_TO_ENDED_MILLIS);
+ }
+
+ private void verifyFailedDownloadLogging() throws Exception {
+ TextClassifierDownloadReported atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadReportedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getDownloadStatus()).isEqualTo(DownloadStatus.FAILED_AND_RETRY);
+ assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
+ assertThat(atom.getUrlSuffix()).isEqualTo(MANIFEST_URL);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(RUN_ATTEMPT_COUNT);
+ assertThat(atom.getFailureReason()).isEqualTo(FAILED_TO_DOWNLOAD_FAILURE_REASON);
+ assertThat(atom.getDownloaderLibFailureCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(atom.getDownloadDurationMillis()).isEqualTo(DOWNLOAD_STARTED_TO_ENDED_MILLIS);
+ }
+
+ private void verifyWorkLogging(int runTimeAttempt, WorkResult workResult) throws Exception {
+ TextClassifierDownloadWorkCompleted atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkCompletedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getWorkResult()).isEqualTo(workResult);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(runTimeAttempt);
+ assertThat(atom.getWorkScheduledToStartedDurationMillis())
+ .isEqualTo(WORK_SCHEDULED_TO_STARTED_MILLIS);
+ assertThat(atom.getWorkStartedToEndedDurationMillis()).isEqualTo(WORK_STARTED_TO_ENDED_MILLIS);
+ }
+
+ private void setUpManifestUrl(
+ @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+ String deviceConfigFlag =
+ String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+ deviceConfig.setConfig(deviceConfigFlag, url);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
new file mode 100644
index 0000000..0818057
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
@@ -0,0 +1,278 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.textclassifier.downloader.TestModelDownloaderService.DownloadResult;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.nio.file.Files;
+import java.util.concurrent.CancellationException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloaderImplTest {
+ private static final String MANIFEST_URL = "https://manifest.url";
+ private static final String MODEL_URL = "https://model.url";
+ private static final byte[] MODEL_CONTENT_BYTES = "content".getBytes();
+ private static final long MODEL_SIZE_IN_BYTES = 7L;
+ private static final String MODEL_FINGERPRINT =
+ "5406ebea1618e9b73a7290c5d716f0b47b4f1fbc5d8c"
+ + "5e78c9010a3e01c18d8594aa942e3536f7e01574245d34647523";
+ private static final ModelManifest.Model MODEL_PROTO =
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL)
+ .setSizeInBytes(MODEL_SIZE_IN_BYTES)
+ .setFingerprint(MODEL_FINGERPRINT)
+ .build();
+ private static final ModelManifest MODEL_MANIFEST_PROTO =
+ ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
+
+ private ModelDownloaderImpl modelDownloaderImpl;
+ private File modelDownloaderDir;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ this.modelDownloaderImpl =
+ new ModelDownloaderImpl(
+ context, MoreExecutors.newDirectExecutorService(), TestModelDownloaderService.class);
+ this.modelDownloaderDir = new File(context.getFilesDir(), "downloader");
+ this.modelDownloaderDir.mkdirs();
+
+ TestModelDownloaderService.reset();
+ }
+
+ @After
+ public void tearDown() {
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
+ }
+
+ @Test
+ public void downloadManifest_failToBind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(false);
+ ListenableFuture<ModelManifest> manifestFuture =
+ modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ Throwable t = expectThrows(Throwable.class, manifestFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode())
+ .isEqualTo(ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ }
+
+ @Test
+ public void downloadManifest_succeed() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(
+ MANIFEST_URL, DownloadResult.SUCCEEDED, MODEL_MANIFEST_PROTO.toByteArray());
+ ListenableFuture<ModelManifest> manifestFuture =
+ modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ assertThat(manifestFuture.get()).isEqualTo(MODEL_MANIFEST_PROTO); // ProtoTruth is not available
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadManifest_failToDownload() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.FAILED, null);
+ ListenableFuture<ModelManifest> manifestFuture =
+ modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ Throwable t = expectThrows(Throwable.class, manifestFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode()).isEqualTo(ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(TestModelDownloaderService.DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(e).hasMessageThat().contains(TestModelDownloaderService.ERROR_MSG);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadManifest_failToParse() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(
+ MANIFEST_URL, DownloadResult.SUCCEEDED, "randomString".getBytes());
+ ListenableFuture<ModelManifest> manifestFuture =
+ modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ Throwable t = expectThrows(Throwable.class, manifestFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode()).isEqualTo(ModelDownloadException.FAILED_TO_PARSE_MANIFEST);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadManifest_cancelAndUnbind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.DO_NOTHING, null);
+ ListenableFuture<ModelManifest> manifestFuture =
+ modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isTrue();
+ manifestFuture.cancel(true);
+
+ expectThrows(CancellationException.class, manifestFuture::get);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadModel_failToBind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(false);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ Throwable t = expectThrows(Throwable.class, modelFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode())
+ .isEqualTo(ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ }
+
+ @Test
+ public void downloadModel_succeed() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(
+ MODEL_URL, DownloadResult.SUCCEEDED, MODEL_CONTENT_BYTES);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ File modelFile = modelFuture.get();
+ assertThat(modelFile.getParentFile()).isEqualTo(modelDownloaderDir);
+ assertThat(Files.readAllBytes(modelFile.toPath())).isEqualTo(MODEL_CONTENT_BYTES);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadModel_failToDownload() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.FAILED, null);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ Throwable t = expectThrows(Throwable.class, modelFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode()).isEqualTo(ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(TestModelDownloaderService.DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(e).hasMessageThat().contains(TestModelDownloaderService.ERROR_MSG);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadModel_failToValidate() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(
+ MODEL_URL, DownloadResult.SUCCEEDED, "randomString".getBytes());
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ Throwable t = expectThrows(Throwable.class, modelFuture::get);
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode()).isEqualTo(ModelDownloadException.FAILED_TO_VALIDATE_MODEL);
+ assertThat(e.getDownloaderLibErrorCode())
+ .isEqualTo(ModelDownloadException.DEFAULT_DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void downloadModel_cancelAndUnbind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.DO_NOTHING, null);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isTrue();
+ modelFuture.cancel(true);
+
+ expectThrows(CancellationException.class, modelFuture::get);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
new file mode 100644
index 0000000..e261158
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderIntegrationTest.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.util.Log;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassification.Request;
+import com.android.textclassifier.testing.ExtServicesTextClassifierRule;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ModelDownloaderIntegrationTest {
+ private static final String TAG = "ModelDownloaderTest";
+ private static final String EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/r/experimental/v999999999/en.fb.manifest";
+ private static final String EXPERIMENTAL_EN_TAG = "en_v999999999";
+ private static final String V804_EN_ANNOTATOR_MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/r/v804/en.fb.manifest";
+ private static final String V804_RU_ANNOTATOR_MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/r/v804/ru.fb.manifest";
+ private static final String V804_EN_TAG = "en_v804";
+ private static final String V804_RU_TAG = "ru_v804";
+ private static final String FACTORY_MODEL_TAG = "*";
+ private static final int ASSERT_MAX_ATTEMPTS = 20;
+ private static final int ASSERT_SLEEP_BEFORE_RETRY_MS = 1000;
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ @Before
+ public void setup() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride("config_updater_model_enabled", "false");
+ extServicesTextClassifierRule.addDeviceConfigOverride("model_download_manager_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "model_download_backoff_delay_in_millis", "5");
+ extServicesTextClassifierRule.addDeviceConfigOverride("testing_locale_list_override", "en-US");
+ extServicesTextClassifierRule.overrideDeviceConfig();
+
+ extServicesTextClassifierRule.enableVerboseLogging();
+ // Verbose logging only takes effect after restarting ExtServices
+ extServicesTextClassifierRule.forceStopExtServices();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ // This is to reset logging/locale_override for ExtServices.
+ extServicesTextClassifierRule.forceStopExtServices();
+ }
+
+ @Test
+ public void smokeTest() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
+ }
+
+ @Test
+ public void downgradeModel() throws Exception {
+ // Download an experimental model.
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+
+ // Downgrade to an older model.
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
+ }
+
+ @Test
+ public void upgradeModel() throws Exception {
+ // Download a model.
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", V804_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(V804_EN_TAG));
+
+ // Upgrade to an experimental model.
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+ }
+
+ @Test
+ public void clearFlag() throws Exception {
+ // Download a new model.
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+
+ // Revert the flag.
+ extServicesTextClassifierRule.addDeviceConfigOverride("manifest_url_annotator_en", "");
+ // Fallback to use the universal model.
+ assertWithRetries(
+ () -> verifyActiveModel(/* text= */ "abc", /* expectedVersion= */ FACTORY_MODEL_TAG));
+ }
+
+ @Test
+ public void modelsForMultipleLanguagesDownloaded() throws Exception {
+ extServicesTextClassifierRule.addDeviceConfigOverride("multi_language_support_enabled", "true");
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "testing_locale_list_override", "en-US,ru-RU");
+
+ // download en model
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_en", EXPERIMENTAL_EN_ANNOTATOR_MANIFEST_URL);
+
+ // download ru model
+ extServicesTextClassifierRule.addDeviceConfigOverride(
+ "manifest_url_annotator_ru", V804_RU_ANNOTATOR_MANIFEST_URL);
+ assertWithRetries(() -> verifyActiveEnglishModel(EXPERIMENTAL_EN_TAG));
+
+ assertWithRetries(this::verifyActiveRussianModel);
+
+ assertWithRetries(
+ () -> verifyActiveModel(/* text= */ "français", /* expectedVersion= */ FACTORY_MODEL_TAG));
+ }
+
+ private void assertWithRetries(Runnable assertRunnable) throws Exception {
+ for (int i = 0; i < ASSERT_MAX_ATTEMPTS; i++) {
+ try {
+ extServicesTextClassifierRule.overrideDeviceConfig();
+ assertRunnable.run();
+ break; // success. Bail out.
+ } catch (AssertionError ex) {
+ if (i == ASSERT_MAX_ATTEMPTS - 1) { // last attempt, give up.
+ extServicesTextClassifierRule.dumpDefaultTextClassifierService();
+ throw ex;
+ } else {
+ Thread.sleep(ASSERT_SLEEP_BEFORE_RETRY_MS);
+ }
+ } catch (Exception unknownException) {
+ throw unknownException;
+ }
+ }
+ }
+
+ private void verifyActiveModel(String text, String expectedVersion) {
+ TextClassification textClassification =
+ extServicesTextClassifierRule
+ .getTextClassifier()
+ .classifyText(new Request.Builder(text, 0, text.length()).build());
+ // The result id contains the name of the just used model.
+ Log.d(TAG, "verifyActiveModel. TextClassification ID: " + textClassification.getId());
+ assertThat(textClassification.getId()).contains(expectedVersion);
+ }
+
+ private void verifyActiveEnglishModel(String expectedVersion) {
+ verifyActiveModel("abc", expectedVersion);
+ }
+
+ private void verifyActiveRussianModel() {
+ verifyActiveModel("привет", V804_RU_TAG);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
new file mode 100644
index 0000000..76d04e0
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.expectThrows;
+
+import androidx.test.core.app.ApplicationProvider;
+import com.google.android.downloader.DownloadConstraints;
+import com.google.android.downloader.DownloadRequest;
+import com.google.android.downloader.DownloadResult;
+import com.google.android.downloader.Downloader;
+import com.google.android.downloader.ErrorDetails;
+import com.google.android.downloader.RequestException;
+import com.google.android.downloader.SimpleFileDownloadDestination;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
+import java.io.File;
+import java.net.URI;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloaderServiceImplTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
+ private static final long BYTES_WRITTEN = 1L;
+ private static final String DOWNLOAD_URI =
+ "https://www.gstatic.com/android/text_classifier/r/v999/en.fb";
+ private static final int DOWNLOADER_LIB_ERROR_CODE = 500;
+ private static final String ERROR_MESSAGE = "err_msg";
+ private static final Exception DOWNLOADER_LIB_EXCEPTION =
+ new RequestException(
+ ErrorDetails.builder()
+ .setErrorMessage(ERROR_MESSAGE)
+ .setHttpStatusCode(DOWNLOADER_LIB_ERROR_CODE)
+ .build());
+
+ @Mock private Downloader downloader;
+ private File targetModelFile;
+ private File targetMetadataFile;
+ private ModelDownloaderServiceImpl modelDownloaderServiceImpl;
+ private TestSuccessCallbackImpl successCallback;
+ private TestFailureCallbackImpl failureCallback;
+
+ @Before
+ public void setUp() {
+
+ this.targetModelFile =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
+ this.targetMetadataFile = ModelDownloaderServiceImpl.getMetadataFile(targetModelFile);
+ this.modelDownloaderServiceImpl =
+ new ModelDownloaderServiceImpl(MoreExecutors.newDirectExecutorService(), downloader);
+ this.successCallback = new TestSuccessCallbackImpl();
+ this.failureCallback = new TestFailureCallbackImpl();
+
+ targetModelFile.deleteOnExit();
+ targetMetadataFile.deleteOnExit();
+ when(downloader.newRequestBuilder(any(), any()))
+ .thenReturn(
+ DownloadRequest.newBuilder()
+ .setUri(URI.create(DOWNLOAD_URI))
+ .setDownloadConstraints(DownloadConstraints.NONE)
+ .setDestination(
+ new SimpleFileDownloadDestination(targetModelFile, targetMetadataFile)));
+ }
+
+ @Test
+ public void download_succeeded() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(
+ FluentFuture.from(Futures.immediateFuture(DownloadResult.create(BYTES_WRITTEN))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), successCallback);
+
+ assertThat(successCallback.getBytesWrittenFuture().get()).isEqualTo(BYTES_WRITTEN);
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(FluentFuture.from(Futures.immediateFailedFuture(DOWNLOADER_LIB_EXCEPTION)));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), successCallback);
+
+ Throwable t =
+ expectThrows(Throwable.class, () -> successCallback.getBytesWrittenFuture().get());
+ assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
+ ModelDownloadException e = (ModelDownloadException) t.getCause();
+ assertThat(e.getErrorCode()).isEqualTo(ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER);
+ assertThat(e.getDownloaderLibErrorCode()).isEqualTo(DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(e).hasMessageThat().contains(ERROR_MESSAGE);
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_succeeded_callbackFailed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(
+ FluentFuture.from(Futures.immediateFuture(DownloadResult.create(BYTES_WRITTEN))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), failureCallback);
+
+ assertThat(failureCallback.onSuccessCalled).isTrue();
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failed_callbackFailed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(FluentFuture.from(Futures.immediateFailedFuture(DOWNLOADER_LIB_EXCEPTION)));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), failureCallback);
+
+ assertThat(failureCallback.onFailureCalled).isTrue();
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ // NOTICE: Had some problem mocking this AIDL interface, so created fake impls
+ private static final class TestSuccessCallbackImpl extends IModelDownloaderCallback.Stub {
+ private final SettableFuture<Long> bytesWrittenFuture = SettableFuture.<Long>create();
+
+ public ListenableFuture<Long> getBytesWrittenFuture() {
+ return bytesWrittenFuture;
+ }
+
+ @Override
+ public void onSuccess(long bytesWritten) {
+ bytesWrittenFuture.set(bytesWritten);
+ }
+
+ @Override
+ public void onFailure(int downloaderLibErrorCode, String errorMsg) {
+ bytesWrittenFuture.setException(
+ new ModelDownloadException(
+ ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER, downloaderLibErrorCode, errorMsg));
+ }
+ }
+
+ private static final class TestFailureCallbackImpl extends IModelDownloaderCallback.Stub {
+ public boolean onSuccessCalled = false;
+ public boolean onFailureCalled = false;
+
+ @Override
+ public void onSuccess(long bytesWritten) {
+ onSuccessCalled = true;
+ throw new RuntimeException();
+ }
+
+ @Override
+ public void onFailure(int downloaderLibErrorCode, String errorMsg) {
+ onFailureCalled = true;
+ throw new RuntimeException();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
new file mode 100644
index 0000000..a782b4c
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.app.Service;
+import android.content.Intent;
+import android.os.IBinder;
+import com.android.textclassifier.common.base.TcLog;
+import java.io.File;
+import java.nio.file.Files;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+
+// TODO(licha): Find another way to test the service. (E.g. CtsTextClassifierService.java)
+/** Test Service of IModelDownloaderService. */
+public final class TestModelDownloaderService extends Service {
+ private static final String TAG = "TestModelDownloaderService";
+
+ public static final String GOOD_URI = "good_uri";
+ public static final String BAD_URI = "bad_uri";
+ public static final long BYTES_WRITTEN = 1L;
+ public static final int DOWNLOADER_LIB_ERROR_CODE = 500;
+ public static final String ERROR_MSG = "not good uri";
+
+ public enum DownloadResult {
+ SUCCEEDED,
+ FAILED,
+ DO_NOTHING
+ }
+
+ // Obviously this does not work when considering concurrency, but probably fine for test purpose
+ private static boolean boundBefore = false;
+ private static boolean boundNow = false;
+ private static CountDownLatch onBindInvokedLatch = new CountDownLatch(1);
+ private static CountDownLatch onUnbindInvokedLatch = new CountDownLatch(1);
+
+ private static boolean bindSucceed = false;
+ private static String expectedUrl = null;
+ private static DownloadResult downloadResult = DownloadResult.SUCCEEDED;
+ private static byte[] fileContent = null;
+
+ public static boolean hasEverBeenBound() {
+ return boundBefore;
+ }
+
+ public static boolean isBound() {
+ return boundNow;
+ }
+
+ public static CountDownLatch getOnBindInvokedLatch() {
+ return onBindInvokedLatch;
+ }
+
+ public static CountDownLatch getOnUnbindInvokedLatch() {
+ return onUnbindInvokedLatch;
+ }
+
+ public static void setBindSucceed(boolean bindSucceed) {
+ TestModelDownloaderService.bindSucceed = bindSucceed;
+ }
+
+ public static void setDownloadResult(
+ String url, DownloadResult result, @Nullable byte[] fileContent) {
+ TestModelDownloaderService.expectedUrl = url;
+ TestModelDownloaderService.downloadResult = result;
+ TestModelDownloaderService.fileContent = fileContent;
+ }
+
+ public static void reset() {
+ boundBefore = false;
+ boundNow = false;
+ onBindInvokedLatch = new CountDownLatch(1);
+ onUnbindInvokedLatch = new CountDownLatch(1);
+ bindSucceed = false;
+ }
+
+ @Override
+ public IBinder onBind(Intent intent) {
+ try {
+ if (bindSucceed) {
+ boundBefore = true;
+ boundNow = true;
+ return new TestModelDownloaderServiceImpl();
+ } else {
+ return null;
+ }
+ } finally {
+ onBindInvokedLatch.countDown();
+ }
+ }
+
+ @Override
+ public boolean onUnbind(Intent intent) {
+ try {
+ boundNow = false;
+ return false;
+ } finally {
+ onUnbindInvokedLatch.countDown();
+ }
+ }
+
+ private static final class TestModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
+ @Override
+ public void download(String url, String targetFilePath, IModelDownloaderCallback callback) {
+ if (expectedUrl == null || !expectedUrl.equals(url)) {
+ throw new IllegalStateException("url does not match");
+ }
+ TcLog.d(TAG, String.format("Test Request: %s, %s, %s", url, targetFilePath, downloadResult));
+ try {
+ switch (downloadResult) {
+ case SUCCEEDED:
+ File targetFile = new File(targetFilePath);
+ targetFile.createNewFile();
+ Files.write(targetFile.toPath(), fileContent);
+ callback.onSuccess(BYTES_WRITTEN);
+ break;
+ case FAILED:
+ callback.onFailure(DOWNLOADER_LIB_ERROR_CODE, ERROR_MSG);
+ break;
+ case DO_NOTHING:
+ // Do nothing
+ }
+ } catch (Throwable t) {
+ // The test would timeout if failing to get the callback result
+ }
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TextClassifierDownloadLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TextClassifierDownloadLoggerTest.java
new file mode 100644
index 0000000..ed76fa8
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TextClassifierDownloadLoggerTest.java
@@ -0,0 +1,144 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.android.os.AtomsProto.TextClassifierDownloadReported;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkCompleted;
+import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
+import com.google.common.collect.Iterables;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+public final class TextClassifierDownloadLoggerTest {
+ private static final String MODEL_TYPE = ModelType.LANG_ID;
+ private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
+ TextClassifierDownloadReported.ModelType.LANG_ID;
+ private static final String URL =
+ "https://www.gstatic.com/android/text_classifier/x/v123/en.fb.manifest";
+ private static final int ERROR_CODE = ModelDownloadException.FAILED_TO_DOWNLOAD_404_ERROR;
+ private static final TextClassifierDownloadReported.FailureReason FAILURE_REASON_ATOM =
+ TextClassifierDownloadReported.FailureReason.FAILED_TO_DOWNLOAD_404_ERROR;
+ private static final int RUN_ATTEMPT_COUNT = 1;
+ private static final long WORK_ID = 123456789L;
+ private static final long DOWNLOAD_DURATION_MILLIS = 666L;
+ private static final int DOWNLOADER_LIB_ERROR_CODE = 500;
+ private static final int REASON_TO_SCHEDULE =
+ TextClassifierDownloadLogger.REASON_TO_SCHEDULE_TCS_STARTED;
+ private static final TextClassifierDownloadWorkScheduled.ReasonToSchedule
+ REASON_TO_SCHEDULE_ATOM = TextClassifierDownloadWorkScheduled.ReasonToSchedule.TCS_STARTED;
+ private static final int WORK_RESULT =
+ TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_MODEL_DOWNLOADED;
+ private static final TextClassifierDownloadWorkCompleted.WorkResult WORK_RESULT_ATOM =
+ TextClassifierDownloadWorkCompleted.WorkResult.SUCCESS_MODEL_DOWNLOADED;
+ private static final long SCHEDULED_TO_START_DURATION_MILLIS = 777L;
+ private static final long STARTED_TO_FINISHED_DURATION_MILLIS = 888L;
+
+ @Rule
+ public final TextClassifierDownloadLoggerTestRule loggerTestRule =
+ new TextClassifierDownloadLoggerTestRule();
+
+ @Test
+ public void downloadSucceeded() throws Exception {
+ TextClassifierDownloadLogger.downloadSucceeded(
+ WORK_ID, MODEL_TYPE, URL, RUN_ATTEMPT_COUNT, DOWNLOAD_DURATION_MILLIS);
+
+ TextClassifierDownloadReported atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadReportedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getDownloadStatus())
+ .isEqualTo(TextClassifierDownloadReported.DownloadStatus.SUCCEEDED);
+ assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
+ assertThat(atom.getUrlSuffix()).isEqualTo(URL);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(RUN_ATTEMPT_COUNT);
+ assertThat(atom.getDownloadDurationMillis()).isEqualTo(DOWNLOAD_DURATION_MILLIS);
+ }
+
+ @Test
+ public void downloadFailed() throws Exception {
+ TextClassifierDownloadLogger.downloadFailed(
+ WORK_ID,
+ MODEL_TYPE,
+ URL,
+ ERROR_CODE,
+ RUN_ATTEMPT_COUNT,
+ DOWNLOADER_LIB_ERROR_CODE,
+ DOWNLOAD_DURATION_MILLIS);
+
+ TextClassifierDownloadReported atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadReportedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getDownloadStatus())
+ .isEqualTo(TextClassifierDownloadReported.DownloadStatus.FAILED_AND_RETRY);
+ assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
+ assertThat(atom.getUrlSuffix()).isEqualTo(URL);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(RUN_ATTEMPT_COUNT);
+ assertThat(atom.getFailureReason()).isEqualTo(FAILURE_REASON_ATOM);
+ assertThat(atom.getDownloaderLibFailureCode()).isEqualTo(DOWNLOADER_LIB_ERROR_CODE);
+ assertThat(atom.getDownloadDurationMillis()).isEqualTo(DOWNLOAD_DURATION_MILLIS);
+ }
+
+ @Test
+ public void downloadWorkScheduled_succeeded() throws Exception {
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ WORK_ID, REASON_TO_SCHEDULE, /* failedToSchedule= */ false);
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getReasonToSchedule()).isEqualTo(REASON_TO_SCHEDULE_ATOM);
+ assertThat(atom.getFailedToSchedule()).isFalse();
+ }
+
+ @Test
+ public void downloadWorkScheduled_failed() throws Exception {
+ TextClassifierDownloadLogger.downloadWorkScheduled(
+ WORK_ID, REASON_TO_SCHEDULE, /* failedToSchedule= */ true);
+
+ TextClassifierDownloadWorkScheduled atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getReasonToSchedule()).isEqualTo(REASON_TO_SCHEDULE_ATOM);
+ assertThat(atom.getFailedToSchedule()).isTrue();
+ }
+
+ @Test
+ public void downloadWorkCompleted() throws Exception {
+ TextClassifierDownloadLogger.downloadWorkCompleted(
+ WORK_ID,
+ WORK_RESULT,
+ RUN_ATTEMPT_COUNT,
+ SCHEDULED_TO_START_DURATION_MILLIS,
+ STARTED_TO_FINISHED_DURATION_MILLIS);
+
+ TextClassifierDownloadWorkCompleted atom =
+ Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkCompletedAtoms());
+ assertThat(atom.getWorkId()).isEqualTo(WORK_ID);
+ assertThat(atom.getWorkResult()).isEqualTo(WORK_RESULT_ATOM);
+ assertThat(atom.getRunAttemptCount()).isEqualTo(RUN_ATTEMPT_COUNT);
+ assertThat(atom.getWorkScheduledToStartedDurationMillis())
+ .isEqualTo(SCHEDULED_TO_START_DURATION_MILLIS);
+ assertThat(atom.getWorkStartedToEndedDurationMillis())
+ .isEqualTo(STARTED_TO_FINISHED_DURATION_MILLIS);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
new file mode 100644
index 0000000..5f8247d
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/ExtServicesTextClassifierRule.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.app.UiAutomation;
+import android.content.pm.PackageManager;
+import android.content.pm.PackageManager.NameNotFoundException;
+import android.provider.DeviceConfig;
+import android.util.Log;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.io.ByteStreams;
+import java.io.FileInputStream;
+import java.io.IOException;
+import org.junit.rules.ExternalResource;
+
+/** A rule that manages a text classifier that is backed by the ExtServices. */
+public final class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String TAG = "androidtc";
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+ private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
+ private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
+
+ private UiAutomation uiAutomation;
+ private DeviceConfig.Properties originalProperties;
+ private DeviceConfig.Properties.Builder newPropertiesBuilder;
+
+ @Override
+ protected void before() throws Exception {
+ uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ uiAutomation.adoptShellPermissionIdentity();
+ originalProperties = DeviceConfig.getProperties(DeviceConfig.NAMESPACE_TEXTCLASSIFIER);
+ newPropertiesBuilder =
+ new DeviceConfig.Properties.Builder(DeviceConfig.NAMESPACE_TEXTCLASSIFIER)
+ .setString(
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE, getExtServicesPackageName());
+ overrideDeviceConfig();
+ }
+
+ @Override
+ protected void after() {
+ try {
+ DeviceConfig.setProperties(originalProperties);
+ } catch (Throwable t) {
+ Log.e(TAG, "Failed to reset DeviceConfig", t);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ public void addDeviceConfigOverride(String name, String value) {
+ newPropertiesBuilder.setString(name, value);
+ }
+
+ /**
+ * Overrides the TextClassifier DeviceConfig manually.
+ *
+ * <p>This will clean up all device configs not in newPropertiesBuilder.
+ *
+ * <p>We will need to call this everytime before testing, because DeviceConfig can be synced in
+ * background at anytime. DeviceConfig#setSyncDisabledMode is to disable sync, however it's a
+ * hidden API.
+ */
+ public void overrideDeviceConfig() throws Exception {
+ DeviceConfig.setProperties(newPropertiesBuilder.build());
+ }
+
+ /** Force stop ExtServices. Force-stop-and-start can be helpful to reload some states. */
+ public void forceStopExtServices() {
+ runShellCommand("am force-stop com.google.android.ext.services");
+ runShellCommand("am force-stop android.ext.services");
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ textClassificationManager.setTextClassifier(null); // Reset TC overrides
+ return textClassificationManager.getTextClassifier();
+ }
+
+ public void dumpDefaultTextClassifierService() {
+ runShellCommand(
+ "dumpsys activity service com.google.android.ext.services/"
+ + "com.android.textclassifier.DefaultTextClassifierService");
+ runShellCommand("cmd device_config list textclassifier");
+ }
+
+ public void enableVerboseLogging() {
+ runShellCommand("setprop log.tag.androidtc VERBOSE");
+ }
+
+ private void runShellCommand(String cmd) {
+ Log.v(TAG, "run shell command: " + cmd);
+ try (FileInputStream output =
+ new FileInputStream(uiAutomation.executeShellCommand(cmd).getFileDescriptor())) {
+ String cmdOutput = new String(ByteStreams.toByteArray(output));
+ if (!cmdOutput.isEmpty()) {
+ Log.d(TAG, "cmd output: " + cmdOutput);
+ }
+ } catch (IOException ioe) {
+ Log.w(TAG, "failed to get cmd output", ioe);
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager = ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
+}
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index 950eee6..813ec6a 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -15,16 +15,9 @@
//
genrule {
- name: "libtextclassifier_fbgen_actions_actions_model",
- srcs: ["actions/actions_model.fbs"],
- out: ["actions/actions_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions_actions-entity-data",
- srcs: ["actions/actions-entity-data.fbs"],
- out: ["actions/actions-entity-data_generated.h"],
+ name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
defaults: ["fbgen"],
}
@@ -36,20 +29,6 @@
}
genrule {
- name: "libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
- srcs: ["annotator/person_name/person_name_model.fbs"],
- out: ["annotator/person_name/person_name_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
name: "libtextclassifier_fbgen_annotator_datetime_datetime",
srcs: ["annotator/datetime/datetime.fbs"],
out: ["annotator/datetime/datetime_generated.h"],
@@ -57,6 +36,13 @@
}
genrule {
+ name: "libtextclassifier_fbgen_annotator_model",
+ srcs: ["annotator/model.fbs"],
+ out: ["annotator/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
name: "libtextclassifier_fbgen_annotator_experimental_experimental",
srcs: ["annotator/experimental/experimental.fbs"],
out: ["annotator/experimental/experimental_generated.h"],
@@ -71,16 +57,9 @@
}
genrule {
- name: "libtextclassifier_fbgen_annotator_model",
- srcs: ["annotator/model.fbs"],
- out: ["annotator/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
- srcs: ["utils/flatbuffers/flatbuffers.fbs"],
- out: ["utils/flatbuffers/flatbuffers_generated.h"],
+ name: "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ srcs: ["annotator/person_name/person_name_model.fbs"],
+ out: ["annotator/person_name/person_name_model_generated.h"],
defaults: ["fbgen"],
}
@@ -92,23 +71,9 @@
}
genrule {
- name: "libtextclassifier_fbgen_utils_resources",
- srcs: ["utils/resources.fbs"],
- out: ["utils/resources_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_container_bit-vector",
- srcs: ["utils/container/bit-vector.fbs"],
- out: ["utils/container/bit-vector_generated.h"],
+ name: "libtextclassifier_fbgen_utils_codepoint-range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
defaults: ["fbgen"],
}
@@ -120,37 +85,16 @@
}
genrule {
- name: "libtextclassifier_fbgen_utils_normalization",
- srcs: ["utils/normalization.fbs"],
- out: ["utils/normalization_generated.h"],
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ srcs: ["utils/flatbuffers/flatbuffers.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_generated.h"],
defaults: ["fbgen"],
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- srcs: ["utils/grammar/semantics/expression.fbs"],
- out: ["utils/grammar/semantics/expression_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_grammar_rules",
- srcs: ["utils/grammar/rules.fbs"],
- out: ["utils/grammar/rules_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_grammar_testing_value",
- srcs: ["utils/grammar/testing/value.fbs"],
- out: ["utils/grammar/testing/value_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_utils_codepoint-range",
- srcs: ["utils/codepoint-range.fbs"],
- out: ["utils/codepoint-range_generated.h"],
+ name: "libtextclassifier_fbgen_utils_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
defaults: ["fbgen"],
}
@@ -162,12 +106,68 @@
}
genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_testing_value",
+ srcs: ["utils/grammar/testing/value.fbs"],
+ out: ["utils/grammar/testing/value_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_rules",
+ srcs: ["utils/grammar/rules.fbs"],
+ out: ["utils/grammar/rules_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ srcs: ["utils/grammar/semantics/expression.fbs"],
+ out: ["utils/grammar/semantics/expression_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_resources",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
name: "libtextclassifier_fbgen_utils_i18n_language-tag",
srcs: ["utils/i18n/language-tag.fbs"],
out: ["utils/i18n/language-tag_generated.h"],
defaults: ["fbgen"],
}
+genrule {
+ name: "libtextclassifier_fbgen_utils_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_container_bit-vector",
+ srcs: ["utils/container/bit-vector.fbs"],
+ out: ["utils/container/bit-vector_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
cc_library_headers {
name: "libtextclassifier_flatbuffer_headers",
stl: "libc++_static",
@@ -178,50 +178,50 @@
"com.android.extservices",
],
generated_headers: [
- "libtextclassifier_fbgen_actions_actions_model",
- "libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_zlib_buffer",
- "libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_utils_intents_intent-config",
- "libtextclassifier_fbgen_utils_normalization",
- "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
"libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_actions_actions_model",
],
export_generated_headers: [
- "libtextclassifier_fbgen_actions_actions_model",
- "libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_lang_id_common_flatbuffers_model",
- "libtextclassifier_fbgen_annotator_person_name_person_name_model",
+ "libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_annotator_datetime_datetime",
+ "libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_utils_tflite_text_encoder_config",
- "libtextclassifier_fbgen_utils_resources",
- "libtextclassifier_fbgen_utils_zlib_buffer",
- "libtextclassifier_fbgen_utils_container_bit-vector",
- "libtextclassifier_fbgen_utils_intents_intent-config",
- "libtextclassifier_fbgen_utils_normalization",
- "libtextclassifier_fbgen_utils_grammar_semantics_expression",
- "libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_codepoint-range",
+ "libtextclassifier_fbgen_utils_intents_intent-config",
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers",
+ "libtextclassifier_fbgen_utils_zlib_buffer",
"libtextclassifier_fbgen_utils_tokenizer",
+ "libtextclassifier_fbgen_utils_grammar_rules",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ "libtextclassifier_fbgen_utils_resources",
"libtextclassifier_fbgen_utils_i18n_language-tag",
+ "libtextclassifier_fbgen_utils_normalization",
+ "libtextclassifier_fbgen_utils_container_bit-vector",
+ "libtextclassifier_fbgen_actions_actions-entity-data",
+ "libtextclassifier_fbgen_actions_actions_model",
],
}
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index 78d5748..9837173 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -17,30 +17,30 @@
filegroup {
name: "libtextclassifier_java_test_sources",
srcs: [
- "actions/grammar-actions_test.cc",
- "actions/actions-suggestions_test.cc",
- "annotator/pod_ner/pod-ner-impl_test.cc",
+ "annotator/datetime/datetime-grounder_test.cc",
"annotator/datetime/regex-parser_test.cc",
"annotator/datetime/grammar-parser_test.cc",
- "annotator/datetime/datetime-grounder_test.cc",
+ "annotator/pod_ner/pod-ner-impl_test.cc",
"utils/intents/intent-generator-test-lib.cc",
"utils/calendar/calendar_test.cc",
"utils/regex-match_test.cc",
"utils/grammar/parsing/lexer_test.cc",
+ "actions/actions-suggestions_test.cc",
+ "actions/grammar-actions_test.cc",
"annotator/number/number_test-include.cc",
"annotator/annotator_test-include.cc",
"annotator/grammar/grammar-annotator_test.cc",
"annotator/grammar/test-utils.cc",
"utils/utf8/unilib_test-include.cc",
+ "utils/grammar/parsing/parser_test.cc",
"utils/grammar/analyzer_test.cc",
"utils/grammar/semantics/composer_test.cc",
- "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
+ "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
"utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
+ "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
+ "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
+ "utils/grammar/semantics/evaluators/span-eval_test.cc",
"utils/grammar/semantics/evaluators/const-eval_test.cc",
"utils/grammar/semantics/evaluators/compose-eval_test.cc",
- "utils/grammar/semantics/evaluators/span-eval_test.cc",
- "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
- "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
- "utils/grammar/parsing/parser_test.cc",
],
}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index 7421579..6ebf1cf 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 21584b6..e906f93 100644
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
@@ -18,7 +18,7 @@
namespace libtextclassifier3;
table ActionsEntityData {
// Extracted text.
- text:string (shared);
+ text:string (key, shared);
}
root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index b1a042c..9f9a8d4 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -17,6 +17,7 @@
#include "actions/actions-suggestions.h"
#include <memory>
+#include <string>
#include <vector>
#include "utils/base/statusor.h"
@@ -40,6 +41,7 @@
#include "utils/strings/stringpiece.h"
#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
@@ -809,12 +811,14 @@
void ActionsSuggestions::PopulateTextReplies(
const tflite::Interpreter* interpreter, int suggestion_index,
- int score_index, const std::string& type,
+ int score_index, const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const {
const std::vector<tflite::StringRef> replies =
model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
const TensorView<float> scores =
model_executor_->OutputView<float>(score_index, interpreter);
+
for (int i = 0; i < replies.size(); i++) {
if (replies[i].len == 0) {
continue;
@@ -823,8 +827,12 @@
if (score < preconditions_.min_reply_score_threshold) {
continue;
}
- response->actions.push_back(
- {std::string(replies[i].str, replies[i].len), type, score});
+ std::string response_text(replies[i].str, replies[i].len);
+ if (blocklist.contains(response_text)) {
+ continue;
+ }
+
+ response->actions.push_back({response_text, type, score, priority_score});
}
}
@@ -909,10 +917,12 @@
// Read smart reply predictions.
if (!response->output_filtered_min_triggering_score &&
model_->tflite_model_spec()->output_replies() >= 0) {
+ absl::flat_hash_set<std::string> empty_blocklist;
PopulateTextReplies(interpreter,
model_->tflite_model_spec()->output_replies(),
model_->tflite_model_spec()->output_replies_scores(),
- model_->smart_reply_action_type()->str(), response);
+ model_->smart_reply_action_type()->str(),
+ /* priority_score */ 0.0, empty_blocklist, response);
}
// Read actions suggestions.
@@ -950,17 +960,26 @@
const int suggestions_index = metadata->output_suggestions();
const int suggestions_scores_index =
metadata->output_suggestions_scores();
+ absl::flat_hash_set<std::string> response_text_blocklist;
switch (metadata->prediction_type()) {
case PredictionType_NEXT_MESSAGE_PREDICTION:
if (!task_spec || task_spec->type()->size() == 0) {
TC3_LOG(WARNING) << "Task type not provided, use default "
"smart_reply_action_type!";
}
+ if (task_spec) {
+ if (task_spec->response_text_blocklist()) {
+ for (const auto& val : *task_spec->response_text_blocklist()) {
+ response_text_blocklist.insert(val->str());
+ }
+ }
+ }
PopulateTextReplies(
interpreter, suggestions_index, suggestions_scores_index,
task_spec ? task_spec->type()->str()
: model_->smart_reply_action_type()->str(),
- response);
+ task_spec ? task_spec->priority_score() : 0.0,
+ response_text_blocklist, response);
break;
case PredictionType_INTENT_TRIGGERING:
PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 32edc78..87f55fb 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,6 +43,7 @@
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_set.h"
namespace libtextclassifier3 {
@@ -176,7 +177,8 @@
void PopulateTextReplies(const tflite::Interpreter* interpreter,
int suggestion_index, int score_index,
- const std::string& type,
+ const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const;
void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 7fe69fc..b51ebc7 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -63,6 +63,8 @@
"actions_suggestions_test.multi_task_sr_emoji.model";
constexpr char kSensitiveTFliteModelFileName[] =
"actions_suggestions_test.sensitive_tflite.model";
+constexpr char kLiveRelayTFLiteModelFileName[] =
+ "actions_suggestions_test.live_relay.model";
std::string ReadFile(const std::string& file_name) {
std::ifstream file_stream(file_name);
@@ -1796,6 +1798,7 @@
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello?",
@@ -1805,9 +1808,48 @@
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 5);
EXPECT_EQ(response.actions[0].response_text, "😁");
- EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
- EXPECT_EQ(response.actions[1].response_text, "Yes");
- EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "👋");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Yes");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "a pleasure chatting",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].response_text, "😁");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "😘");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Okay");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
+}
+
+TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kLiveRelayTFLiteModelFileName);
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Hi",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].response_text, "Hi how are you doing");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "Hi whats up");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 8c03eeb..0d8c7ad 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -36,6 +36,17 @@
ENTITY_ANNOTATION = 3,
}
+namespace libtextclassifier3;
+enum RankingOptionsSortType : int {
+ SORT_TYPE_UNSPECIFIED = 0,
+
+ // Rank results (or groups) by score, then type
+ SORT_TYPE_SCORE = 1,
+
+ // Rank results (or groups) by priority score, then score, then type
+ SORT_TYPE_PRIORITY_SCORE = 2,
+}
+
// Prediction metadata for an arbitrary task.
namespace libtextclassifier3;
table PredictionMetadata {
@@ -315,10 +326,11 @@
// Additional entity information.
serialized_entity_data:string (shared);
- // Priority score used for internal conflict resolution.
+ // For ranking and internal conflict resolution.
priority_score:float = 0;
entity_data:ActionsEntityData;
+ response_text_blocklist:[string];
}
// Options to specify triggering behaviour per action class.
@@ -416,6 +428,8 @@
// If true, keep actions from the same entities together for ranking.
group_by_annotations:bool = true;
+
+ sort_type:RankingOptionsSortType = SORT_TYPE_SCORE;
}
// Entity data to set from capturing groups.
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index d52ecaa..46e392a 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,6 +20,8 @@
#include <set>
#include <vector>
+#include "actions/actions_model_generated.h"
+
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
#endif
@@ -34,11 +36,22 @@
namespace {
void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
- std::sort(actions->begin(), actions->end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.score > b.score ||
- (a.score >= b.score && a.type < b.type);
- });
+ std::stable_sort(actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.score > b.score ||
+ (a.score >= b.score && a.type < b.type);
+ });
+}
+
+void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
+ std::stable_sort(
+ actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.priority_score > b.priority_score ||
+ (a.priority_score >= b.priority_score && a.score > b.score) ||
+ (a.priority_score >= b.priority_score && a.score >= b.score &&
+ a.type < b.type);
+ });
}
template <typename T>
@@ -241,13 +254,8 @@
const reflection::Schema* annotations_entity_data_schema) const {
if (options_->deduplicate_suggestions() ||
options_->deduplicate_suggestions_by_span()) {
- // First order suggestions by priority score for deduplication.
- std::sort(
- response->actions.begin(), response->actions.end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.priority_score > b.priority_score ||
- (a.priority_score >= b.priority_score && a.score > b.score);
- });
+ // Order suggestions by [priority score -> score] for deduplication
+ SortByPriorityAndScoreAndType(&response->actions);
// Deduplicate, keeping the higher score actions.
if (options_->deduplicate_suggestions()) {
@@ -275,6 +283,8 @@
}
}
+ bool sort_by_priority =
+ options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
// Suppress smart replies if actions are present.
if (options_->suppress_smart_replies_with_actions()) {
std::vector<ActionSuggestion> non_smart_reply_actions;
@@ -316,17 +326,35 @@
// Sort within each group by score.
for (std::vector<ActionSuggestion>& group : groups) {
- SortByScoreAndType(&group);
+ if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&group);
+ } else {
+ SortByScoreAndType(&group);
+ }
}
- // Sort groups by maximum score.
- std::sort(groups.begin(), groups.end(),
- [](const std::vector<ActionSuggestion>& a,
- const std::vector<ActionSuggestion>& b) {
- return a.begin()->score > b.begin()->score ||
- (a.begin()->score >= b.begin()->score &&
- a.begin()->type < b.begin()->type);
- });
+ // Sort groups by maximum score or priority score.
+ if (sort_by_priority) {
+ std::stable_sort(
+ groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return (a.begin()->priority_score > b.begin()->priority_score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score > b.begin()->score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ } else {
+ std::stable_sort(groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return a.begin()->score > b.begin()->score ||
+ (a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ }
// Flatten result.
const size_t num_actions = response->actions.size();
@@ -336,9 +364,9 @@
response->actions.insert(response->actions.end(), actions.begin(),
actions.end());
}
-
+ } else if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&response->actions);
} else {
- // Order suggestions independently by score.
SortByScoreAndType(&response->actions);
}
diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc
index b52cf45..5eba45f 100644
--- a/native/actions/ranker_test.cc
+++ b/native/actions/ranker_test.cc
@@ -18,6 +18,7 @@
#include <string>
+#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
#include "gmock/gmock.h"
@@ -308,12 +309,12 @@
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
- /*priority_score=*/1.0,
+ /*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -338,6 +339,58 @@
IsAction("text_reply", "How are you?", 0.5)}));
}
+TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply",
+ /*score=*/2.0,
+ /*priority_score=*/0.0});
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact",
+ /*score=*/0.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact2",
+ /*score=*/0.5,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ }
+ RankingOptionsT options;
+ options.group_by_annotations = true;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ // The text reply should be last, even though it's score is higher than
+ // any other scores -- because it's priority_score is lower than the max
+ // of those with the 'phone' annotation
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({
+ // Group 1 (Phone annotation)
+ IsAction("add_contact2", "", 0.5), // priority_score=1.0
+ IsAction("add_contact", "", 0.0), // priority_score=1.0
+ IsAction("call_phone", "", 1.0), // priority_score=0.0
+ IsAction("text_reply", "How are you?", 2.0), // Group 2
+ }));
+}
+
TEST(RankingTest, SortsActionsByScore) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
@@ -349,12 +402,12 @@
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
- /*priority_score=*/1.0,
+ /*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -378,5 +431,40 @@
IsAction("add_contact", "", 0.0)}));
}
+TEST(RankingTest, SortsActionsByPriority) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
+ ActionsSuggestionsResponse response;
+ // emoji replies given higher priority_score
+ response.actions.push_back({/*response_text=*/"😁",
+ /*type=*/"text_reply",
+ /*score=*/0.5,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"👋",
+ /*type=*/"text_reply",
+ /*score=*/0.4,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"Yes",
+ /*type=*/"text_reply",
+ /*score=*/1.0,
+ /*priority_score=*/0.0});
+ RankingOptionsT options;
+ // Don't group by annotation.
+ options.group_by_annotations = false;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ EXPECT_THAT(response.actions, testing::ElementsAreArray(
+ {IsAction("text_reply", "😁", 0.5),
+ IsAction("text_reply", "👋", 0.4),
+ // Ranked last because of priority score
+ IsAction("text_reply", "Yes", 1.0)}));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index d122687..0fa7f7e 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.live_relay.model b/native/actions/test_data/actions_suggestions_test.live_relay.model
new file mode 100644
index 0000000..6ff4302
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.live_relay.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 2d97bc8..6107e98 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 567828b..436ed93 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 99f9040..935691d 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index 504d8e0..2c9f74b 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index 33926c2..cdb7523 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index 730f603..ac28fa2 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 29fe077..d864b79 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index e296a64..e0d4241 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -338,6 +338,12 @@
TC3_LOG(ERROR) << "Could not initialize selection executor.";
return;
}
+ }
+
+ // Even if the annotation mode is not enabled (for the neural network model),
+ // the selection feature processor is needed to tokenize the text for other
+ // models.
+ if (model_->selection_feature_options()) {
selection_feature_processor_.reset(
new FeatureProcessor(model_->selection_feature_options(), unilib_));
}
@@ -967,11 +973,11 @@
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
- std::sort(candidates.annotated_spans[0].begin(),
- candidates.annotated_spans[0].end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- return a.span.first < b.span.first;
- });
+ std::stable_sort(candidates.annotated_spans[0].begin(),
+ candidates.annotated_spans[0].end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
@@ -981,13 +987,14 @@
return original_click_indices;
}
- std::sort(candidate_indices.begin(), candidate_indices.end(),
- [this, &candidates](int a, int b) {
- return GetPriorityScore(
- candidates.annotated_spans[0][a].classification) >
- GetPriorityScore(
- candidates.annotated_spans[0][b].classification);
- });
+ std::stable_sort(
+ candidate_indices.begin(), candidate_indices.end(),
+ [this, &candidates](int a, int b) {
+ return GetPriorityScore(
+ candidates.annotated_spans[0][a].classification) >
+ GetPriorityScore(
+ candidates.annotated_spans[0][b].classification);
+ });
for (const int i : candidate_indices) {
if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
@@ -1167,7 +1174,7 @@
}
}
- std::sort(
+ std::stable_sort(
conflicting_indices.begin(), conflicting_indices.end(),
[this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
if (scores_lengths[i].first == scores_lengths[j].first &&
@@ -1235,7 +1242,7 @@
chosen_indices_for_source_ptr->insert(considered_candidate);
}
- std::sort(chosen_indices->begin(), chosen_indices->end());
+ std::stable_sort(chosen_indices->begin(), chosen_indices->end());
return true;
}
@@ -1408,10 +1415,11 @@
// Sorts the classification results from high score to low score.
void SortClassificationResults(
std::vector<ClassificationResult>* classification_results) {
- std::sort(classification_results->begin(), classification_results->end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
}
} // namespace
@@ -1930,10 +1938,11 @@
}
// Sort results according to score.
- std::sort(results.begin(), results.end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ results.begin(), results.end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
if (results.empty()) {
results = {{Collections::Other(), 1.0}};
@@ -1946,21 +1955,22 @@
const std::vector<Locale>& detected_text_language_tags,
const AnnotationOptions& options, InterpreterManager* interpreter_manager,
std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
+ bool skip_model_annotatation = false;
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
- return true;
+ skip_model_annotatation = true;
}
-
if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
ml_model_triggering_locales_,
/*default_value=*/true)) {
- return true;
+ skip_model_annotatation = true;
}
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
std::vector<UnicodeTextRange> lines;
- if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
+ if (!selection_feature_processor_ ||
+ !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
lines.push_back({context_unicode.begin(), context_unicode.end()});
} else {
lines = selection_feature_processor_->SplitContext(
@@ -1974,7 +1984,6 @@
: 0.f);
for (const UnicodeTextRange& line : lines) {
- FeatureProcessor::EmbeddingCache embedding_cache;
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
@@ -1989,6 +1998,13 @@
const TokenSpan full_line_span = {
0, static_cast<TokenIndex>(line_tokens.size())};
+ tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
+
+ if (skip_model_annotatation) {
+ // We do not annotate, we only output the tokens.
+ continue;
+ }
+
// TODO(zilka): Add support for greater granularity of this check.
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
line_tokens, full_line_span)) {
@@ -2017,68 +2033,46 @@
}
const int offset = std::distance(context_unicode.begin(), line.first);
- UnicodeText line_unicode;
- std::vector<UnicodeText::const_iterator> line_codepoints;
- if (options.enable_optimization) {
- if (local_chunks.empty()) {
- continue;
- }
- line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
- line_codepoints = line_unicode.Codepoints();
- line_codepoints.push_back(line_unicode.end());
+ if (local_chunks.empty()) {
+ continue;
}
+ const UnicodeText line_unicode =
+ UTF8ToUnicodeText(line_str, /*do_copy=*/false);
+ std::vector<UnicodeText::const_iterator> line_codepoints =
+ line_unicode.Codepoints();
+ line_codepoints.push_back(line_unicode.end());
+
+ FeatureProcessor::EmbeddingCache embedding_cache;
for (const TokenSpan& chunk : local_chunks) {
CodepointSpan codepoint_span =
TokenSpanToCodepointSpan(line_tokens, chunk);
- if (options.enable_optimization) {
- if (!codepoint_span.IsValid() ||
- codepoint_span.second > line_codepoints.size()) {
- continue;
- }
- codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ if (!codepoint_span.IsValid() ||
+ codepoint_span.second > line_codepoints.size()) {
+ continue;
+ }
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ codepoint_span = StripUnpairedBrackets(
/*span_begin=*/line_codepoints[codepoint_span.first],
- /*span_end=*/line_codepoints[codepoint_span.second],
- codepoint_span);
- if (model_->selection_options()->strip_unpaired_brackets()) {
- codepoint_span = StripUnpairedBrackets(
- /*span_begin=*/line_codepoints[codepoint_span.first],
- /*span_end=*/line_codepoints[codepoint_span.second],
- codepoint_span, *unilib_);
- }
- } else {
- codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
- line_str, codepoint_span);
- if (model_->selection_options()->strip_unpaired_brackets()) {
- codepoint_span =
- StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
- }
+ /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
+ *unilib_);
}
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (options.enable_optimization) {
- if (!ModelClassifyText(
- line_unicode, line_tokens, detected_text_language_tags,
- /*span_begin=*/line_codepoints[codepoint_span.first],
- /*span_end=*/line_codepoints[codepoint_span.second], &line,
- codepoint_span, options, interpreter_manager,
- &embedding_cache, &classification, /*tokens=*/nullptr)) {
- TC3_LOG(ERROR) << "Could not classify text: "
- << (codepoint_span.first + offset) << " "
- << (codepoint_span.second + offset);
- return false;
- }
- } else {
- if (!ModelClassifyText(line_str, line_tokens,
- detected_text_language_tags, codepoint_span,
- options, interpreter_manager, &embedding_cache,
- &classification, /*tokens=*/nullptr)) {
- TC3_LOG(ERROR) << "Could not classify text: "
- << (codepoint_span.first + offset) << " "
- << (codepoint_span.second + offset);
- return false;
- }
+ if (!ModelClassifyText(
+ line_unicode, line_tokens, detected_text_language_tags,
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second], &line,
+ codepoint_span, options, interpreter_manager, &embedding_cache,
+ &classification, /*tokens=*/nullptr)) {
+ TC3_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
}
// Do not include the span if it's classified as "other".
@@ -2092,16 +2086,6 @@
}
}
}
-
- // If we are going line-by-line, we need to insert the tokens for each line.
- // But if not, we can optimize and just std::move the current line vector to
- // the output.
- if (selection_feature_processor_->GetOptions()
- ->only_use_line_with_click()) {
- tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
- } else {
- *tokens = std::move(line_tokens);
- }
}
return true;
}
@@ -2316,19 +2300,19 @@
// Also sort them according to the end position and collection, so that the
// deduplication code below can assume that same spans and classifications
// form contiguous blocks.
- std::sort(candidates->begin(), candidates->end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- if (a.span.first != b.span.first) {
- return a.span.first < b.span.first;
- }
+ std::stable_sort(candidates->begin(), candidates->end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ if (a.span.first != b.span.first) {
+ return a.span.first < b.span.first;
+ }
- if (a.span.second != b.span.second) {
- return a.span.second < b.span.second;
- }
+ if (a.span.second != b.span.second) {
+ return a.span.second < b.span.second;
+ }
- return a.classification[0].collection <
- b.classification[0].collection;
- });
+ return a.classification[0].collection <
+ b.classification[0].collection;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
@@ -2923,10 +2907,10 @@
return false;
}
}
- std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
- [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
- return lhs.score < rhs.score;
- });
+ std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
@@ -2955,7 +2939,7 @@
chunks->push_back(scored_chunk.token_span);
}
- std::sort(chunks->begin(), chunks->end());
+ std::stable_sort(chunks->begin(), chunks->end());
return true;
}
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
index 3ecc201..a40779e 100644
--- a/native/annotator/annotator_test-include.cc
+++ b/native/annotator/annotator_test-include.cc
@@ -1253,34 +1253,6 @@
}));
}
-TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) {
- std::unique_ptr<Annotator> classifier = Annotator::FromPath(
- GetTestModelPath(), unilib_.get(), calendarlib_.get());
- ASSERT_TRUE(classifier);
-
- AnnotationOptions options;
- options.enable_optimization = true;
-
- EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options),
- ElementsAreArray({
- IsAnnotatedSpan(11, 26, "phone"),
- }));
-
- // Unpaired bracket stripping.
- EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options),
- ElementsAreArray({
- IsAnnotatedSpan(12, 23, "phone"),
- }));
- EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options),
- ElementsAreArray({
- IsAnnotatedSpan(11, 22, "phone"),
- }));
- EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options),
- ElementsAreArray({
- IsAnnotatedSpan(12, 23, "phone"),
- }));
-}
-
TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
diff --git a/native/annotator/annotator_test-include.h b/native/annotator/annotator_test-include.h
index bcbb9e9..a7490e6 100644
--- a/native/annotator/annotator_test-include.h
+++ b/native/annotator/annotator_test-include.h
@@ -45,6 +45,7 @@
ValidateAndInitialize(libtextclassifier3::ViewModel(owned_buffer_.data(),
owned_buffer_.size()),
unilib, calendarlib);
+ AssertIsInitialized();
}
static std::unique_ptr<TestingAnnotator> FromUnownedBuffer(
@@ -59,6 +60,9 @@
}
using Annotator::ResolveConflicts;
+
+ private:
+ void AssertIsInitialized() { ASSERT_TRUE(IsInitialized()); }
};
class AnnotatorTest : public ::testing::TestWithParam<const char*> {
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
index 7d5f440..ff0c775 100644
--- a/native/annotator/datetime/datetime-grounder.cc
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/datetime-grounder.h"
+#include <algorithm>
#include <limits>
#include <unordered_map>
#include <vector>
@@ -250,10 +251,10 @@
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
datetime_parse_result.push_back(result);
}
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index 867c886..94a0961 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,8 @@
#include "annotator/datetime/extractor.h"
+#include <algorithm>
+
#include "annotator/datetime/utils.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
@@ -347,10 +349,11 @@
}
}
- std::sort(found_numbers.begin(), found_numbers.end(),
- [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
- return a.first < b.first;
- });
+ std::stable_sort(
+ found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
int sum = 0;
int running_value = -1;
diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc
index 4dc9c56..5daabd5 100644
--- a/native/annotator/datetime/regex-parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/regex-parser.h"
+#include <algorithm>
#include <iterator>
#include <set>
#include <unordered_set>
@@ -191,17 +192,17 @@
// Resolve conflicts by always picking the longer span and breaking ties by
// selecting the earlier entry in the list for a given locale.
- std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
- [](const std::pair<DatetimeParseResultSpan, int>& a,
- const std::pair<DatetimeParseResultSpan, int>& b) {
- if ((a.first.span.second - a.first.span.first) !=
- (b.first.span.second - b.first.span.first)) {
- return (a.first.span.second - a.first.span.first) >
- (b.first.span.second - b.first.span.first);
- } else {
- return a.second < b.second;
- }
- });
+ std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
std::vector<DatetimeParseResultSpan> results;
std::vector<DatetimeParseResultSpan> resolved_found_spans;
@@ -394,10 +395,10 @@
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
results->push_back(result);
}
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index f82eb44..eab00e1 100644
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -73,6 +73,18 @@
datetime_component:[Datetime_.DatetimeComponent];
}
+namespace libtextclassifier3.EntityData_.Contact_.AlternativeNameInfo_;
+enum AlternativeNameSource : int {
+ NONE = 0,
+ NAME_CORRECTION_LOG = 1,
+}
+
+namespace libtextclassifier3.EntityData_.Contact_;
+table AlternativeNameInfo {
+ name:string (shared);
+ source:AlternativeNameInfo_.AlternativeNameSource;
+}
+
namespace libtextclassifier3.EntityData_;
table Contact {
name:string (shared);
@@ -81,6 +93,7 @@
email_address:string (shared);
phone_number:string (shared);
contact_id:string (shared);
+ alternative_name_info:[Contact_.AlternativeNameInfo];
}
namespace libtextclassifier3.EntityData_;
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 640ceec..2c5a43c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -16,6 +16,7 @@
#include "annotator/translate/translate.h"
+#include <algorithm>
#include <memory>
#include "annotator/collections.h"
@@ -142,11 +143,11 @@
result.push_back({key, value});
}
- std::sort(result.begin(), result.end(),
- [](TranslateAnnotator::LanguageConfidence& a,
- TranslateAnnotator::LanguageConfidence& b) {
- return a.confidence > b.confidence;
- });
+ std::stable_sort(result.begin(), result.end(),
+ [](const TranslateAnnotator::LanguageConfidence& a,
+ const TranslateAnnotator::LanguageConfidence& b) {
+ return a.confidence > b.confidence;
+ });
return result;
}
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 45999cd..ada301c 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -437,7 +437,8 @@
ContactPointer contact_pointer;
std::string contact_name, contact_given_name, contact_family_name,
contact_nickname, contact_email_address, contact_phone_number,
- contact_account_type, contact_account_name, contact_id;
+ contact_account_type, contact_account_name, contact_id,
+ contact_alternate_name;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -615,11 +616,6 @@
// If true, trigger dictionary on words that are of beginner level.
bool trigger_dictionary_on_beginner_words = false;
- // If true, enables an optimized code path for annotation.
- // The optimization caused crashes previously, which is why we are rolling it
- // out using this temporary flag. See: b/178503899
- bool enable_optimization = false;
-
bool operator==(const AnnotationOptions& other) const {
return this->is_serialized_entity_data_enabled ==
other.is_serialized_entity_data_enabled &&
diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
index 469cb1f..49c9ca0 100644
--- a/native/lang_id/common/embedding-network.cc
+++ b/native/lang_id/common/embedding-network.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/embedding-network.h"
+#include <vector>
+
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
index ab8a1a6..4e304fe 100644
--- a/native/lang_id/common/fel/feature-extractor.cc
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/fel/feature-extractor.h"
#include <string>
+#include <vector>
#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/fel-parser.h"
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
index af41e29..60dcc46 100644
--- a/native/lang_id/common/fel/workspace.cc
+++ b/native/lang_id/common/fel/workspace.cc
@@ -18,6 +18,7 @@
#include <atomic>
#include <string>
+#include <vector>
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index f13d802..2ac5b26 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -23,6 +23,7 @@
#include <stddef.h>
+#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 19afcc4..fc925ea 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -29,6 +29,8 @@
#endif
#include <sys/stat.h>
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc
index 199bb69..d227eec 100644
--- a/native/lang_id/common/lite_strings/str-split.cc
+++ b/native/lang_id/common/lite_strings/str-split.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/lite_strings/str-split.h"
+#include <vector>
+
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
index 750341d..249ed57 100644
--- a/native/lang_id/common/math/softmax.cc
+++ b/native/lang_id/common/math/softmax.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/math/softmax.h"
#include <algorithm>
+#include <vector>
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/fastexp.h"
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index dc36fb7..51c8c47 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/lang-id-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/fb_model/model-provider-from-fb.h"
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index 43bf860..d14d403 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/model-provider-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/common/file/file-utils.h"
#include "lang_id/common/file/mmap.h"
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index 92359a9..f7c66f7 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -21,6 +21,7 @@
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "lang_id/common/embedding-feature-interface.h"
diff --git a/native/lang_id/script/approx-script-data.cc b/native/lang_id/script/approx-script-data.cc
index 233653f..678a1e9 100755
--- a/native/lang_id/script/approx-script-data.cc
+++ b/native/lang_id/script/approx-script-data.cc
@@ -27,7 +27,7 @@
namespace mobile {
namespace approx_script_internal {
-const int kNumRanges = 376;
+const int kNumRanges = 389;
const uint32 kRangeFirst[] = {
65, // Range #0: [65, 90, Latin]
@@ -67,8 +67,8 @@
2048, // Range #34: [2048, 2110, Samaritan]
2112, // Range #35: [2112, 2142, Mandaic]
2144, // Range #36: [2144, 2154, Syriac]
- 2208, // Range #37: [2208, 2247, Arabic]
- 2259, // Range #38: [2259, 2273, Arabic]
+ 2160, // Range #37: [2160, 2193, Arabic]
+ 2200, // Range #38: [2200, 2273, Arabic]
2275, // Range #39: [2275, 2303, Arabic]
2304, // Range #40: [2304, 2384, Devanagari]
2389, // Range #41: [2389, 2403, Devanagari]
@@ -87,32 +87,32 @@
3031, // Range #54: [3031, 3031, Tamil]
3046, // Range #55: [3046, 3066, Tamil]
3072, // Range #56: [3072, 3149, Telugu]
- 3157, // Range #57: [3157, 3162, Telugu]
- 3168, // Range #58: [3168, 3183, Telugu]
- 3191, // Range #59: [3191, 3199, Telugu]
- 3200, // Range #60: [3200, 3277, Kannada]
- 3285, // Range #61: [3285, 3286, Kannada]
- 3294, // Range #62: [3294, 3314, Kannada]
- 3328, // Range #63: [3328, 3455, Malayalam]
- 3457, // Range #64: [3457, 3551, Sinhala]
- 3558, // Range #65: [3558, 3572, Sinhala]
- 3585, // Range #66: [3585, 3642, Thai]
- 3648, // Range #67: [3648, 3675, Thai]
- 3713, // Range #68: [3713, 3807, Lao]
- 3840, // Range #69: [3840, 4052, Tibetan]
- 4057, // Range #70: [4057, 4058, Tibetan]
- 4096, // Range #71: [4096, 4255, Myanmar]
- 4256, // Range #72: [4256, 4295, Georgian]
- 4301, // Range #73: [4301, 4346, Georgian]
- 4348, // Range #74: [4348, 4351, Georgian]
- 4352, // Range #75: [4352, 4607, Hangul]
- 4608, // Range #76: [4608, 5017, Ethiopic]
- 5024, // Range #77: [5024, 5117, Cherokee]
- 5120, // Range #78: [5120, 5759, Canadian_Aboriginal]
- 5760, // Range #79: [5760, 5788, Ogham]
- 5792, // Range #80: [5792, 5866, Runic]
- 5870, // Range #81: [5870, 5880, Runic]
- 5888, // Range #82: [5888, 5908, Tagalog]
+ 3157, // Range #57: [3157, 3183, Telugu]
+ 3191, // Range #58: [3191, 3199, Telugu]
+ 3200, // Range #59: [3200, 3277, Kannada]
+ 3285, // Range #60: [3285, 3286, Kannada]
+ 3293, // Range #61: [3293, 3314, Kannada]
+ 3328, // Range #62: [3328, 3455, Malayalam]
+ 3457, // Range #63: [3457, 3551, Sinhala]
+ 3558, // Range #64: [3558, 3572, Sinhala]
+ 3585, // Range #65: [3585, 3642, Thai]
+ 3648, // Range #66: [3648, 3675, Thai]
+ 3713, // Range #67: [3713, 3807, Lao]
+ 3840, // Range #68: [3840, 4052, Tibetan]
+ 4057, // Range #69: [4057, 4058, Tibetan]
+ 4096, // Range #70: [4096, 4255, Myanmar]
+ 4256, // Range #71: [4256, 4295, Georgian]
+ 4301, // Range #72: [4301, 4346, Georgian]
+ 4348, // Range #73: [4348, 4351, Georgian]
+ 4352, // Range #74: [4352, 4607, Hangul]
+ 4608, // Range #75: [4608, 5017, Ethiopic]
+ 5024, // Range #76: [5024, 5117, Cherokee]
+ 5120, // Range #77: [5120, 5759, Canadian_Aboriginal]
+ 5760, // Range #78: [5760, 5788, Ogham]
+ 5792, // Range #79: [5792, 5866, Runic]
+ 5870, // Range #80: [5870, 5880, Runic]
+ 5888, // Range #81: [5888, 5909, Tagalog]
+ 5919, // Range #82: [5919, 5919, Tagalog]
5920, // Range #83: [5920, 5940, Hanunoo]
5952, // Range #84: [5952, 5971, Buhid]
5984, // Range #85: [5984, 6003, Tagbanwa]
@@ -133,7 +133,7 @@
6688, // Range #100: [6688, 6793, Tai_Tham]
6800, // Range #101: [6800, 6809, Tai_Tham]
6816, // Range #102: [6816, 6829, Tai_Tham]
- 6912, // Range #103: [6912, 7036, Balinese]
+ 6912, // Range #103: [6912, 7038, Balinese]
7040, // Range #104: [7040, 7103, Sundanese]
7104, // Range #105: [7104, 7155, Batak]
7164, // Range #106: [7164, 7167, Batak]
@@ -164,7 +164,7 @@
8526, // Range #131: [8526, 8526, Latin]
8544, // Range #132: [8544, 8584, Latin]
10240, // Range #133: [10240, 10495, Braille]
- 11264, // Range #134: [11264, 11358, Glagolitic]
+ 11264, // Range #134: [11264, 11359, Glagolitic]
11360, // Range #135: [11360, 11391, Latin]
11392, // Range #136: [11392, 11507, Coptic]
11513, // Range #137: [11513, 11519, Coptic]
@@ -196,7 +196,7 @@
13008, // Range #163: [13008, 13054, Katakana]
13056, // Range #164: [13056, 13143, Katakana]
13312, // Range #165: [13312, 19903, Han]
- 19968, // Range #166: [19968, 40956, Han]
+ 19968, // Range #166: [19968, 40959, Han]
40960, // Range #167: [40960, 42182, Yi]
42192, // Range #168: [42192, 42239, Lisu]
42240, // Range #169: [42240, 42539, Vai]
@@ -204,208 +204,221 @@
42656, // Range #171: [42656, 42743, Bamum]
42786, // Range #172: [42786, 42887, Latin]
42891, // Range #173: [42891, 42954, Latin]
- 42997, // Range #174: [42997, 43007, Latin]
- 43008, // Range #175: [43008, 43052, Syloti_Nagri]
- 43072, // Range #176: [43072, 43127, Phags_Pa]
- 43136, // Range #177: [43136, 43205, Saurashtra]
- 43214, // Range #178: [43214, 43225, Saurashtra]
- 43232, // Range #179: [43232, 43263, Devanagari]
- 43264, // Range #180: [43264, 43309, Kayah_Li]
- 43311, // Range #181: [43311, 43311, Kayah_Li]
- 43312, // Range #182: [43312, 43347, Rejang]
- 43359, // Range #183: [43359, 43359, Rejang]
- 43360, // Range #184: [43360, 43388, Hangul]
- 43392, // Range #185: [43392, 43469, Javanese]
- 43472, // Range #186: [43472, 43487, Javanese]
- 43488, // Range #187: [43488, 43518, Myanmar]
- 43520, // Range #188: [43520, 43574, Cham]
- 43584, // Range #189: [43584, 43615, Cham]
- 43616, // Range #190: [43616, 43647, Myanmar]
- 43648, // Range #191: [43648, 43714, Tai_Viet]
- 43739, // Range #192: [43739, 43743, Tai_Viet]
- 43744, // Range #193: [43744, 43766, Meetei_Mayek]
- 43777, // Range #194: [43777, 43798, Ethiopic]
- 43808, // Range #195: [43808, 43822, Ethiopic]
- 43824, // Range #196: [43824, 43866, Latin]
- 43868, // Range #197: [43868, 43876, Latin]
- 43877, // Range #198: [43877, 43877, Greek]
- 43878, // Range #199: [43878, 43881, Latin]
- 43888, // Range #200: [43888, 43967, Cherokee]
- 43968, // Range #201: [43968, 44025, Meetei_Mayek]
- 44032, // Range #202: [44032, 55203, Hangul]
- 55216, // Range #203: [55216, 55291, Hangul]
- 63744, // Range #204: [63744, 64217, Han]
- 64256, // Range #205: [64256, 64262, Latin]
- 64275, // Range #206: [64275, 64279, Armenian]
- 64285, // Range #207: [64285, 64335, Hebrew]
- 64336, // Range #208: [64336, 64449, Arabic]
- 64467, // Range #209: [64467, 64829, Arabic]
- 64848, // Range #210: [64848, 64967, Arabic]
- 65008, // Range #211: [65008, 65021, Arabic]
- 65070, // Range #212: [65070, 65071, Cyrillic]
- 65136, // Range #213: [65136, 65276, Arabic]
- 65313, // Range #214: [65313, 65338, Latin]
- 65345, // Range #215: [65345, 65370, Latin]
- 65382, // Range #216: [65382, 65391, Katakana]
- 65393, // Range #217: [65393, 65437, Katakana]
- 65440, // Range #218: [65440, 65500, Hangul]
- 65536, // Range #219: [65536, 65629, Linear_B]
- 65664, // Range #220: [65664, 65786, Linear_B]
- 65856, // Range #221: [65856, 65934, Greek]
- 65952, // Range #222: [65952, 65952, Greek]
- 66176, // Range #223: [66176, 66204, Lycian]
- 66208, // Range #224: [66208, 66256, Carian]
- 66304, // Range #225: [66304, 66339, Old_Italic]
- 66349, // Range #226: [66349, 66351, Old_Italic]
- 66352, // Range #227: [66352, 66378, Gothic]
- 66384, // Range #228: [66384, 66426, Old_Permic]
- 66432, // Range #229: [66432, 66463, Ugaritic]
- 66464, // Range #230: [66464, 66517, Old_Persian]
- 66560, // Range #231: [66560, 66639, Deseret]
- 66640, // Range #232: [66640, 66687, Shavian]
- 66688, // Range #233: [66688, 66729, Osmanya]
- 66736, // Range #234: [66736, 66811, Osage]
- 66816, // Range #235: [66816, 66855, Elbasan]
- 66864, // Range #236: [66864, 66915, Caucasian_Albanian]
- 66927, // Range #237: [66927, 66927, Caucasian_Albanian]
- 67072, // Range #238: [67072, 67382, Linear_A]
- 67392, // Range #239: [67392, 67413, Linear_A]
- 67424, // Range #240: [67424, 67431, Linear_A]
- 67584, // Range #241: [67584, 67647, Cypriot]
- 67648, // Range #242: [67648, 67679, Imperial_Aramaic]
- 67680, // Range #243: [67680, 67711, Palmyrene]
- 67712, // Range #244: [67712, 67742, Nabataean]
- 67751, // Range #245: [67751, 67759, Nabataean]
- 67808, // Range #246: [67808, 67829, Hatran]
- 67835, // Range #247: [67835, 67839, Hatran]
- 67840, // Range #248: [67840, 67871, Phoenician]
- 67872, // Range #249: [67872, 67897, Lydian]
- 67903, // Range #250: [67903, 67903, Lydian]
- 67968, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
- 68000, // Range #252: [68000, 68095, Meroitic_Cursive]
- 68096, // Range #253: [68096, 68102, Kharoshthi]
- 68108, // Range #254: [68108, 68168, Kharoshthi]
- 68176, // Range #255: [68176, 68184, Kharoshthi]
- 68192, // Range #256: [68192, 68223, Old_South_Arabian]
- 68224, // Range #257: [68224, 68255, Old_North_Arabian]
- 68288, // Range #258: [68288, 68342, Manichaean]
- 68352, // Range #259: [68352, 68415, Avestan]
- 68416, // Range #260: [68416, 68447, Inscriptional_Parthian]
- 68448, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
- 68472, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
- 68480, // Range #263: [68480, 68497, Psalter_Pahlavi]
- 68505, // Range #264: [68505, 68508, Psalter_Pahlavi]
- 68521, // Range #265: [68521, 68527, Psalter_Pahlavi]
- 68608, // Range #266: [68608, 68680, Old_Turkic]
- 68736, // Range #267: [68736, 68786, Old_Hungarian]
- 68800, // Range #268: [68800, 68850, Old_Hungarian]
- 68858, // Range #269: [68858, 68863, Old_Hungarian]
- 68864, // Range #270: [68864, 68903, Hanifi_Rohingya]
- 68912, // Range #271: [68912, 68921, Hanifi_Rohingya]
- 69216, // Range #272: [69216, 69246, Arabic]
- 69248, // Range #273: [69248, 69297, Yezidi]
- 69376, // Range #274: [69376, 69415, Old_Sogdian]
- 69424, // Range #275: [69424, 69465, Sogdian]
- 69552, // Range #276: [69552, 69579, Chorasmian]
- 69600, // Range #277: [69600, 69622, Elymaic]
- 69632, // Range #278: [69632, 69743, Brahmi]
- 69759, // Range #279: [69759, 69759, Brahmi]
- 69760, // Range #280: [69760, 69825, Kaithi]
- 69837, // Range #281: [69837, 69837, Kaithi]
- 69840, // Range #282: [69840, 69864, Sora_Sompeng]
- 69872, // Range #283: [69872, 69881, Sora_Sompeng]
- 69888, // Range #284: [69888, 69959, Chakma]
- 69968, // Range #285: [69968, 70006, Mahajani]
- 70016, // Range #286: [70016, 70111, Sharada]
- 70113, // Range #287: [70113, 70132, Sinhala]
- 70144, // Range #288: [70144, 70206, Khojki]
- 70272, // Range #289: [70272, 70313, Multani]
- 70320, // Range #290: [70320, 70378, Khudawadi]
- 70384, // Range #291: [70384, 70393, Khudawadi]
- 70400, // Range #292: [70400, 70457, Grantha]
- 70460, // Range #293: [70460, 70480, Grantha]
- 70487, // Range #294: [70487, 70487, Grantha]
- 70493, // Range #295: [70493, 70516, Grantha]
- 70656, // Range #296: [70656, 70753, Newa]
- 70784, // Range #297: [70784, 70855, Tirhuta]
- 70864, // Range #298: [70864, 70873, Tirhuta]
- 71040, // Range #299: [71040, 71133, Siddham]
- 71168, // Range #300: [71168, 71236, Modi]
- 71248, // Range #301: [71248, 71257, Modi]
- 71264, // Range #302: [71264, 71276, Mongolian]
- 71296, // Range #303: [71296, 71352, Takri]
- 71360, // Range #304: [71360, 71369, Takri]
- 71424, // Range #305: [71424, 71487, Ahom]
- 71680, // Range #306: [71680, 71739, Dogra]
- 71840, // Range #307: [71840, 71922, Warang_Citi]
- 71935, // Range #308: [71935, 71935, Warang_Citi]
- 71936, // Range #309: [71936, 72006, Dives_Akuru]
- 72016, // Range #310: [72016, 72025, Dives_Akuru]
- 72096, // Range #311: [72096, 72164, Nandinagari]
- 72192, // Range #312: [72192, 72263, Zanabazar_Square]
- 72272, // Range #313: [72272, 72354, Soyombo]
- 72384, // Range #314: [72384, 72440, Pau_Cin_Hau]
- 72704, // Range #315: [72704, 72773, Bhaiksuki]
- 72784, // Range #316: [72784, 72812, Bhaiksuki]
- 72816, // Range #317: [72816, 72886, Marchen]
- 72960, // Range #318: [72960, 73031, Masaram_Gondi]
- 73040, // Range #319: [73040, 73049, Masaram_Gondi]
- 73056, // Range #320: [73056, 73112, Gunjala_Gondi]
- 73120, // Range #321: [73120, 73129, Gunjala_Gondi]
- 73440, // Range #322: [73440, 73464, Makasar]
- 73648, // Range #323: [73648, 73648, Lisu]
- 73664, // Range #324: [73664, 73713, Tamil]
- 73727, // Range #325: [73727, 73727, Tamil]
- 73728, // Range #326: [73728, 74649, Cuneiform]
- 74752, // Range #327: [74752, 74868, Cuneiform]
- 74880, // Range #328: [74880, 75075, Cuneiform]
- 77824, // Range #329: [77824, 78904, Egyptian_Hieroglyphs]
- 82944, // Range #330: [82944, 83526, Anatolian_Hieroglyphs]
- 92160, // Range #331: [92160, 92728, Bamum]
- 92736, // Range #332: [92736, 92783, Mro]
- 92880, // Range #333: [92880, 92917, Bassa_Vah]
- 92928, // Range #334: [92928, 92997, Pahawh_Hmong]
- 93008, // Range #335: [93008, 93047, Pahawh_Hmong]
- 93053, // Range #336: [93053, 93071, Pahawh_Hmong]
- 93760, // Range #337: [93760, 93850, Medefaidrin]
- 93952, // Range #338: [93952, 94087, Miao]
- 94095, // Range #339: [94095, 94111, Miao]
- 94176, // Range #340: [94176, 94176, Tangut]
- 94177, // Range #341: [94177, 94177, Nushu]
- 94180, // Range #342: [94180, 94180, Khitan_Small_Script]
- 94192, // Range #343: [94192, 94193, Han]
- 94208, // Range #344: [94208, 100343, Tangut]
- 100352, // Range #345: [100352, 101119, Tangut]
- 101120, // Range #346: [101120, 101589, Khitan_Small_Script]
- 101632, // Range #347: [101632, 101640, Tangut]
- 110592, // Range #348: [110592, 110592, Katakana]
- 110593, // Range #349: [110593, 110878, Hiragana]
- 110928, // Range #350: [110928, 110930, Hiragana]
- 110948, // Range #351: [110948, 110951, Katakana]
- 110960, // Range #352: [110960, 111355, Nushu]
- 113664, // Range #353: [113664, 113770, Duployan]
- 113776, // Range #354: [113776, 113800, Duployan]
- 113808, // Range #355: [113808, 113823, Duployan]
- 119296, // Range #356: [119296, 119365, Greek]
- 120832, // Range #357: [120832, 121483, SignWriting]
- 121499, // Range #358: [121499, 121519, SignWriting]
- 122880, // Range #359: [122880, 122922, Glagolitic]
- 123136, // Range #360: [123136, 123215, Nyiakeng_Puachue_Hmong]
- 123584, // Range #361: [123584, 123641, Wancho]
- 123647, // Range #362: [123647, 123647, Wancho]
- 124928, // Range #363: [124928, 125142, Mende_Kikakui]
- 125184, // Range #364: [125184, 125279, Adlam]
- 126464, // Range #365: [126464, 126523, Arabic]
- 126530, // Range #366: [126530, 126619, Arabic]
- 126625, // Range #367: [126625, 126651, Arabic]
- 126704, // Range #368: [126704, 126705, Arabic]
- 127488, // Range #369: [127488, 127488, Hiragana]
- 131072, // Range #370: [131072, 173789, Han]
- 173824, // Range #371: [173824, 177972, Han]
- 177984, // Range #372: [177984, 183969, Han]
- 183984, // Range #373: [183984, 191456, Han]
- 194560, // Range #374: [194560, 195101, Han]
- 196608, // Range #375: [196608, 201546, Han]
+ 42960, // Range #174: [42960, 42969, Latin]
+ 42994, // Range #175: [42994, 43007, Latin]
+ 43008, // Range #176: [43008, 43052, Syloti_Nagri]
+ 43072, // Range #177: [43072, 43127, Phags_Pa]
+ 43136, // Range #178: [43136, 43205, Saurashtra]
+ 43214, // Range #179: [43214, 43225, Saurashtra]
+ 43232, // Range #180: [43232, 43263, Devanagari]
+ 43264, // Range #181: [43264, 43309, Kayah_Li]
+ 43311, // Range #182: [43311, 43311, Kayah_Li]
+ 43312, // Range #183: [43312, 43347, Rejang]
+ 43359, // Range #184: [43359, 43359, Rejang]
+ 43360, // Range #185: [43360, 43388, Hangul]
+ 43392, // Range #186: [43392, 43469, Javanese]
+ 43472, // Range #187: [43472, 43487, Javanese]
+ 43488, // Range #188: [43488, 43518, Myanmar]
+ 43520, // Range #189: [43520, 43574, Cham]
+ 43584, // Range #190: [43584, 43615, Cham]
+ 43616, // Range #191: [43616, 43647, Myanmar]
+ 43648, // Range #192: [43648, 43714, Tai_Viet]
+ 43739, // Range #193: [43739, 43743, Tai_Viet]
+ 43744, // Range #194: [43744, 43766, Meetei_Mayek]
+ 43777, // Range #195: [43777, 43798, Ethiopic]
+ 43808, // Range #196: [43808, 43822, Ethiopic]
+ 43824, // Range #197: [43824, 43866, Latin]
+ 43868, // Range #198: [43868, 43876, Latin]
+ 43877, // Range #199: [43877, 43877, Greek]
+ 43878, // Range #200: [43878, 43881, Latin]
+ 43888, // Range #201: [43888, 43967, Cherokee]
+ 43968, // Range #202: [43968, 44025, Meetei_Mayek]
+ 44032, // Range #203: [44032, 55203, Hangul]
+ 55216, // Range #204: [55216, 55291, Hangul]
+ 63744, // Range #205: [63744, 64217, Han]
+ 64256, // Range #206: [64256, 64262, Latin]
+ 64275, // Range #207: [64275, 64279, Armenian]
+ 64285, // Range #208: [64285, 64335, Hebrew]
+ 64336, // Range #209: [64336, 64450, Arabic]
+ 64467, // Range #210: [64467, 64829, Arabic]
+ 64832, // Range #211: [64832, 64967, Arabic]
+ 64975, // Range #212: [64975, 64975, Arabic]
+ 65008, // Range #213: [65008, 65023, Arabic]
+ 65070, // Range #214: [65070, 65071, Cyrillic]
+ 65136, // Range #215: [65136, 65276, Arabic]
+ 65313, // Range #216: [65313, 65338, Latin]
+ 65345, // Range #217: [65345, 65370, Latin]
+ 65382, // Range #218: [65382, 65391, Katakana]
+ 65393, // Range #219: [65393, 65437, Katakana]
+ 65440, // Range #220: [65440, 65500, Hangul]
+ 65536, // Range #221: [65536, 65629, Linear_B]
+ 65664, // Range #222: [65664, 65786, Linear_B]
+ 65856, // Range #223: [65856, 65934, Greek]
+ 65952, // Range #224: [65952, 65952, Greek]
+ 66176, // Range #225: [66176, 66204, Lycian]
+ 66208, // Range #226: [66208, 66256, Carian]
+ 66304, // Range #227: [66304, 66339, Old_Italic]
+ 66349, // Range #228: [66349, 66351, Old_Italic]
+ 66352, // Range #229: [66352, 66378, Gothic]
+ 66384, // Range #230: [66384, 66426, Old_Permic]
+ 66432, // Range #231: [66432, 66463, Ugaritic]
+ 66464, // Range #232: [66464, 66517, Old_Persian]
+ 66560, // Range #233: [66560, 66639, Deseret]
+ 66640, // Range #234: [66640, 66687, Shavian]
+ 66688, // Range #235: [66688, 66729, Osmanya]
+ 66736, // Range #236: [66736, 66811, Osage]
+ 66816, // Range #237: [66816, 66855, Elbasan]
+ 66864, // Range #238: [66864, 66915, Caucasian_Albanian]
+ 66927, // Range #239: [66927, 66927, Caucasian_Albanian]
+ 66928, // Range #240: [66928, 67004, Vithkuqi]
+ 67072, // Range #241: [67072, 67382, Linear_A]
+ 67392, // Range #242: [67392, 67413, Linear_A]
+ 67424, // Range #243: [67424, 67431, Linear_A]
+ 67456, // Range #244: [67456, 67514, Latin]
+ 67584, // Range #245: [67584, 67647, Cypriot]
+ 67648, // Range #246: [67648, 67679, Imperial_Aramaic]
+ 67680, // Range #247: [67680, 67711, Palmyrene]
+ 67712, // Range #248: [67712, 67742, Nabataean]
+ 67751, // Range #249: [67751, 67759, Nabataean]
+ 67808, // Range #250: [67808, 67829, Hatran]
+ 67835, // Range #251: [67835, 67839, Hatran]
+ 67840, // Range #252: [67840, 67871, Phoenician]
+ 67872, // Range #253: [67872, 67897, Lydian]
+ 67903, // Range #254: [67903, 67903, Lydian]
+ 67968, // Range #255: [67968, 67999, Meroitic_Hieroglyphs]
+ 68000, // Range #256: [68000, 68095, Meroitic_Cursive]
+ 68096, // Range #257: [68096, 68102, Kharoshthi]
+ 68108, // Range #258: [68108, 68168, Kharoshthi]
+ 68176, // Range #259: [68176, 68184, Kharoshthi]
+ 68192, // Range #260: [68192, 68223, Old_South_Arabian]
+ 68224, // Range #261: [68224, 68255, Old_North_Arabian]
+ 68288, // Range #262: [68288, 68342, Manichaean]
+ 68352, // Range #263: [68352, 68415, Avestan]
+ 68416, // Range #264: [68416, 68447, Inscriptional_Parthian]
+ 68448, // Range #265: [68448, 68466, Inscriptional_Pahlavi]
+ 68472, // Range #266: [68472, 68479, Inscriptional_Pahlavi]
+ 68480, // Range #267: [68480, 68497, Psalter_Pahlavi]
+ 68505, // Range #268: [68505, 68508, Psalter_Pahlavi]
+ 68521, // Range #269: [68521, 68527, Psalter_Pahlavi]
+ 68608, // Range #270: [68608, 68680, Old_Turkic]
+ 68736, // Range #271: [68736, 68786, Old_Hungarian]
+ 68800, // Range #272: [68800, 68850, Old_Hungarian]
+ 68858, // Range #273: [68858, 68863, Old_Hungarian]
+ 68864, // Range #274: [68864, 68903, Hanifi_Rohingya]
+ 68912, // Range #275: [68912, 68921, Hanifi_Rohingya]
+ 69216, // Range #276: [69216, 69246, Arabic]
+ 69248, // Range #277: [69248, 69297, Yezidi]
+ 69376, // Range #278: [69376, 69415, Old_Sogdian]
+ 69424, // Range #279: [69424, 69465, Sogdian]
+ 69488, // Range #280: [69488, 69513, Old_Uyghur]
+ 69552, // Range #281: [69552, 69579, Chorasmian]
+ 69600, // Range #282: [69600, 69622, Elymaic]
+ 69632, // Range #283: [69632, 69749, Brahmi]
+ 69759, // Range #284: [69759, 69759, Brahmi]
+ 69760, // Range #285: [69760, 69826, Kaithi]
+ 69837, // Range #286: [69837, 69837, Kaithi]
+ 69840, // Range #287: [69840, 69864, Sora_Sompeng]
+ 69872, // Range #288: [69872, 69881, Sora_Sompeng]
+ 69888, // Range #289: [69888, 69959, Chakma]
+ 69968, // Range #290: [69968, 70006, Mahajani]
+ 70016, // Range #291: [70016, 70111, Sharada]
+ 70113, // Range #292: [70113, 70132, Sinhala]
+ 70144, // Range #293: [70144, 70206, Khojki]
+ 70272, // Range #294: [70272, 70313, Multani]
+ 70320, // Range #295: [70320, 70378, Khudawadi]
+ 70384, // Range #296: [70384, 70393, Khudawadi]
+ 70400, // Range #297: [70400, 70457, Grantha]
+ 70460, // Range #298: [70460, 70480, Grantha]
+ 70487, // Range #299: [70487, 70487, Grantha]
+ 70493, // Range #300: [70493, 70516, Grantha]
+ 70656, // Range #301: [70656, 70753, Newa]
+ 70784, // Range #302: [70784, 70855, Tirhuta]
+ 70864, // Range #303: [70864, 70873, Tirhuta]
+ 71040, // Range #304: [71040, 71133, Siddham]
+ 71168, // Range #305: [71168, 71236, Modi]
+ 71248, // Range #306: [71248, 71257, Modi]
+ 71264, // Range #307: [71264, 71276, Mongolian]
+ 71296, // Range #308: [71296, 71353, Takri]
+ 71360, // Range #309: [71360, 71369, Takri]
+ 71424, // Range #310: [71424, 71494, Ahom]
+ 71680, // Range #311: [71680, 71739, Dogra]
+ 71840, // Range #312: [71840, 71922, Warang_Citi]
+ 71935, // Range #313: [71935, 71935, Warang_Citi]
+ 71936, // Range #314: [71936, 72006, Dives_Akuru]
+ 72016, // Range #315: [72016, 72025, Dives_Akuru]
+ 72096, // Range #316: [72096, 72164, Nandinagari]
+ 72192, // Range #317: [72192, 72263, Zanabazar_Square]
+ 72272, // Range #318: [72272, 72354, Soyombo]
+ 72368, // Range #319: [72368, 72383, Canadian_Aboriginal]
+ 72384, // Range #320: [72384, 72440, Pau_Cin_Hau]
+ 72704, // Range #321: [72704, 72773, Bhaiksuki]
+ 72784, // Range #322: [72784, 72812, Bhaiksuki]
+ 72816, // Range #323: [72816, 72886, Marchen]
+ 72960, // Range #324: [72960, 73031, Masaram_Gondi]
+ 73040, // Range #325: [73040, 73049, Masaram_Gondi]
+ 73056, // Range #326: [73056, 73112, Gunjala_Gondi]
+ 73120, // Range #327: [73120, 73129, Gunjala_Gondi]
+ 73440, // Range #328: [73440, 73464, Makasar]
+ 73648, // Range #329: [73648, 73648, Lisu]
+ 73664, // Range #330: [73664, 73713, Tamil]
+ 73727, // Range #331: [73727, 73727, Tamil]
+ 73728, // Range #332: [73728, 74649, Cuneiform]
+ 74752, // Range #333: [74752, 74868, Cuneiform]
+ 74880, // Range #334: [74880, 75075, Cuneiform]
+ 77712, // Range #335: [77712, 77810, Cypro_Minoan]
+ 77824, // Range #336: [77824, 78904, Egyptian_Hieroglyphs]
+ 82944, // Range #337: [82944, 83526, Anatolian_Hieroglyphs]
+ 92160, // Range #338: [92160, 92728, Bamum]
+ 92736, // Range #339: [92736, 92783, Mro]
+ 92784, // Range #340: [92784, 92873, Tangsa]
+ 92880, // Range #341: [92880, 92917, Bassa_Vah]
+ 92928, // Range #342: [92928, 92997, Pahawh_Hmong]
+ 93008, // Range #343: [93008, 93047, Pahawh_Hmong]
+ 93053, // Range #344: [93053, 93071, Pahawh_Hmong]
+ 93760, // Range #345: [93760, 93850, Medefaidrin]
+ 93952, // Range #346: [93952, 94087, Miao]
+ 94095, // Range #347: [94095, 94111, Miao]
+ 94176, // Range #348: [94176, 94176, Tangut]
+ 94177, // Range #349: [94177, 94177, Nushu]
+ 94178, // Range #350: [94178, 94179, Han]
+ 94180, // Range #351: [94180, 94180, Khitan_Small_Script]
+ 94192, // Range #352: [94192, 94193, Han]
+ 94208, // Range #353: [94208, 100343, Tangut]
+ 100352, // Range #354: [100352, 101119, Tangut]
+ 101120, // Range #355: [101120, 101589, Khitan_Small_Script]
+ 101632, // Range #356: [101632, 101640, Tangut]
+ 110576, // Range #357: [110576, 110592, Katakana]
+ 110593, // Range #358: [110593, 110879, Hiragana]
+ 110880, // Range #359: [110880, 110882, Katakana]
+ 110928, // Range #360: [110928, 110930, Hiragana]
+ 110948, // Range #361: [110948, 110951, Katakana]
+ 110960, // Range #362: [110960, 111355, Nushu]
+ 113664, // Range #363: [113664, 113770, Duployan]
+ 113776, // Range #364: [113776, 113800, Duployan]
+ 113808, // Range #365: [113808, 113823, Duployan]
+ 119296, // Range #366: [119296, 119365, Greek]
+ 120832, // Range #367: [120832, 121483, SignWriting]
+ 121499, // Range #368: [121499, 121519, SignWriting]
+ 122624, // Range #369: [122624, 122654, Latin]
+ 122880, // Range #370: [122880, 122922, Glagolitic]
+ 123136, // Range #371: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 123536, // Range #372: [123536, 123566, Toto]
+ 123584, // Range #373: [123584, 123641, Wancho]
+ 123647, // Range #374: [123647, 123647, Wancho]
+ 124896, // Range #375: [124896, 124926, Ethiopic]
+ 124928, // Range #376: [124928, 125142, Mende_Kikakui]
+ 125184, // Range #377: [125184, 125279, Adlam]
+ 126464, // Range #378: [126464, 126523, Arabic]
+ 126530, // Range #379: [126530, 126619, Arabic]
+ 126625, // Range #380: [126625, 126651, Arabic]
+ 126704, // Range #381: [126704, 126705, Arabic]
+ 127488, // Range #382: [127488, 127488, Hiragana]
+ 131072, // Range #383: [131072, 173791, Han]
+ 173824, // Range #384: [173824, 177976, Han]
+ 177984, // Range #385: [177984, 183969, Han]
+ 183984, // Range #386: [183984, 191456, Han]
+ 194560, // Range #387: [194560, 195101, Han]
+ 196608, // Range #388: [196608, 201546, Han]
};
const uint16 kRangeSizeMinusOne[] = {
@@ -446,8 +459,8 @@
62, // Range #34: [2048, 2110, Samaritan]
30, // Range #35: [2112, 2142, Mandaic]
10, // Range #36: [2144, 2154, Syriac]
- 39, // Range #37: [2208, 2247, Arabic]
- 14, // Range #38: [2259, 2273, Arabic]
+ 33, // Range #37: [2160, 2193, Arabic]
+ 73, // Range #38: [2200, 2273, Arabic]
28, // Range #39: [2275, 2303, Arabic]
80, // Range #40: [2304, 2384, Devanagari]
14, // Range #41: [2389, 2403, Devanagari]
@@ -466,32 +479,32 @@
0, // Range #54: [3031, 3031, Tamil]
20, // Range #55: [3046, 3066, Tamil]
77, // Range #56: [3072, 3149, Telugu]
- 5, // Range #57: [3157, 3162, Telugu]
- 15, // Range #58: [3168, 3183, Telugu]
- 8, // Range #59: [3191, 3199, Telugu]
- 77, // Range #60: [3200, 3277, Kannada]
- 1, // Range #61: [3285, 3286, Kannada]
- 20, // Range #62: [3294, 3314, Kannada]
- 127, // Range #63: [3328, 3455, Malayalam]
- 94, // Range #64: [3457, 3551, Sinhala]
- 14, // Range #65: [3558, 3572, Sinhala]
- 57, // Range #66: [3585, 3642, Thai]
- 27, // Range #67: [3648, 3675, Thai]
- 94, // Range #68: [3713, 3807, Lao]
- 212, // Range #69: [3840, 4052, Tibetan]
- 1, // Range #70: [4057, 4058, Tibetan]
- 159, // Range #71: [4096, 4255, Myanmar]
- 39, // Range #72: [4256, 4295, Georgian]
- 45, // Range #73: [4301, 4346, Georgian]
- 3, // Range #74: [4348, 4351, Georgian]
- 255, // Range #75: [4352, 4607, Hangul]
- 409, // Range #76: [4608, 5017, Ethiopic]
- 93, // Range #77: [5024, 5117, Cherokee]
- 639, // Range #78: [5120, 5759, Canadian_Aboriginal]
- 28, // Range #79: [5760, 5788, Ogham]
- 74, // Range #80: [5792, 5866, Runic]
- 10, // Range #81: [5870, 5880, Runic]
- 20, // Range #82: [5888, 5908, Tagalog]
+ 26, // Range #57: [3157, 3183, Telugu]
+ 8, // Range #58: [3191, 3199, Telugu]
+ 77, // Range #59: [3200, 3277, Kannada]
+ 1, // Range #60: [3285, 3286, Kannada]
+ 21, // Range #61: [3293, 3314, Kannada]
+ 127, // Range #62: [3328, 3455, Malayalam]
+ 94, // Range #63: [3457, 3551, Sinhala]
+ 14, // Range #64: [3558, 3572, Sinhala]
+ 57, // Range #65: [3585, 3642, Thai]
+ 27, // Range #66: [3648, 3675, Thai]
+ 94, // Range #67: [3713, 3807, Lao]
+ 212, // Range #68: [3840, 4052, Tibetan]
+ 1, // Range #69: [4057, 4058, Tibetan]
+ 159, // Range #70: [4096, 4255, Myanmar]
+ 39, // Range #71: [4256, 4295, Georgian]
+ 45, // Range #72: [4301, 4346, Georgian]
+ 3, // Range #73: [4348, 4351, Georgian]
+ 255, // Range #74: [4352, 4607, Hangul]
+ 409, // Range #75: [4608, 5017, Ethiopic]
+ 93, // Range #76: [5024, 5117, Cherokee]
+ 639, // Range #77: [5120, 5759, Canadian_Aboriginal]
+ 28, // Range #78: [5760, 5788, Ogham]
+ 74, // Range #79: [5792, 5866, Runic]
+ 10, // Range #80: [5870, 5880, Runic]
+ 21, // Range #81: [5888, 5909, Tagalog]
+ 0, // Range #82: [5919, 5919, Tagalog]
20, // Range #83: [5920, 5940, Hanunoo]
19, // Range #84: [5952, 5971, Buhid]
19, // Range #85: [5984, 6003, Tagbanwa]
@@ -512,7 +525,7 @@
105, // Range #100: [6688, 6793, Tai_Tham]
9, // Range #101: [6800, 6809, Tai_Tham]
13, // Range #102: [6816, 6829, Tai_Tham]
- 124, // Range #103: [6912, 7036, Balinese]
+ 126, // Range #103: [6912, 7038, Balinese]
63, // Range #104: [7040, 7103, Sundanese]
51, // Range #105: [7104, 7155, Batak]
3, // Range #106: [7164, 7167, Batak]
@@ -543,7 +556,7 @@
0, // Range #131: [8526, 8526, Latin]
40, // Range #132: [8544, 8584, Latin]
255, // Range #133: [10240, 10495, Braille]
- 94, // Range #134: [11264, 11358, Glagolitic]
+ 95, // Range #134: [11264, 11359, Glagolitic]
31, // Range #135: [11360, 11391, Latin]
115, // Range #136: [11392, 11507, Coptic]
6, // Range #137: [11513, 11519, Coptic]
@@ -575,7 +588,7 @@
46, // Range #163: [13008, 13054, Katakana]
87, // Range #164: [13056, 13143, Katakana]
6591, // Range #165: [13312, 19903, Han]
- 20988, // Range #166: [19968, 40956, Han]
+ 20991, // Range #166: [19968, 40959, Han]
1222, // Range #167: [40960, 42182, Yi]
47, // Range #168: [42192, 42239, Lisu]
299, // Range #169: [42240, 42539, Vai]
@@ -583,208 +596,221 @@
87, // Range #171: [42656, 42743, Bamum]
101, // Range #172: [42786, 42887, Latin]
63, // Range #173: [42891, 42954, Latin]
- 10, // Range #174: [42997, 43007, Latin]
- 44, // Range #175: [43008, 43052, Syloti_Nagri]
- 55, // Range #176: [43072, 43127, Phags_Pa]
- 69, // Range #177: [43136, 43205, Saurashtra]
- 11, // Range #178: [43214, 43225, Saurashtra]
- 31, // Range #179: [43232, 43263, Devanagari]
- 45, // Range #180: [43264, 43309, Kayah_Li]
- 0, // Range #181: [43311, 43311, Kayah_Li]
- 35, // Range #182: [43312, 43347, Rejang]
- 0, // Range #183: [43359, 43359, Rejang]
- 28, // Range #184: [43360, 43388, Hangul]
- 77, // Range #185: [43392, 43469, Javanese]
- 15, // Range #186: [43472, 43487, Javanese]
- 30, // Range #187: [43488, 43518, Myanmar]
- 54, // Range #188: [43520, 43574, Cham]
- 31, // Range #189: [43584, 43615, Cham]
- 31, // Range #190: [43616, 43647, Myanmar]
- 66, // Range #191: [43648, 43714, Tai_Viet]
- 4, // Range #192: [43739, 43743, Tai_Viet]
- 22, // Range #193: [43744, 43766, Meetei_Mayek]
- 21, // Range #194: [43777, 43798, Ethiopic]
- 14, // Range #195: [43808, 43822, Ethiopic]
- 42, // Range #196: [43824, 43866, Latin]
- 8, // Range #197: [43868, 43876, Latin]
- 0, // Range #198: [43877, 43877, Greek]
- 3, // Range #199: [43878, 43881, Latin]
- 79, // Range #200: [43888, 43967, Cherokee]
- 57, // Range #201: [43968, 44025, Meetei_Mayek]
- 11171, // Range #202: [44032, 55203, Hangul]
- 75, // Range #203: [55216, 55291, Hangul]
- 473, // Range #204: [63744, 64217, Han]
- 6, // Range #205: [64256, 64262, Latin]
- 4, // Range #206: [64275, 64279, Armenian]
- 50, // Range #207: [64285, 64335, Hebrew]
- 113, // Range #208: [64336, 64449, Arabic]
- 362, // Range #209: [64467, 64829, Arabic]
- 119, // Range #210: [64848, 64967, Arabic]
- 13, // Range #211: [65008, 65021, Arabic]
- 1, // Range #212: [65070, 65071, Cyrillic]
- 140, // Range #213: [65136, 65276, Arabic]
- 25, // Range #214: [65313, 65338, Latin]
- 25, // Range #215: [65345, 65370, Latin]
- 9, // Range #216: [65382, 65391, Katakana]
- 44, // Range #217: [65393, 65437, Katakana]
- 60, // Range #218: [65440, 65500, Hangul]
- 93, // Range #219: [65536, 65629, Linear_B]
- 122, // Range #220: [65664, 65786, Linear_B]
- 78, // Range #221: [65856, 65934, Greek]
- 0, // Range #222: [65952, 65952, Greek]
- 28, // Range #223: [66176, 66204, Lycian]
- 48, // Range #224: [66208, 66256, Carian]
- 35, // Range #225: [66304, 66339, Old_Italic]
- 2, // Range #226: [66349, 66351, Old_Italic]
- 26, // Range #227: [66352, 66378, Gothic]
- 42, // Range #228: [66384, 66426, Old_Permic]
- 31, // Range #229: [66432, 66463, Ugaritic]
- 53, // Range #230: [66464, 66517, Old_Persian]
- 79, // Range #231: [66560, 66639, Deseret]
- 47, // Range #232: [66640, 66687, Shavian]
- 41, // Range #233: [66688, 66729, Osmanya]
- 75, // Range #234: [66736, 66811, Osage]
- 39, // Range #235: [66816, 66855, Elbasan]
- 51, // Range #236: [66864, 66915, Caucasian_Albanian]
- 0, // Range #237: [66927, 66927, Caucasian_Albanian]
- 310, // Range #238: [67072, 67382, Linear_A]
- 21, // Range #239: [67392, 67413, Linear_A]
- 7, // Range #240: [67424, 67431, Linear_A]
- 63, // Range #241: [67584, 67647, Cypriot]
- 31, // Range #242: [67648, 67679, Imperial_Aramaic]
- 31, // Range #243: [67680, 67711, Palmyrene]
- 30, // Range #244: [67712, 67742, Nabataean]
- 8, // Range #245: [67751, 67759, Nabataean]
- 21, // Range #246: [67808, 67829, Hatran]
- 4, // Range #247: [67835, 67839, Hatran]
- 31, // Range #248: [67840, 67871, Phoenician]
- 25, // Range #249: [67872, 67897, Lydian]
- 0, // Range #250: [67903, 67903, Lydian]
- 31, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
- 95, // Range #252: [68000, 68095, Meroitic_Cursive]
- 6, // Range #253: [68096, 68102, Kharoshthi]
- 60, // Range #254: [68108, 68168, Kharoshthi]
- 8, // Range #255: [68176, 68184, Kharoshthi]
- 31, // Range #256: [68192, 68223, Old_South_Arabian]
- 31, // Range #257: [68224, 68255, Old_North_Arabian]
- 54, // Range #258: [68288, 68342, Manichaean]
- 63, // Range #259: [68352, 68415, Avestan]
- 31, // Range #260: [68416, 68447, Inscriptional_Parthian]
- 18, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
- 7, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
- 17, // Range #263: [68480, 68497, Psalter_Pahlavi]
- 3, // Range #264: [68505, 68508, Psalter_Pahlavi]
- 6, // Range #265: [68521, 68527, Psalter_Pahlavi]
- 72, // Range #266: [68608, 68680, Old_Turkic]
- 50, // Range #267: [68736, 68786, Old_Hungarian]
- 50, // Range #268: [68800, 68850, Old_Hungarian]
- 5, // Range #269: [68858, 68863, Old_Hungarian]
- 39, // Range #270: [68864, 68903, Hanifi_Rohingya]
- 9, // Range #271: [68912, 68921, Hanifi_Rohingya]
- 30, // Range #272: [69216, 69246, Arabic]
- 49, // Range #273: [69248, 69297, Yezidi]
- 39, // Range #274: [69376, 69415, Old_Sogdian]
- 41, // Range #275: [69424, 69465, Sogdian]
- 27, // Range #276: [69552, 69579, Chorasmian]
- 22, // Range #277: [69600, 69622, Elymaic]
- 111, // Range #278: [69632, 69743, Brahmi]
- 0, // Range #279: [69759, 69759, Brahmi]
- 65, // Range #280: [69760, 69825, Kaithi]
- 0, // Range #281: [69837, 69837, Kaithi]
- 24, // Range #282: [69840, 69864, Sora_Sompeng]
- 9, // Range #283: [69872, 69881, Sora_Sompeng]
- 71, // Range #284: [69888, 69959, Chakma]
- 38, // Range #285: [69968, 70006, Mahajani]
- 95, // Range #286: [70016, 70111, Sharada]
- 19, // Range #287: [70113, 70132, Sinhala]
- 62, // Range #288: [70144, 70206, Khojki]
- 41, // Range #289: [70272, 70313, Multani]
- 58, // Range #290: [70320, 70378, Khudawadi]
- 9, // Range #291: [70384, 70393, Khudawadi]
- 57, // Range #292: [70400, 70457, Grantha]
- 20, // Range #293: [70460, 70480, Grantha]
- 0, // Range #294: [70487, 70487, Grantha]
- 23, // Range #295: [70493, 70516, Grantha]
- 97, // Range #296: [70656, 70753, Newa]
- 71, // Range #297: [70784, 70855, Tirhuta]
- 9, // Range #298: [70864, 70873, Tirhuta]
- 93, // Range #299: [71040, 71133, Siddham]
- 68, // Range #300: [71168, 71236, Modi]
- 9, // Range #301: [71248, 71257, Modi]
- 12, // Range #302: [71264, 71276, Mongolian]
- 56, // Range #303: [71296, 71352, Takri]
- 9, // Range #304: [71360, 71369, Takri]
- 63, // Range #305: [71424, 71487, Ahom]
- 59, // Range #306: [71680, 71739, Dogra]
- 82, // Range #307: [71840, 71922, Warang_Citi]
- 0, // Range #308: [71935, 71935, Warang_Citi]
- 70, // Range #309: [71936, 72006, Dives_Akuru]
- 9, // Range #310: [72016, 72025, Dives_Akuru]
- 68, // Range #311: [72096, 72164, Nandinagari]
- 71, // Range #312: [72192, 72263, Zanabazar_Square]
- 82, // Range #313: [72272, 72354, Soyombo]
- 56, // Range #314: [72384, 72440, Pau_Cin_Hau]
- 69, // Range #315: [72704, 72773, Bhaiksuki]
- 28, // Range #316: [72784, 72812, Bhaiksuki]
- 70, // Range #317: [72816, 72886, Marchen]
- 71, // Range #318: [72960, 73031, Masaram_Gondi]
- 9, // Range #319: [73040, 73049, Masaram_Gondi]
- 56, // Range #320: [73056, 73112, Gunjala_Gondi]
- 9, // Range #321: [73120, 73129, Gunjala_Gondi]
- 24, // Range #322: [73440, 73464, Makasar]
- 0, // Range #323: [73648, 73648, Lisu]
- 49, // Range #324: [73664, 73713, Tamil]
- 0, // Range #325: [73727, 73727, Tamil]
- 921, // Range #326: [73728, 74649, Cuneiform]
- 116, // Range #327: [74752, 74868, Cuneiform]
- 195, // Range #328: [74880, 75075, Cuneiform]
- 1080, // Range #329: [77824, 78904, Egyptian_Hieroglyphs]
- 582, // Range #330: [82944, 83526, Anatolian_Hieroglyphs]
- 568, // Range #331: [92160, 92728, Bamum]
- 47, // Range #332: [92736, 92783, Mro]
- 37, // Range #333: [92880, 92917, Bassa_Vah]
- 69, // Range #334: [92928, 92997, Pahawh_Hmong]
- 39, // Range #335: [93008, 93047, Pahawh_Hmong]
- 18, // Range #336: [93053, 93071, Pahawh_Hmong]
- 90, // Range #337: [93760, 93850, Medefaidrin]
- 135, // Range #338: [93952, 94087, Miao]
- 16, // Range #339: [94095, 94111, Miao]
- 0, // Range #340: [94176, 94176, Tangut]
- 0, // Range #341: [94177, 94177, Nushu]
- 0, // Range #342: [94180, 94180, Khitan_Small_Script]
- 1, // Range #343: [94192, 94193, Han]
- 6135, // Range #344: [94208, 100343, Tangut]
- 767, // Range #345: [100352, 101119, Tangut]
- 469, // Range #346: [101120, 101589, Khitan_Small_Script]
- 8, // Range #347: [101632, 101640, Tangut]
- 0, // Range #348: [110592, 110592, Katakana]
- 285, // Range #349: [110593, 110878, Hiragana]
- 2, // Range #350: [110928, 110930, Hiragana]
- 3, // Range #351: [110948, 110951, Katakana]
- 395, // Range #352: [110960, 111355, Nushu]
- 106, // Range #353: [113664, 113770, Duployan]
- 24, // Range #354: [113776, 113800, Duployan]
- 15, // Range #355: [113808, 113823, Duployan]
- 69, // Range #356: [119296, 119365, Greek]
- 651, // Range #357: [120832, 121483, SignWriting]
- 20, // Range #358: [121499, 121519, SignWriting]
- 42, // Range #359: [122880, 122922, Glagolitic]
- 79, // Range #360: [123136, 123215, Nyiakeng_Puachue_Hmong]
- 57, // Range #361: [123584, 123641, Wancho]
- 0, // Range #362: [123647, 123647, Wancho]
- 214, // Range #363: [124928, 125142, Mende_Kikakui]
- 95, // Range #364: [125184, 125279, Adlam]
- 59, // Range #365: [126464, 126523, Arabic]
- 89, // Range #366: [126530, 126619, Arabic]
- 26, // Range #367: [126625, 126651, Arabic]
- 1, // Range #368: [126704, 126705, Arabic]
- 0, // Range #369: [127488, 127488, Hiragana]
- 42717, // Range #370: [131072, 173789, Han]
- 4148, // Range #371: [173824, 177972, Han]
- 5985, // Range #372: [177984, 183969, Han]
- 7472, // Range #373: [183984, 191456, Han]
- 541, // Range #374: [194560, 195101, Han]
- 4938, // Range #375: [196608, 201546, Han]
+ 9, // Range #174: [42960, 42969, Latin]
+ 13, // Range #175: [42994, 43007, Latin]
+ 44, // Range #176: [43008, 43052, Syloti_Nagri]
+ 55, // Range #177: [43072, 43127, Phags_Pa]
+ 69, // Range #178: [43136, 43205, Saurashtra]
+ 11, // Range #179: [43214, 43225, Saurashtra]
+ 31, // Range #180: [43232, 43263, Devanagari]
+ 45, // Range #181: [43264, 43309, Kayah_Li]
+ 0, // Range #182: [43311, 43311, Kayah_Li]
+ 35, // Range #183: [43312, 43347, Rejang]
+ 0, // Range #184: [43359, 43359, Rejang]
+ 28, // Range #185: [43360, 43388, Hangul]
+ 77, // Range #186: [43392, 43469, Javanese]
+ 15, // Range #187: [43472, 43487, Javanese]
+ 30, // Range #188: [43488, 43518, Myanmar]
+ 54, // Range #189: [43520, 43574, Cham]
+ 31, // Range #190: [43584, 43615, Cham]
+ 31, // Range #191: [43616, 43647, Myanmar]
+ 66, // Range #192: [43648, 43714, Tai_Viet]
+ 4, // Range #193: [43739, 43743, Tai_Viet]
+ 22, // Range #194: [43744, 43766, Meetei_Mayek]
+ 21, // Range #195: [43777, 43798, Ethiopic]
+ 14, // Range #196: [43808, 43822, Ethiopic]
+ 42, // Range #197: [43824, 43866, Latin]
+ 8, // Range #198: [43868, 43876, Latin]
+ 0, // Range #199: [43877, 43877, Greek]
+ 3, // Range #200: [43878, 43881, Latin]
+ 79, // Range #201: [43888, 43967, Cherokee]
+ 57, // Range #202: [43968, 44025, Meetei_Mayek]
+ 11171, // Range #203: [44032, 55203, Hangul]
+ 75, // Range #204: [55216, 55291, Hangul]
+ 473, // Range #205: [63744, 64217, Han]
+ 6, // Range #206: [64256, 64262, Latin]
+ 4, // Range #207: [64275, 64279, Armenian]
+ 50, // Range #208: [64285, 64335, Hebrew]
+ 114, // Range #209: [64336, 64450, Arabic]
+ 362, // Range #210: [64467, 64829, Arabic]
+ 135, // Range #211: [64832, 64967, Arabic]
+ 0, // Range #212: [64975, 64975, Arabic]
+ 15, // Range #213: [65008, 65023, Arabic]
+ 1, // Range #214: [65070, 65071, Cyrillic]
+ 140, // Range #215: [65136, 65276, Arabic]
+ 25, // Range #216: [65313, 65338, Latin]
+ 25, // Range #217: [65345, 65370, Latin]
+ 9, // Range #218: [65382, 65391, Katakana]
+ 44, // Range #219: [65393, 65437, Katakana]
+ 60, // Range #220: [65440, 65500, Hangul]
+ 93, // Range #221: [65536, 65629, Linear_B]
+ 122, // Range #222: [65664, 65786, Linear_B]
+ 78, // Range #223: [65856, 65934, Greek]
+ 0, // Range #224: [65952, 65952, Greek]
+ 28, // Range #225: [66176, 66204, Lycian]
+ 48, // Range #226: [66208, 66256, Carian]
+ 35, // Range #227: [66304, 66339, Old_Italic]
+ 2, // Range #228: [66349, 66351, Old_Italic]
+ 26, // Range #229: [66352, 66378, Gothic]
+ 42, // Range #230: [66384, 66426, Old_Permic]
+ 31, // Range #231: [66432, 66463, Ugaritic]
+ 53, // Range #232: [66464, 66517, Old_Persian]
+ 79, // Range #233: [66560, 66639, Deseret]
+ 47, // Range #234: [66640, 66687, Shavian]
+ 41, // Range #235: [66688, 66729, Osmanya]
+ 75, // Range #236: [66736, 66811, Osage]
+ 39, // Range #237: [66816, 66855, Elbasan]
+ 51, // Range #238: [66864, 66915, Caucasian_Albanian]
+ 0, // Range #239: [66927, 66927, Caucasian_Albanian]
+ 76, // Range #240: [66928, 67004, Vithkuqi]
+ 310, // Range #241: [67072, 67382, Linear_A]
+ 21, // Range #242: [67392, 67413, Linear_A]
+ 7, // Range #243: [67424, 67431, Linear_A]
+ 58, // Range #244: [67456, 67514, Latin]
+ 63, // Range #245: [67584, 67647, Cypriot]
+ 31, // Range #246: [67648, 67679, Imperial_Aramaic]
+ 31, // Range #247: [67680, 67711, Palmyrene]
+ 30, // Range #248: [67712, 67742, Nabataean]
+ 8, // Range #249: [67751, 67759, Nabataean]
+ 21, // Range #250: [67808, 67829, Hatran]
+ 4, // Range #251: [67835, 67839, Hatran]
+ 31, // Range #252: [67840, 67871, Phoenician]
+ 25, // Range #253: [67872, 67897, Lydian]
+ 0, // Range #254: [67903, 67903, Lydian]
+ 31, // Range #255: [67968, 67999, Meroitic_Hieroglyphs]
+ 95, // Range #256: [68000, 68095, Meroitic_Cursive]
+ 6, // Range #257: [68096, 68102, Kharoshthi]
+ 60, // Range #258: [68108, 68168, Kharoshthi]
+ 8, // Range #259: [68176, 68184, Kharoshthi]
+ 31, // Range #260: [68192, 68223, Old_South_Arabian]
+ 31, // Range #261: [68224, 68255, Old_North_Arabian]
+ 54, // Range #262: [68288, 68342, Manichaean]
+ 63, // Range #263: [68352, 68415, Avestan]
+ 31, // Range #264: [68416, 68447, Inscriptional_Parthian]
+ 18, // Range #265: [68448, 68466, Inscriptional_Pahlavi]
+ 7, // Range #266: [68472, 68479, Inscriptional_Pahlavi]
+ 17, // Range #267: [68480, 68497, Psalter_Pahlavi]
+ 3, // Range #268: [68505, 68508, Psalter_Pahlavi]
+ 6, // Range #269: [68521, 68527, Psalter_Pahlavi]
+ 72, // Range #270: [68608, 68680, Old_Turkic]
+ 50, // Range #271: [68736, 68786, Old_Hungarian]
+ 50, // Range #272: [68800, 68850, Old_Hungarian]
+ 5, // Range #273: [68858, 68863, Old_Hungarian]
+ 39, // Range #274: [68864, 68903, Hanifi_Rohingya]
+ 9, // Range #275: [68912, 68921, Hanifi_Rohingya]
+ 30, // Range #276: [69216, 69246, Arabic]
+ 49, // Range #277: [69248, 69297, Yezidi]
+ 39, // Range #278: [69376, 69415, Old_Sogdian]
+ 41, // Range #279: [69424, 69465, Sogdian]
+ 25, // Range #280: [69488, 69513, Old_Uyghur]
+ 27, // Range #281: [69552, 69579, Chorasmian]
+ 22, // Range #282: [69600, 69622, Elymaic]
+ 117, // Range #283: [69632, 69749, Brahmi]
+ 0, // Range #284: [69759, 69759, Brahmi]
+ 66, // Range #285: [69760, 69826, Kaithi]
+ 0, // Range #286: [69837, 69837, Kaithi]
+ 24, // Range #287: [69840, 69864, Sora_Sompeng]
+ 9, // Range #288: [69872, 69881, Sora_Sompeng]
+ 71, // Range #289: [69888, 69959, Chakma]
+ 38, // Range #290: [69968, 70006, Mahajani]
+ 95, // Range #291: [70016, 70111, Sharada]
+ 19, // Range #292: [70113, 70132, Sinhala]
+ 62, // Range #293: [70144, 70206, Khojki]
+ 41, // Range #294: [70272, 70313, Multani]
+ 58, // Range #295: [70320, 70378, Khudawadi]
+ 9, // Range #296: [70384, 70393, Khudawadi]
+ 57, // Range #297: [70400, 70457, Grantha]
+ 20, // Range #298: [70460, 70480, Grantha]
+ 0, // Range #299: [70487, 70487, Grantha]
+ 23, // Range #300: [70493, 70516, Grantha]
+ 97, // Range #301: [70656, 70753, Newa]
+ 71, // Range #302: [70784, 70855, Tirhuta]
+ 9, // Range #303: [70864, 70873, Tirhuta]
+ 93, // Range #304: [71040, 71133, Siddham]
+ 68, // Range #305: [71168, 71236, Modi]
+ 9, // Range #306: [71248, 71257, Modi]
+ 12, // Range #307: [71264, 71276, Mongolian]
+ 57, // Range #308: [71296, 71353, Takri]
+ 9, // Range #309: [71360, 71369, Takri]
+ 70, // Range #310: [71424, 71494, Ahom]
+ 59, // Range #311: [71680, 71739, Dogra]
+ 82, // Range #312: [71840, 71922, Warang_Citi]
+ 0, // Range #313: [71935, 71935, Warang_Citi]
+ 70, // Range #314: [71936, 72006, Dives_Akuru]
+ 9, // Range #315: [72016, 72025, Dives_Akuru]
+ 68, // Range #316: [72096, 72164, Nandinagari]
+ 71, // Range #317: [72192, 72263, Zanabazar_Square]
+ 82, // Range #318: [72272, 72354, Soyombo]
+ 15, // Range #319: [72368, 72383, Canadian_Aboriginal]
+ 56, // Range #320: [72384, 72440, Pau_Cin_Hau]
+ 69, // Range #321: [72704, 72773, Bhaiksuki]
+ 28, // Range #322: [72784, 72812, Bhaiksuki]
+ 70, // Range #323: [72816, 72886, Marchen]
+ 71, // Range #324: [72960, 73031, Masaram_Gondi]
+ 9, // Range #325: [73040, 73049, Masaram_Gondi]
+ 56, // Range #326: [73056, 73112, Gunjala_Gondi]
+ 9, // Range #327: [73120, 73129, Gunjala_Gondi]
+ 24, // Range #328: [73440, 73464, Makasar]
+ 0, // Range #329: [73648, 73648, Lisu]
+ 49, // Range #330: [73664, 73713, Tamil]
+ 0, // Range #331: [73727, 73727, Tamil]
+ 921, // Range #332: [73728, 74649, Cuneiform]
+ 116, // Range #333: [74752, 74868, Cuneiform]
+ 195, // Range #334: [74880, 75075, Cuneiform]
+ 98, // Range #335: [77712, 77810, Cypro_Minoan]
+ 1080, // Range #336: [77824, 78904, Egyptian_Hieroglyphs]
+ 582, // Range #337: [82944, 83526, Anatolian_Hieroglyphs]
+ 568, // Range #338: [92160, 92728, Bamum]
+ 47, // Range #339: [92736, 92783, Mro]
+ 89, // Range #340: [92784, 92873, Tangsa]
+ 37, // Range #341: [92880, 92917, Bassa_Vah]
+ 69, // Range #342: [92928, 92997, Pahawh_Hmong]
+ 39, // Range #343: [93008, 93047, Pahawh_Hmong]
+ 18, // Range #344: [93053, 93071, Pahawh_Hmong]
+ 90, // Range #345: [93760, 93850, Medefaidrin]
+ 135, // Range #346: [93952, 94087, Miao]
+ 16, // Range #347: [94095, 94111, Miao]
+ 0, // Range #348: [94176, 94176, Tangut]
+ 0, // Range #349: [94177, 94177, Nushu]
+ 1, // Range #350: [94178, 94179, Han]
+ 0, // Range #351: [94180, 94180, Khitan_Small_Script]
+ 1, // Range #352: [94192, 94193, Han]
+ 6135, // Range #353: [94208, 100343, Tangut]
+ 767, // Range #354: [100352, 101119, Tangut]
+ 469, // Range #355: [101120, 101589, Khitan_Small_Script]
+ 8, // Range #356: [101632, 101640, Tangut]
+ 16, // Range #357: [110576, 110592, Katakana]
+ 286, // Range #358: [110593, 110879, Hiragana]
+ 2, // Range #359: [110880, 110882, Katakana]
+ 2, // Range #360: [110928, 110930, Hiragana]
+ 3, // Range #361: [110948, 110951, Katakana]
+ 395, // Range #362: [110960, 111355, Nushu]
+ 106, // Range #363: [113664, 113770, Duployan]
+ 24, // Range #364: [113776, 113800, Duployan]
+ 15, // Range #365: [113808, 113823, Duployan]
+ 69, // Range #366: [119296, 119365, Greek]
+ 651, // Range #367: [120832, 121483, SignWriting]
+ 20, // Range #368: [121499, 121519, SignWriting]
+ 30, // Range #369: [122624, 122654, Latin]
+ 42, // Range #370: [122880, 122922, Glagolitic]
+ 79, // Range #371: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 30, // Range #372: [123536, 123566, Toto]
+ 57, // Range #373: [123584, 123641, Wancho]
+ 0, // Range #374: [123647, 123647, Wancho]
+ 30, // Range #375: [124896, 124926, Ethiopic]
+ 214, // Range #376: [124928, 125142, Mende_Kikakui]
+ 95, // Range #377: [125184, 125279, Adlam]
+ 59, // Range #378: [126464, 126523, Arabic]
+ 89, // Range #379: [126530, 126619, Arabic]
+ 26, // Range #380: [126625, 126651, Arabic]
+ 1, // Range #381: [126704, 126705, Arabic]
+ 0, // Range #382: [127488, 127488, Hiragana]
+ 42719, // Range #383: [131072, 173791, Han]
+ 4152, // Range #384: [173824, 177976, Han]
+ 5985, // Range #385: [177984, 183969, Han]
+ 7472, // Range #386: [183984, 191456, Han]
+ 541, // Range #387: [194560, 195101, Han]
+ 4938, // Range #388: [196608, 201546, Han]
};
const uint8 kRangeScript[] = {
@@ -825,8 +851,8 @@
126, // Range #34: [2048, 2110, Samaritan]
84, // Range #35: [2112, 2142, Mandaic]
34, // Range #36: [2144, 2154, Syriac]
- 2, // Range #37: [2208, 2247, Arabic]
- 2, // Range #38: [2259, 2273, Arabic]
+ 2, // Range #37: [2160, 2193, Arabic]
+ 2, // Range #38: [2200, 2273, Arabic]
2, // Range #39: [2275, 2303, Arabic]
10, // Range #40: [2304, 2384, Devanagari]
10, // Range #41: [2389, 2403, Devanagari]
@@ -845,32 +871,32 @@
35, // Range #54: [3031, 3031, Tamil]
35, // Range #55: [3046, 3066, Tamil]
36, // Range #56: [3072, 3149, Telugu]
- 36, // Range #57: [3157, 3162, Telugu]
- 36, // Range #58: [3168, 3183, Telugu]