blob: 20ae592099480fb9ead84e1f44b3a05d16d0f21d [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier;
import static com.android.textclassifier.common.ModelFile.LANGUAGE_INDEPENDENT;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.when;
import android.content.Context;
import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
import com.android.textclassifier.common.ModelFile;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.ModelType.ModelTypeDef;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.downloader.DownloadedModelManager;
import com.android.textclassifier.downloader.ModelDownloadManager;
import com.android.textclassifier.downloader.ModelDownloadWorker;
import com.android.textclassifier.testing.SetDefaultLocalesRule;
import com.android.textclassifier.testing.TestingDeviceConfig;
import com.google.common.collect.ImmutableList;
import com.google.common.io.Files;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public final class ModelFileManagerImplTest {
private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
@ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
private TestingDeviceConfig deviceConfig;
@Mock private DownloadedModelManager downloadedModelManager;
@Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
private File rootTestDir;
private ModelFileManagerImpl modelFileManager;
private ModelDownloadManager modelDownloadManager;
private TextClassifierSettings settings;
@Before
public void setup() {
deviceConfig = new TestingDeviceConfig();
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
rootTestDir.mkdirs();
Context context = ApplicationProvider.getApplicationContext();
settings = new TextClassifierSettings(deviceConfig);
modelDownloadManager =
new ModelDownloadManager(
context,
ModelDownloadWorker.class,
downloadedModelManager,
settings,
MoreExecutors.newDirectExecutorService());
modelFileManager = new ModelFileManagerImpl(context, modelDownloadManager, settings);
setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE));
}
@After
public void removeTestDir() {
recursiveDelete(rootTestDir);
}
@Test
public void annotatorModelPreloaded() {
verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
}
@Test
public void actionsModelPreloaded() {
verifyModelPreloadedAsAsset(
ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
}
@Test
public void langIdModelPreloaded() {
verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
}
private void verifyModelPreloadedAsAsset(
@ModelTypeDef String modelType, String expectedModelPath) {
List<ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
List<ModelFile> assetFiles =
modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
assertThat(assetFiles).hasSize(1);
assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
}
@Test
public void findBestModel_versionCode() {
ModelFile olderModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile newerModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2);
ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, /* localePreferences= */ null, /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(newerModelFile);
}
@Test
public void findBestModel_languageDependentModelIsPreferred() {
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile languageDependentModelFile =
createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_noMatchedLanguageModel() {
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile languageDependentModelFile = createModelFile("zh-hk", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@Test
public void findBestModel_languageIsMoreImportantThanVersion() {
ModelFile matchButOlderModel = createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2);
ModelFileManager modelFileManager =
createModelFileManager(matchButOlderModel, mismatchButNewerModel);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(matchButOlderModel);
}
@Test
public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() {
setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh"));
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() {
setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk"));
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, LocaleList.forLanguageTags("zh"), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() {
setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, LocaleList.forLanguageTags("zh"), /*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@Test
public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() {
setDefaultLocalesRule.set(
new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile nonPrimaryLocaleModelFile = createModelFile("zh-hk", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null);
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@Test
public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() {
setDefaultLocalesRule.set(
new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
ModelFile nonPrimaryLocalePreferenceModelFile = createModelFile("zh-hk", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile);
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE,
new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")),
/*detectedLocales=*/ null);
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@Test
public void findBestModel_multiLanguageEnabled_noMatchedModel() {
setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
ModelFile primaryLocalePreferenceModelFile = createModelFile("en", /* version= */ 1);
ModelFile secondaryLocalePreferencetModelFile = createModelFile("zh-hk", /* version */ 1);
ModelFileManager modelFileManager =
createModelFileManager(
primaryLocalePreferenceModelFile, secondaryLocalePreferencetModelFile);
final LocaleList requestLocalePreferences =
new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("fy"));
final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("hr");
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
assertThat(bestModelFile).isEqualTo(primaryLocalePreferenceModelFile);
}
@Test
public void findBestModel_multiLanguageEnabled_matchDetected() {
setDefaultLocalesRule.set(
new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
ModelFile localePreferenceModelFile = createModelFile("zh", /*version*/ 1);
ModelFileManager modelFileManager = createModelFileManager(localePreferenceModelFile);
final LocaleList requestLocalePreferences =
new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("zh"));
final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("zh");
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
assertThat(bestModelFile).isEqualTo(localePreferenceModelFile);
}
@Test
public void findBestModel_multiLanguageDisabled_matchDetected() {
setDefaultLocalesRule.set(
new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, false);
ModelFile nonLocalePreferenceModelFile = createModelFile("zh", /*version*/ 1);
ModelFileManager modelFileManager = createModelFileManager(nonLocalePreferenceModelFile);
final LocaleList requestLocalePreferences = new LocaleList(Locale.forLanguageTag("en"));
final LocaleList detectedLocalePreferences = LocaleList.getEmptyLocaleList();
ModelFile bestModelFile =
modelFileManager.findBestModelFile(
MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
assertThat(bestModelFile).isEqualTo(null);
}
@Test
public void downloaderModelsLister() throws IOException {
File annotatorFile = new File(rootTestDir, "annotator.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
File langIdFile = new File(rootTestDir, "langId.model");
Files.copy(TestDataUtils.getLangIdModelFile(), langIdFile);
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
DownloaderModelsLister downloaderModelsLister =
new DownloaderModelsLister(modelDownloadManager, settings);
when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
when(downloadedModelManager.listModels(ModelType.LANG_ID))
.thenReturn(Arrays.asList(langIdFile));
when(downloadedModelManager.listModels(ModelType.ACTIONS_SUGGESTIONS))
.thenReturn(new ArrayList<>());
assertThat(downloaderModelsLister.list(MODEL_TYPE))
.containsExactly(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
assertThat(downloaderModelsLister.list(ModelType.LANG_ID))
.containsExactly(ModelFile.createFromRegularFile(langIdFile, ModelType.LANG_ID));
assertThat(downloaderModelsLister.list(ModelType.ACTIONS_SUGGESTIONS)).isEmpty();
}
@Test
public void downloaderModelsLister_checkModelFileManager() throws IOException {
File annotatorFile = new File(rootTestDir, "test.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
assertThat(modelFileManager.listModelFiles(MODEL_TYPE))
.contains(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
}
@Test
public void downloaderModelsLister_disabled() throws IOException {
File annotatorFile = new File(rootTestDir, "test.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
DownloaderModelsLister downloaderModelsLister =
new DownloaderModelsLister(modelDownloadManager, settings);
when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
assertThat(downloaderModelsLister.list(MODEL_TYPE)).isEmpty();
}
@Test
public void regularFileFullMatchLister() throws IOException {
File modelFile = new File(rootTestDir, "test.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
File wrongFile = new File(rootTestDir, "wrong.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
RegularFileFullMatchLister regularFileFullMatchLister =
new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
assertThat(listedModels).hasSize(1);
assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
assertThat(listedModels.get(0).isAsset).isFalse();
}
@Test
public void regularFilePatternMatchLister() throws IOException {
File modelFile1 = new File(rootTestDir, "annotator.en.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
File modelFile2 = new File(rootTestDir, "annotator.fr.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
RegularFilePatternMatchLister regularFilePatternMatchLister =
new RegularFilePatternMatchLister(
MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
assertThat(listedModels).hasSize(2);
assertThat(listedModels.get(0).isAsset).isFalse();
assertThat(listedModels.get(1).isAsset).isFalse();
assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
.containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
}
@Test
public void regularFilePatternMatchLister_disabled() throws IOException {
File modelFile1 = new File(rootTestDir, "annotator.en.model");
Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
RegularFilePatternMatchLister regularFilePatternMatchLister =
new RegularFilePatternMatchLister(
MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
assertThat(listedModels).isEmpty();
}
private ModelFileManager createModelFileManager(ModelFile... modelFiles) {
return new ModelFileManagerImpl(
ApplicationProvider.getApplicationContext(),
ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)),
settings);
}
private ModelFile createModelFile(String supportedLocaleTags, int version) {
return new ModelFile(
MODEL_TYPE,
new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version))
.getAbsolutePath(),
version,
supportedLocaleTags,
/* isAsset= */ false);
}
private static void recursiveDelete(File f) {
if (f.isDirectory()) {
for (File innerFile : f.listFiles()) {
recursiveDelete(innerFile);
}
}
f.delete();
}
}