blob: a71cdefc4848ec5cecdfd5d226e83c9c41a89aab [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier.downloader;
import static java.lang.Math.min;
import android.content.Context;
import android.os.LocaleList;
import android.util.ArrayMap;
import android.util.Pair;
import androidx.work.ListenableWorker;
import androidx.work.WorkerParameters;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.ModelType.ModelTypeDef;
import com.android.textclassifier.common.TextClassifierServiceExecutors;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Locale;
/** The WorkManager worker to download models for TextClassifierService. */
public final class ModelDownloadWorker extends ListenableWorker {
private static final String TAG = "ModelDownloadWorker";
public static final String INPUT_DATA_KEY_WORK_ID = "ModelDownloadWorker_workId";
public static final String INPUT_DATA_KEY_SCHEDULED_TIMESTAMP =
"ModelDownloadWorker_scheduledTimestamp";
private final ListeningExecutorService executorService;
private final ModelDownloader downloader;
private final DownloadedModelManager downloadedModelManager;
private final TextClassifierSettings settings;
private final long workId;
private final Clock clock;
private final long workScheduledTimeMillis;
private final Object lock = new Object();
private long workStartedTimeMillis = 0;
@GuardedBy("lock")
private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads;
private ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload;
public ModelDownloadWorker(Context context, WorkerParameters workerParams) {
super(context, workerParams);
this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor();
this.downloader = new ModelDownloaderImpl(context, executorService);
this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context);
this.settings = new TextClassifierSettings(context);
this.pendingDownloads = new ArrayMap<>();
this.manifestsToDownload = null;
this.workId = workerParams.getInputData().getLong(INPUT_DATA_KEY_WORK_ID, 0);
this.workScheduledTimeMillis =
workerParams.getInputData().getLong(INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, 0);
this.clock = Clock.systemUTC();
}
@VisibleForTesting
ModelDownloadWorker(
Context context,
WorkerParameters workerParams,
ListeningExecutorService executorService,
ModelDownloader modelDownloader,
DownloadedModelManager downloadedModelManager,
TextClassifierSettings settings,
long workId,
Clock clock,
long workScheduledTimeMillis) {
super(context, workerParams);
this.executorService = executorService;
this.downloader = modelDownloader;
this.downloadedModelManager = downloadedModelManager;
this.settings = settings;
this.pendingDownloads = new ArrayMap<>();
this.manifestsToDownload = null;
this.workId = workId;
this.clock = clock;
this.workScheduledTimeMillis = workScheduledTimeMillis;
}
@Override
public final ListenableFuture<ListenableWorker.Result> startWork() {
TcLog.d(TAG, "Start download work...");
workStartedTimeMillis = getCurrentTimeMillis();
// Notice: startWork() is invoked on the main thread
if (!settings.isModelDownloadManagerEnabled()) {
TcLog.e(TAG, "Model Downloader is disabled. Abort the work.");
logDownloadWorkCompleted(
TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
return Futures.immediateFuture(ListenableWorker.Result.failure());
}
if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
TcLog.d(TAG, "Max attempt reached. Abort download work.");
logDownloadWorkCompleted(
TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MAX_RUN_ATTEMPT_REACHED);
return Futures.immediateFuture(ListenableWorker.Result.failure());
}
return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService))
.transform(
downloadResult -> {
Preconditions.checkNotNull(manifestsToDownload);
downloadedModelManager.onDownloadCompleted(manifestsToDownload);
TcLog.d(TAG, "Download work completed: " + downloadResult);
if (downloadResult.failureCount() == 0) {
logDownloadWorkCompleted(
downloadResult.successCount() > 0
? TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_MODEL_DOWNLOADED
: TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_NO_UPDATE_AVAILABLE);
return ListenableWorker.Result.success();
} else {
logDownloadWorkCompleted(
TextClassifierDownloadLogger.WORK_RESULT_RETRY_MODEL_DOWNLOAD_FAILED);
return ListenableWorker.Result.retry();
}
},
executorService)
.catching(
Throwable.class,
t -> {
TcLog.e(TAG, "Unexpected Exception during downloading: ", t);
logDownloadWorkCompleted(
TextClassifierDownloadLogger.WORK_RESULT_RETRY_RUNTIME_EXCEPTION);
return ListenableWorker.Result.retry();
},
executorService);
}
/**
* Checks device settings and returns the list of locales to download according to multi language
* support settings. Guarantees that the primary locale goes first.
*/
private ImmutableList<Locale> getLocalesToDownload() {
LocaleList localeList = LocaleList.getAdjustedDefault();
Locale primaryLocale = localeList.get(0);
if (!settings.isMultiLanguageSupportEnabled()) {
return ImmutableList.of(primaryLocale);
}
ImmutableList.Builder<Locale> localesToDownloadBuilder = ImmutableList.builder();
int size = min(settings.getMultiLanguageModelsLimit(), localeList.size());
for (int i = 0; i < size; i++) {
localesToDownloadBuilder.add(localeList.get(i));
}
return localesToDownloadBuilder.build();
}
/**
* Returns list of locales to download from {@code localeList} for the given {@code modelType}.
*/
private ImmutableList<Locale> getLocalesToDownloadByType(
ImmutableList<Locale> localeList, @ModelTypeDef String modelType) {
if (!settings.getEnabledModelTypesForMultiLanguageSupport().contains(modelType)) {
return ImmutableList.of(Locale.getDefault());
}
return localeList;
}
/**
* Check device config and dispatch download tasks for all modelTypes.
*
* <p>Download tasks will be combined and logged after completion. Return true if all tasks
* succeeded
*/
private ListenableFuture<DownloadResult> checkAndDownloadModels() {
ImmutableList<Locale> localesToDownload = getLocalesToDownload();
ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>();
ImmutableMap.Builder<String, ManifestsToDownloadByType> manifestsToDownloadBuilder =
ImmutableMap.builder();
for (String modelType : ModelType.values()) {
ImmutableList<Locale> localesToDownloadByType =
getLocalesToDownloadByType(localesToDownload, modelType);
ImmutableMap.Builder<String, String> localeTagToManifestUrlBuilder = ImmutableMap.builder();
for (Locale locale : localesToDownloadByType) {
Pair<String, String> bestLocaleTagAndManifestUrl =
LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, locale, settings);
if (bestLocaleTagAndManifestUrl == null) {
TcLog.w(
TAG,
String.format(
Locale.US, "No suitable manifest for %s, %s", modelType, locale.toLanguageTag()));
continue;
}
String bestLocaleTag = bestLocaleTagAndManifestUrl.first;
String manifestUrl = bestLocaleTagAndManifestUrl.second;
localeTagToManifestUrlBuilder.put(bestLocaleTag, manifestUrl);
TcLog.d(
TAG,
String.format(
Locale.US,
"model type: %s, current locale tag: %s, best locale tag: %s, manifest url: %s",
modelType,
locale.toLanguageTag(),
bestLocaleTag,
manifestUrl));
if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) {
continue;
}
downloadResultFutures.add(
downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
}
manifestsToDownloadBuilder.put(
modelType,
ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.buildOrThrow()));
}
manifestsToDownload = manifestsToDownloadBuilder.buildOrThrow();
return Futures.whenAllComplete(downloadResultFutures)
.call(
() -> {
TcLog.d(TAG, "All Download Tasks Completed");
int successCount = 0;
int failureCount = 0;
for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
if (Futures.getDone(downloadResultFuture)) {
successCount += 1;
} else {
failureCount += 1;
}
}
return DownloadResult.create(successCount, failureCount);
},
executorService);
}
private boolean shouldDownloadManifest(
@ModelTypeDef String modelType, String localeTag, String manifestUrl) {
Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
if (downloadedManifest == null) {
return true;
}
if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) {
if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) {
TcLog.w(
TAG,
String.format(
Locale.US,
"Manifest failed too many times, stop retrying: %s %d",
manifestUrl,
downloadedManifest.getFailureCounts()));
return false;
} else {
return true;
}
}
ManifestEnrollment manifestEnrollment =
downloadedModelManager.getManifestEnrollment(modelType, localeTag);
return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl());
}
/**
* Downloads a single manifest and models configured inside it.
*
* <p>The returned future should always resolve to a ManifestDownloadResult as we catch all
* exceptions.
*/
private ListenableFuture<Boolean> downloadManifestAndRegister(
@ModelTypeDef String modelType, String localeTag, String manifestUrl) {
long downloadStartTimestamp = getCurrentTimeMillis();
return FluentFuture.from(downloadManifest(manifestUrl))
.transform(
unused -> {
downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl);
TextClassifierDownloadLogger.downloadSucceeded(
workId,
modelType,
manifestUrl,
getRunAttemptCount(),
getCurrentTimeMillis() - downloadStartTimestamp);
TcLog.d(TAG, "Manifest downloaded and registered: " + manifestUrl);
return true;
},
executorService)
.catching(
Throwable.class,
t -> {
downloadedModelManager.registerManifestDownloadFailure(manifestUrl);
int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON;
int downloaderLibErrorCode = 0;
if (t instanceof ModelDownloadException) {
ModelDownloadException mde = (ModelDownloadException) t;
errorCode = mde.getErrorCode();
downloaderLibErrorCode = mde.getDownloaderLibErrorCode();
}
TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t);
TextClassifierDownloadLogger.downloadFailed(
workId,
modelType,
manifestUrl,
errorCode,
getRunAttemptCount(),
downloaderLibErrorCode,
getCurrentTimeMillis() - downloadStartTimestamp);
return false;
},
executorService);
}
// Download a manifest and its models, and register it to Manifest table.
private ListenableFuture<Void> downloadManifest(String manifestUrl) {
synchronized (lock) {
Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
if (downloadedManifest != null
&& downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
return Futures.immediateVoidFuture();
}
if (pendingDownloads.containsKey(manifestUrl)) {
return pendingDownloads.get(manifestUrl);
}
ListenableFuture<Void> manfiestDownloadFuture =
FluentFuture.from(downloader.downloadManifest(manifestUrl))
.transformAsync(
manifest -> {
ModelManifest.Model modelInfo = manifest.getModels(0);
return Futures.transform(
downloadModel(modelInfo), unused -> modelInfo, executorService);
},
executorService)
.transform(
modelInfo -> {
downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl());
return null;
},
executorService);
pendingDownloads.put(manifestUrl, manfiestDownloadFuture);
return manfiestDownloadFuture;
}
}
// Download a model and register it into Model table.
private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) {
String modelUrl = modelInfo.getUrl();
synchronized (lock) {
Model downloadedModel = downloadedModelManager.getModel(modelUrl);
if (downloadedModel != null) {
TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath());
return Futures.immediateVoidFuture();
}
if (pendingDownloads.containsKey(modelUrl)) {
return pendingDownloads.get(modelUrl);
}
ListenableFuture<Void> modelDownloadFuture =
FluentFuture.from(
downloader.downloadModel(
downloadedModelManager.getModelDownloaderDir(), modelInfo))
.transform(
modelFile -> {
downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath());
TcLog.d(TAG, "Model File downloaded: " + modelUrl);
return null;
},
executorService);
pendingDownloads.put(modelUrl, modelDownloadFuture);
return modelDownloadFuture;
}
}
/**
* This method will be called when we our work gets interrupted by the system. Result future
* should have already been cancelled in that case. Unless it's because the REPLACE policy of
* WorkManager unique queue, the interrupted work will be rescheduled later.
*/
@Override
public final void onStopped() {
TcLog.d(TAG, String.format(Locale.US, "Stop download. Attempt:%d", getRunAttemptCount()));
logDownloadWorkCompleted(TextClassifierDownloadLogger.WORK_RESULT_RETRY_STOPPED_BY_OS);
}
private long getCurrentTimeMillis() {
return clock.instant().toEpochMilli();
}
private void logDownloadWorkCompleted(int workResult) {
if (workStartedTimeMillis < workScheduledTimeMillis) {
TcLog.w(
TAG,
String.format(
Locale.US,
"Bad workStartedTimeMillis: %d, workScheduledTimeMillis: %d",
workStartedTimeMillis,
workScheduledTimeMillis));
workStartedTimeMillis = workScheduledTimeMillis;
}
TextClassifierDownloadLogger.downloadWorkCompleted(
workId,
workResult,
getRunAttemptCount(),
workStartedTimeMillis - workScheduledTimeMillis,
getCurrentTimeMillis() - workStartedTimeMillis);
}
@AutoValue
abstract static class DownloadResult {
public abstract int successCount();
public abstract int failureCount();
public static DownloadResult create(int successCount, int failureCount) {
return new AutoValue_ModelDownloadWorker_DownloadResult(successCount, failureCount);
}
}
}