| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.android.textclassifier.downloader; |
| |
| import static android.content.Context.BIND_AUTO_CREATE; |
| import static android.content.Context.BIND_NOT_FOREGROUND; |
| |
| import android.content.ComponentName; |
| import android.content.Context; |
| import android.content.Intent; |
| import android.content.ServiceConnection; |
| import android.os.IBinder; |
| import androidx.concurrent.futures.CallbackToFutureAdapter; |
| import com.android.textclassifier.common.base.TcLog; |
| import com.android.textclassifier.protobuf.ExtensionRegistryLite; |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Preconditions; |
| import com.google.common.hash.HashCode; |
| import com.google.common.hash.Hashing; |
| import com.google.common.io.Files; |
| import com.google.common.util.concurrent.FutureCallback; |
| import com.google.common.util.concurrent.Futures; |
| import com.google.common.util.concurrent.ListenableFuture; |
| import java.io.File; |
| import java.io.FileInputStream; |
| import java.io.IOException; |
| import java.net.URI; |
| import java.util.concurrent.ExecutorService; |
| |
| /** |
| * ModelDownloader implementation that forwards requests to ModelDownloaderService. This is to |
| * restrict the INTERNET permission to the service process only (instead of the whole ExtServices). |
| */ |
| final class ModelDownloaderImpl implements ModelDownloader { |
| private static final String TAG = "ModelDownloaderImpl"; |
| |
| private final Context context; |
| private final ExecutorService bgExecutorService; |
| private final Class<?> downloaderServiceClass; |
| |
| public ModelDownloaderImpl(Context context, ExecutorService bgExecutorService) { |
| this(context, bgExecutorService, ModelDownloaderService.class); |
| } |
| |
| @VisibleForTesting |
| ModelDownloaderImpl( |
| Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass) { |
| this.context = context.getApplicationContext(); |
| this.bgExecutorService = bgExecutorService; |
| this.downloaderServiceClass = downloaderServiceClass; |
| } |
| |
| @Override |
| public ListenableFuture<ModelManifest> downloadManifest(String manifestUrl) { |
| File manifestFile = |
| new File(context.getCacheDir(), manifestUrl.replaceAll("[^A-Za-z0-9]", "_") + ".manifest"); |
| return Futures.transform( |
| download(URI.create(manifestUrl), manifestFile), |
| bytesWritten -> { |
| try { |
| return ModelManifest.parseFrom( |
| new FileInputStream(manifestFile), ExtensionRegistryLite.getEmptyRegistry()); |
| } catch (Throwable t) { |
| throw new ModelDownloadException(ModelDownloadException.FAILED_TO_PARSE_MANIFEST, t); |
| } finally { |
| manifestFile.delete(); |
| } |
| }, |
| bgExecutorService); |
| } |
| |
| @Override |
| public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) { |
| File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model"); |
| ListenableFuture<File> modelFileFuture = |
| Futures.transform( |
| download(URI.create(model.getUrl()), modelFile), |
| bytesWritten -> { |
| validateModel(modelFile, model.getSizeInBytes(), model.getFingerprint()); |
| return modelFile; |
| }, |
| bgExecutorService); |
| Futures.addCallback( |
| modelFileFuture, |
| new FutureCallback<File>() { |
| @Override |
| public void onSuccess(File pendingModelFile) { |
| TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); |
| } |
| |
| @Override |
| public void onFailure(Throwable t) { |
| modelFile.delete(); |
| TcLog.e(TAG, "Failed to download: " + modelFile.getAbsolutePath(), t); |
| } |
| }, |
| bgExecutorService); |
| return modelFileFuture; |
| } |
| |
| // TODO(licha): Make this visible for testing. So we can avoid some duplicated test cases. |
| /** |
| * Downloads the file from uri to the targetFile. If the targetFile already exists, it will be |
| * deleted. Return bytes written if succeeds. |
| */ |
| private ListenableFuture<Long> download(URI uri, File targetFile) { |
| if (targetFile.exists()) { |
| TcLog.w( |
| TAG, |
| "Target file already exists. Delete it before downloading: " |
| + targetFile.getAbsolutePath()); |
| targetFile.delete(); |
| } |
| DownloaderServiceConnection conn = new DownloaderServiceConnection(); |
| ListenableFuture<IModelDownloaderService> downloaderServiceFuture = connect(conn); |
| ListenableFuture<Long> bytesWrittenFuture = |
| Futures.transformAsync( |
| downloaderServiceFuture, |
| service -> scheduleDownload(service, uri, targetFile), |
| bgExecutorService); |
| bytesWrittenFuture.addListener( |
| () -> { |
| try { |
| context.unbindService(conn); |
| } catch (IllegalArgumentException e) { |
| TcLog.e(TAG, "Error when unbind", e); |
| } |
| }, |
| bgExecutorService); |
| return bytesWrittenFuture; |
| } |
| |
| /** Model verification. Throws unchecked Exceptions if validation fails. */ |
| private static void validateModel(File pendingModelFile, long sizeInBytes, String fingerprint) { |
| if (!pendingModelFile.exists()) { |
| throw new ModelDownloadException( |
| ModelDownloadException.DOWNLOADED_FILE_MISSING, "PendingModelFile does not exist."); |
| } |
| if (pendingModelFile.length() != sizeInBytes) { |
| throw new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_VALIDATE_MODEL, |
| String.format( |
| "PendingModelFile size does not match: expected [%d] actual [%d]", |
| sizeInBytes, pendingModelFile.length())); |
| } |
| try { |
| HashCode pendingModelFingerprint = |
| Files.asByteSource(pendingModelFile).hash(Hashing.sha384()); |
| if (!pendingModelFingerprint.equals(HashCode.fromString(fingerprint))) { |
| throw new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_VALIDATE_MODEL, |
| String.format( |
| "PendingModelFile fingerprint does not match: expected [%s] actual [%s]", |
| fingerprint, pendingModelFingerprint)); |
| } |
| } catch (IOException e) { |
| throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e); |
| } |
| TcLog.d(TAG, "Pending model file passed validation."); |
| } |
| |
| private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) { |
| TcLog.d(TAG, "Starting a new connection to ModelDownloaderService"); |
| return CallbackToFutureAdapter.getFuture( |
| completer -> { |
| conn.attachCompleter(completer); |
| Intent intent = new Intent(context, downloaderServiceClass); |
| if (context.bindService(intent, conn, BIND_AUTO_CREATE | BIND_NOT_FOREGROUND)) { |
| return "Binding to service"; |
| } else { |
| completer.setException( |
| new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, |
| "Unable to bind to service")); |
| return "Binding failed"; |
| } |
| }); |
| } |
| |
| // Here the returned download result future can be set by: 1) the service can invoke the callback |
| // and set the result/exception; 2) If the service crashed, the CallbackToFutureAdapter will try |
| // to fail the future when the callback is garbage collected. If somehow none of them worked, the |
| // restult future will hang there until time out. (WorkManager forces a 10-min running time.) |
| private static ListenableFuture<Long> scheduleDownload( |
| IModelDownloaderService service, URI uri, File targetFile) { |
| TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService"); |
| return CallbackToFutureAdapter.getFuture( |
| completer -> { |
| service.download( |
| uri.toString(), |
| targetFile.getAbsolutePath(), |
| new IModelDownloaderCallback.Stub() { |
| @Override |
| public void onSuccess(long bytesWritten) { |
| completer.set(bytesWritten); |
| } |
| |
| @Override |
| public void onFailure( |
| @ModelDownloadException.ErrorCode int errorCode, String errorMsg) { |
| completer.setException(new ModelDownloadException(errorCode, errorMsg)); |
| } |
| }); |
| return "downlaoderService.download"; |
| }); |
| } |
| |
| /** The implementation of {@link ServiceConnection} that handles changes in the connection. */ |
| @VisibleForTesting |
| static class DownloaderServiceConnection implements ServiceConnection { |
| private static final String TAG = "ModelDownloaderImpl.DownloaderServiceConnection"; |
| |
| private CallbackToFutureAdapter.Completer<IModelDownloaderService> completer; |
| |
| public void attachCompleter( |
| CallbackToFutureAdapter.Completer<IModelDownloaderService> completer) { |
| this.completer = completer; |
| } |
| |
| @Override |
| public void onServiceConnected(ComponentName componentName, IBinder iBinder) { |
| TcLog.d(TAG, "DownloaderService connected"); |
| completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder))); |
| } |
| |
| @Override |
| public void onServiceDisconnected(ComponentName componentName) { |
| // If this is invoked after onServiceConnected, it will be ignored by the completer. |
| completer.setException( |
| new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, |
| "Service disconnected")); |
| } |
| |
| @Override |
| public void onBindingDied(ComponentName name) { |
| completer.setException( |
| new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, "Binding died")); |
| } |
| |
| @Override |
| public void onNullBinding(ComponentName name) { |
| completer.setException( |
| new ModelDownloadException( |
| ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, |
| "Unable to bind to DownloaderService")); |
| } |
| } |
| } |