blob: 0614c2fd3bbaa87ddb4f071e3815b559397e4c94 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier.downloader;
import androidx.annotation.IntDef;
import androidx.annotation.NonNull;
import androidx.room.ColumnInfo;
import androidx.room.Dao;
import androidx.room.Database;
import androidx.room.DatabaseView;
import androidx.room.Delete;
import androidx.room.Embedded;
import androidx.room.Entity;
import androidx.room.ForeignKey;
import androidx.room.Index;
import androidx.room.Insert;
import androidx.room.OnConflictStrategy;
import androidx.room.Query;
import androidx.room.RoomDatabase;
import androidx.room.Transaction;
import com.android.textclassifier.common.ModelType.ModelTypeDef;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.auto.value.AutoValue;
import com.google.common.collect.Iterables;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
/** Database storing info about models downloaded by model downloader */
@Database(
entities = {
DownloadedModelDatabase.Model.class,
DownloadedModelDatabase.Manifest.class,
DownloadedModelDatabase.ManifestModelCrossRef.class,
DownloadedModelDatabase.ManifestEnrollment.class
},
views = {DownloadedModelDatabase.ModelView.class},
version = 1,
exportSchema = true)
abstract class DownloadedModelDatabase extends RoomDatabase {
public static final String TAG = "DownloadedModelDatabase";
/** Rpresents a downloaded model file. */
@AutoValue
@Entity(
tableName = "model",
primaryKeys = {"model_url"})
abstract static class Model {
@AutoValue.CopyAnnotations
@ColumnInfo(name = "model_url")
@NonNull
public abstract String getModelUrl();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "model_path")
@NonNull
public abstract String getModelPath();
public static Model create(String modelUrl, String modelPath) {
return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath);
}
}
/** Rpresents a manifest we processed. */
@AutoValue
@Entity(
tableName = "manifest",
primaryKeys = {"manifest_url"})
abstract static class Manifest {
// TODO(licha): Consider using Enum here
@Retention(RetentionPolicy.SOURCE)
@IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED})
@interface StatusDef {}
public static final int STATUS_UNKNOWN = 0;
/** Failed to download this manifest. Could be retried in the future. */
public static final int STATUS_FAILED = 1;
/** Downloaded this manifest successfully and it's currently in storage. */
public static final int STATUS_SUCCEEDED = 2;
@AutoValue.CopyAnnotations
@ColumnInfo(name = "manifest_url")
@NonNull
public abstract String getManifestUrl();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "status")
@StatusDef
public abstract int getStatus();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "failure_counts")
public abstract int getFailureCounts();
public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) {
return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts);
}
}
/**
* Represents the relationship between manfiests and downloaded models.
*
* <p>A manifest can include multiple models, a model can also be included in multiple manifests.
* In different manifests, a model may have different configurations (e.g. primary model in
* manfiest A but dark model in B).
*/
@AutoValue
@Entity(
tableName = "manifest_model_cross_ref",
primaryKeys = {"manifest_url", "model_url"},
foreignKeys = {
@ForeignKey(
entity = Manifest.class,
parentColumns = "manifest_url",
childColumns = "manifest_url",
onDelete = ForeignKey.CASCADE),
@ForeignKey(
entity = Model.class,
parentColumns = "model_url",
childColumns = "model_url",
onDelete = ForeignKey.CASCADE),
},
indices = {
@Index(value = {"manifest_url"}),
@Index(value = {"model_url"}),
})
abstract static class ManifestModelCrossRef {
@AutoValue.CopyAnnotations
@ColumnInfo(name = "manifest_url")
@NonNull
public abstract String getManifestUrl();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "model_url")
@NonNull
public abstract String getModelUrl();
public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) {
return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl);
}
}
/**
* Represents the relationship between user scenarios and manifests.
*
* <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should
* use. The same manifest can be used for different scenarios.
*/
@AutoValue
@Entity(
tableName = "manifest_enrollment",
primaryKeys = {"model_type", "locale_tag"},
foreignKeys = {
@ForeignKey(
entity = Manifest.class,
parentColumns = "manifest_url",
childColumns = "manifest_url",
onDelete = ForeignKey.CASCADE)
},
indices = {@Index(value = {"manifest_url"})})
abstract static class ManifestEnrollment {
@AutoValue.CopyAnnotations
@ColumnInfo(name = "model_type")
@NonNull
@ModelTypeDef
public abstract String getModelType();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "locale_tag")
@NonNull
public abstract String getLocaleTag();
@AutoValue.CopyAnnotations
@ColumnInfo(name = "manifest_url")
@NonNull
public abstract String getManifestUrl();
public static ManifestEnrollment create(
@ModelTypeDef String modelType, String localeTag, String manifestUrl) {
return new AutoValue_DownloadedModelDatabase_ManifestEnrollment(
modelType, localeTag, manifestUrl);
}
}
/** Represents the mapping from manfiest enrollments to models. */
@AutoValue
@DatabaseView(
value =
"SELECT manifest_enrollment.*, model.* "
+ "FROM manifest_enrollment "
+ "INNER JOIN manifest_model_cross_ref "
+ "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url "
+ "INNER JOIN model "
+ "ON manifest_model_cross_ref.model_url = model.model_url",
viewName = "model_view")
abstract static class ModelView {
@AutoValue.CopyAnnotations
@Embedded
@NonNull
public abstract ManifestEnrollment getManifestEnrollment();
@AutoValue.CopyAnnotations
@Embedded
@NonNull
public abstract Model getModel();
public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) {
return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model);
}
}
@Dao
abstract static class DownloadedModelDatabaseDao {
// Full table scan
@Query("SELECT * FROM model")
abstract List<Model> queryAllModels();
@Query("SELECT * FROM manifest")
abstract List<Manifest> queryAllManifests();
@Query("SELECT * FROM manifest_model_cross_ref")
abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs();
@Query("SELECT * FROM manifest_enrollment")
abstract List<ManifestEnrollment> queryAllManifestEnrollments();
@Query("SELECT * FROM model_view")
abstract List<ModelView> queryAllModelViews();
// Single table query
@Query("SELECT * FROM model WHERE model_url = :modelUrl")
abstract List<Model> queryModelWithModelUrl(String modelUrl);
@Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl")
abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl);
@Query(
"SELECT * FROM manifest_enrollment WHERE model_type = :modelType "
+ "AND locale_tag = :localeTag")
abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag(
@ModelTypeDef String modelType, String localeTag);
// Helpers for clean up
@Query(
"SELECT manifest.* FROM manifest "
+ "LEFT JOIN model_view "
+ "ON manifest.manifest_url = model_view.manifest_url "
+ "WHERE model_view.manifest_url IS NULL "
+ "AND manifest.status = 2")
abstract List<Manifest> queryUnusedManifests();
@Query(
"SELECT * FROM manifest WHERE manifest.status = 1 "
+ "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)")
abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep);
@Query(
"SELECT model.* FROM model LEFT JOIN model_view "
+ "ON model.model_url = model_view.model_url "
+ "WHERE model_view.model_url IS NULL")
abstract List<Model> queryUnusedModels();
// Insertion
@Insert(onConflict = OnConflictStrategy.REPLACE)
abstract void insert(Model model);
@Insert(onConflict = OnConflictStrategy.REPLACE)
abstract void insert(Manifest manifest);
@Insert(onConflict = OnConflictStrategy.REPLACE)
abstract void insert(ManifestModelCrossRef manifestModelCrossRef);
@Insert(onConflict = OnConflictStrategy.REPLACE)
abstract void insert(ManifestEnrollment manifestEnrollment);
@Transaction
void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) {
insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
}
@Transaction
void increaseManifestFailureCounts(String manifestUrl) {
List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl);
if (manifests.isEmpty()) {
insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1));
} else {
Manifest prevManifest = Iterables.getOnlyElement(manifests);
insert(
Manifest.create(
manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1));
}
}
// Deletion
@Delete
abstract void deleteModels(List<Model> models);
@Delete
abstract void deleteManifests(List<Manifest> manifests);
@Delete
abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs);
@Delete
abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments);
@Transaction
void deleteUnusedManifestsAndModels() {
// Because Manifest table is the parent table of ManifestModelCrossRef table, related cross
// ref row in that table will be deleted automatically
deleteManifests(queryUnusedManifests());
deleteModels(queryUnusedModels());
}
@Transaction
void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) {
deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep));
}
}
abstract DownloadedModelDatabaseDao dao();
/** Dump the database for debugging. */
void dump(IndentingPrintWriter printWriter, ExecutorService executorService) {
printWriter.println("DownloadedModelDatabase");
printWriter.increaseIndent();
try {
printWriter.println("Model Table:");
printWriter.increaseIndent();
List<Model> models = executorService.submit(() -> dao().queryAllModels()).get();
for (Model model : models) {
printWriter.println(model.toString());
}
printWriter.decreaseIndent();
printWriter.println("Manifest Table:");
printWriter.increaseIndent();
List<Manifest> manifests = executorService.submit(() -> dao().queryAllManifests()).get();
for (Manifest manifest : manifests) {
printWriter.println(manifest.toString());
}
printWriter.decreaseIndent();
printWriter.println("ManifestModelCrossRef Table:");
printWriter.increaseIndent();
List<ManifestModelCrossRef> manifestModelCrossRefs =
executorService.submit(() -> dao().queryAllManifestModelCrossRefs()).get();
for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) {
printWriter.println(manifestModelCrossRef.toString());
}
printWriter.decreaseIndent();
printWriter.println("ManifestEnrollment Table:");
printWriter.increaseIndent();
List<ManifestEnrollment> manifestEnrollments =
executorService.submit(() -> dao().queryAllManifestEnrollments()).get();
for (ManifestEnrollment manifestEnrollment : manifestEnrollments) {
printWriter.println(manifestEnrollment.toString());
}
printWriter.decreaseIndent();
} catch (ExecutionException | InterruptedException e) {
TcLog.e(TAG, "Failed to dump the database", e);
}
printWriter.decreaseIndent();
}
}