| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.android.textclassifier; |
| |
| import android.content.Context; |
| import android.os.LocaleList; |
| import android.os.ParcelFileDescriptor; |
| import android.text.TextUtils; |
| import androidx.annotation.GuardedBy; |
| import androidx.annotation.StringDef; |
| import com.android.textclassifier.ModelFileManager.ModelFile; |
| import com.android.textclassifier.ModelFileManager.ModelFile.ModelType; |
| import com.android.textclassifier.common.base.TcLog; |
| import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo; |
| import com.android.textclassifier.utils.IndentingPrintWriter; |
| import com.google.android.textclassifier.ActionsSuggestionsModel; |
| import com.google.android.textclassifier.AnnotatorModel; |
| import com.google.android.textclassifier.LangIdModel; |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Optional; |
| import com.google.common.base.Preconditions; |
| import com.google.common.base.Splitter; |
| import com.google.common.collect.ImmutableList; |
| import com.google.common.collect.ImmutableMap; |
| import java.io.File; |
| import java.io.FileNotFoundException; |
| import java.io.IOException; |
| import java.lang.annotation.Retention; |
| import java.lang.annotation.RetentionPolicy; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.List; |
| import java.util.Locale; |
| import java.util.Objects; |
| import java.util.function.Function; |
| import java.util.function.Supplier; |
| import java.util.regex.Matcher; |
| import java.util.regex.Pattern; |
| import java.util.stream.Collectors; |
| import javax.annotation.Nullable; |
| |
| /** |
| * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the |
| * model files to load. |
| */ |
| final class ModelFileManager { |
| private static final String TAG = "ModelFileManager"; |
| private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/"; |
| |
| private final File downloadModelDir; |
| private final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers; |
| |
| /** Create a ModelFileManager based on hardcoded model file locations. */ |
| public ModelFileManager(Context context, TextClassifierSettings settings) { |
| Preconditions.checkNotNull(context); |
| Preconditions.checkNotNull(settings); |
| this.downloadModelDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME); |
| if (!downloadModelDir.exists()) { |
| downloadModelDir.mkdirs(); |
| } |
| |
| ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>> suppliersBuilder = |
| ImmutableMap.builder(); |
| for (String modelType : ModelType.values()) { |
| suppliersBuilder.put( |
| modelType, new ModelFileSupplierImpl(settings, modelType, downloadModelDir)); |
| } |
| this.modelFileSuppliers = suppliersBuilder.build(); |
| } |
| |
| @VisibleForTesting |
| ModelFileManager( |
| File downloadModelDir, |
| ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers) { |
| this.downloadModelDir = Preconditions.checkNotNull(downloadModelDir); |
| this.modelFileSuppliers = Preconditions.checkNotNull(modelFileSuppliers); |
| } |
| |
| /** |
| * Returns an immutable list of model files listed by the given model files supplier. |
| * |
| * @param modelType which type of model files to look for |
| */ |
| public ImmutableList<ModelFile> listModelFiles(@ModelType.ModelTypeDef String modelType) { |
| if (modelFileSuppliers.containsKey(modelType)) { |
| return modelFileSuppliers.get(modelType).get(); |
| } |
| return ImmutableList.of(); |
| } |
| |
| /** |
| * Returns the best model file for the given localelist, {@code null} if nothing is found. |
| * |
| * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.) |
| * @param localeList an ordered list of user preferences for locales, use {@code null} if there is |
| * no preference. |
| */ |
| @Nullable |
| public ModelFile findBestModelFile( |
| @ModelType.ModelTypeDef String modelType, @Nullable LocaleList localeList) { |
| final String languages = |
| localeList == null || localeList.isEmpty() |
| ? LocaleList.getDefault().toLanguageTags() |
| : localeList.toLanguageTags(); |
| final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); |
| |
| ModelFile bestModel = null; |
| for (ModelFile model : listModelFiles(modelType)) { |
| if (model.isAnyLanguageSupported(languageRangeList)) { |
| if (model.isPreferredTo(bestModel)) { |
| bestModel = model; |
| } |
| } |
| } |
| return bestModel; |
| } |
| |
| /** |
| * Returns a {@link File} that represents the destination to download a model. |
| * |
| * <p>Each model file's name is uniquely formatted based on its unique remote URL address. |
| * |
| * <p>{@link ModelDownloadManager} needs to call this to get the right location and file name. |
| * |
| * @param modelType the type of the model image to download |
| * @param url the unique remote url of the model image |
| */ |
| public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String url) { |
| String fileName = String.format("%s.%d.model", modelType, url.hashCode()); |
| return new File(downloadModelDir, fileName); |
| } |
| |
| /** |
| * Dumps the internal state for debugging. |
| * |
| * @param printWriter writer to write dumped states |
| */ |
| public void dump(IndentingPrintWriter printWriter) { |
| printWriter.println("ModelFileManager:"); |
| printWriter.increaseIndent(); |
| for (@ModelType.ModelTypeDef String modelType : ModelType.values()) { |
| printWriter.println(modelType + " model file(s):"); |
| printWriter.increaseIndent(); |
| for (ModelFile modelFile : listModelFiles(modelType)) { |
| printWriter.println(modelFile.toString()); |
| } |
| printWriter.decreaseIndent(); |
| } |
| printWriter.decreaseIndent(); |
| } |
| |
| /** Default implementation of the model file supplier. */ |
| @VisibleForTesting |
| static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> { |
| private static final String FACTORY_MODEL_DIR = "/etc/textclassifier/"; |
| |
| private static final class ModelFileInfo { |
| private final String modelNameRegex; |
| private final String configUpdaterModelPath; |
| private final Function<Integer, Integer> versionSupplier; |
| private final Function<Integer, String> supportedLocalesSupplier; |
| |
| public ModelFileInfo( |
| String modelNameRegex, |
| String configUpdaterModelPath, |
| Function<Integer, Integer> versionSupplier, |
| Function<Integer, String> supportedLocalesSupplier) { |
| this.modelNameRegex = Preconditions.checkNotNull(modelNameRegex); |
| this.configUpdaterModelPath = Preconditions.checkNotNull(configUpdaterModelPath); |
| this.versionSupplier = Preconditions.checkNotNull(versionSupplier); |
| this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier); |
| } |
| |
| public String getModelNameRegex() { |
| return modelNameRegex; |
| } |
| |
| public String getConfigUpdaterModelPath() { |
| return configUpdaterModelPath; |
| } |
| |
| public Function<Integer, Integer> getVersionSupplier() { |
| return versionSupplier; |
| } |
| |
| public Function<Integer, String> getSupportedLocalesSupplier() { |
| return supportedLocalesSupplier; |
| } |
| } |
| |
| private static final ImmutableMap<String, ModelFileInfo> MODEL_FILE_INFO_MAP = |
| ImmutableMap.<String, ModelFileInfo>builder() |
| .put( |
| ModelType.ANNOTATOR, |
| new ModelFileInfo( |
| "(annotator|textclassifier)\\.(.*)\\.model", |
| "/data/misc/textclassifier/textclassifier.model", |
| AnnotatorModel::getVersion, |
| AnnotatorModel::getLocales)) |
| .put( |
| ModelType.LANG_ID, |
| new ModelFileInfo( |
| "lang_id.model", |
| "/data/misc/textclassifier/lang_id.model", |
| LangIdModel::getVersion, |
| fd -> ModelFile.LANGUAGE_INDEPENDENT)) |
| .put( |
| ModelType.ACTIONS_SUGGESTIONS, |
| new ModelFileInfo( |
| "actions_suggestions\\.(.*)\\.model", |
| "/data/misc/textclassifier/actions_suggestions.model", |
| ActionsSuggestionsModel::getVersion, |
| ActionsSuggestionsModel::getLocales)) |
| .build(); |
| |
| private final TextClassifierSettings settings; |
| @ModelType.ModelTypeDef private final String modelType; |
| private final File configUpdaterModelFile; |
| private final File downloaderModelDir; |
| private final File factoryModelDir; |
| private final Pattern modelFilenamePattern; |
| private final Function<Integer, Integer> versionSupplier; |
| private final Function<Integer, String> supportedLocalesSupplier; |
| private final Object lock = new Object(); |
| |
| @GuardedBy("lock") |
| private ImmutableList<ModelFile> factoryModels; |
| |
| public ModelFileSupplierImpl( |
| TextClassifierSettings settings, |
| @ModelType.ModelTypeDef String modelType, |
| File downloaderModelDir) { |
| this( |
| settings, |
| modelType, |
| new File(FACTORY_MODEL_DIR), |
| MODEL_FILE_INFO_MAP.get(modelType).getModelNameRegex(), |
| new File(MODEL_FILE_INFO_MAP.get(modelType).getConfigUpdaterModelPath()), |
| downloaderModelDir, |
| MODEL_FILE_INFO_MAP.get(modelType).getVersionSupplier(), |
| MODEL_FILE_INFO_MAP.get(modelType).getSupportedLocalesSupplier()); |
| } |
| |
| @VisibleForTesting |
| ModelFileSupplierImpl( |
| TextClassifierSettings settings, |
| @ModelType.ModelTypeDef String modelType, |
| File factoryModelDir, |
| String modelFileNameRegex, |
| File configUpdaterModelFile, |
| File downloaderModelDir, |
| Function<Integer, Integer> versionSupplier, |
| Function<Integer, String> supportedLocalesSupplier) { |
| this.settings = Preconditions.checkNotNull(settings); |
| this.modelType = Preconditions.checkNotNull(modelType); |
| this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir); |
| this.modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(modelFileNameRegex)); |
| this.configUpdaterModelFile = Preconditions.checkNotNull(configUpdaterModelFile); |
| this.downloaderModelDir = Preconditions.checkNotNull(downloaderModelDir); |
| this.versionSupplier = Preconditions.checkNotNull(versionSupplier); |
| this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier); |
| } |
| |
| @Override |
| public ImmutableList<ModelFile> get() { |
| final List<ModelFile> modelFiles = new ArrayList<>(); |
| // The dwonloader and config updater model have higher precedences. |
| if (downloaderModelDir.exists() && settings.isModelDownloadManagerEnabled()) { |
| modelFiles.addAll(getMatchedModelFiles(downloaderModelDir)); |
| } |
| if (configUpdaterModelFile.exists()) { |
| final ModelFile updatedModel = createModelFile(configUpdaterModelFile); |
| if (updatedModel != null) { |
| modelFiles.add(updatedModel); |
| } |
| } |
| // Factory models should never have overlapping locales, so the order doesn't matter. |
| synchronized (lock) { |
| if (factoryModels == null) { |
| factoryModels = getMatchedModelFiles(factoryModelDir); |
| } |
| modelFiles.addAll(factoryModels); |
| } |
| return ImmutableList.copyOf(modelFiles); |
| } |
| |
| private ImmutableList<ModelFile> getMatchedModelFiles(File parentDir) { |
| ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder(); |
| if (parentDir.exists() && parentDir.isDirectory()) { |
| final File[] files = parentDir.listFiles(); |
| for (File file : files) { |
| final Matcher matcher = modelFilenamePattern.matcher(file.getName()); |
| if (matcher.matches() && file.isFile()) { |
| final ModelFile model = createModelFile(file); |
| if (model != null) { |
| modelFilesBuilder.add(model); |
| } |
| } |
| } |
| } |
| return modelFilesBuilder.build(); |
| } |
| |
| /** Returns null if the path did not point to a compatible model. */ |
| @Nullable |
| private ModelFile createModelFile(File file) { |
| if (!file.exists()) { |
| return null; |
| } |
| ParcelFileDescriptor modelFd = null; |
| try { |
| modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY); |
| if (modelFd == null) { |
| return null; |
| } |
| final int modelFdInt = modelFd.getFd(); |
| final int version = versionSupplier.apply(modelFdInt); |
| final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt); |
| if (supportedLocalesStr.isEmpty()) { |
| TcLog.d(TAG, "Ignoring " + file.getAbsolutePath()); |
| return null; |
| } |
| final List<Locale> supportedLocales = new ArrayList<>(); |
| for (String langTag : Splitter.on(',').split(supportedLocalesStr)) { |
| supportedLocales.add(Locale.forLanguageTag(langTag)); |
| } |
| return new ModelFile( |
| modelType, |
| file, |
| version, |
| supportedLocales, |
| supportedLocalesStr, |
| ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr)); |
| } catch (FileNotFoundException e) { |
| TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e); |
| return null; |
| } finally { |
| maybeCloseAndLogError(modelFd); |
| } |
| } |
| |
| /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */ |
| private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { |
| if (fd == null) { |
| return; |
| } |
| try { |
| fd.close(); |
| } catch (IOException e) { |
| TcLog.e(TAG, "Error closing file.", e); |
| } |
| } |
| } |
| |
| /** Describes TextClassifier model files on disk. */ |
| public static final class ModelFile { |
| public static final String LANGUAGE_INDEPENDENT = "*"; |
| |
| @ModelType.ModelTypeDef private final String modelType; |
| private final File file; |
| private final int version; |
| private final List<Locale> supportedLocales; |
| private final String supportedLocalesStr; |
| private final boolean languageIndependent; |
| |
| public ModelFile( |
| @ModelType.ModelTypeDef String modelType, |
| File file, |
| int version, |
| List<Locale> supportedLocales, |
| String supportedLocalesStr, |
| boolean languageIndependent) { |
| this.modelType = Preconditions.checkNotNull(modelType); |
| this.file = Preconditions.checkNotNull(file); |
| this.version = version; |
| this.supportedLocales = Preconditions.checkNotNull(supportedLocales); |
| this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr); |
| this.languageIndependent = languageIndependent; |
| } |
| |
| /** Returns the type of this model, defined in {@link ModelType}. */ |
| @ModelType.ModelTypeDef |
| public String getModelType() { |
| return modelType; |
| } |
| |
| /** Returns the absolute path to the model file. */ |
| public String getPath() { |
| return file.getAbsolutePath(); |
| } |
| |
| /** Returns a name to use for id generation, effectively the name of the model file. */ |
| public String getName() { |
| return file.getName(); |
| } |
| |
| /** Returns the version tag in the model's metadata. */ |
| public int getVersion() { |
| return version; |
| } |
| |
| /** Returns whether the language supports any language in the given ranges. */ |
| public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) { |
| Preconditions.checkNotNull(languageRanges); |
| return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null; |
| } |
| |
| /** Returns if this model file is preferred to the given one. */ |
| public boolean isPreferredTo(@Nullable ModelFile model) { |
| // A model is preferred to no model. |
| if (model == null) { |
| return true; |
| } |
| |
| // A language-specific model is preferred to a language independent |
| // model. |
| if (!languageIndependent && model.languageIndependent) { |
| return true; |
| } |
| if (languageIndependent && !model.languageIndependent) { |
| return false; |
| } |
| |
| // A higher-version model is preferred. |
| if (version > model.getVersion()) { |
| return true; |
| } |
| return false; |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(getPath()); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| if (this == other) { |
| return true; |
| } |
| if (other instanceof ModelFile) { |
| final ModelFile otherModel = (ModelFile) other; |
| return TextUtils.equals(getPath(), otherModel.getPath()); |
| } |
| return false; |
| } |
| |
| public ModelInfo toModelInfo() { |
| return new ModelInfo(getVersion(), supportedLocalesStr); |
| } |
| |
| @Override |
| public String toString() { |
| return String.format( |
| Locale.US, |
| "ModelFile { type=%s path=%s name=%s version=%d locales=%s }", |
| modelType, |
| getPath(), |
| getName(), |
| version, |
| supportedLocalesStr); |
| } |
| |
| public static ImmutableList<Optional<ModelInfo>> toModelInfos( |
| Optional<ModelFile>... modelFiles) { |
| return Arrays.stream(modelFiles) |
| .map(modelFile -> modelFile.transform(ModelFile::toModelInfo)) |
| .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf)); |
| } |
| |
| /** Effectively an enum class to represent types of models. */ |
| public static final class ModelType { |
| @Retention(RetentionPolicy.SOURCE) |
| @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS}) |
| public @interface ModelTypeDef {} |
| |
| public static final String ANNOTATOR = "annotator"; |
| public static final String LANG_ID = "lang_id"; |
| public static final String ACTIONS_SUGGESTIONS = "actions_suggestions"; |
| |
| public static final ImmutableList<String> VALUES = |
| ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS); |
| |
| public static ImmutableList<String> values() { |
| return VALUES; |
| } |
| |
| private ModelType() {} |
| } |
| } |
| } |