Snap for 7132927 from ba8933394323a4348a73de83a194c97b117d09e3 to mainline-conscrypt-release
Change-Id: I8935cb9255ee330ec59c1473ec8be47c73f7ef73
diff --git a/java/Android.bp b/java/Android.bp
index 893b423..163512e 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -35,6 +35,9 @@
sdk_version: "system_current",
min_sdk_version: "30",
manifest: "AndroidManifest.xml",
+ aaptflags: [
+ "-0 .model",
+ ],
}
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
@@ -52,6 +55,10 @@
],
sdk_version: "system_current",
min_sdk_version: "30",
+ aaptflags: [
+ "-0 .model",
+ ],
+
}
java_library {
@@ -66,6 +73,8 @@
genrule {
name: "statslog-textclassifier-java-gen",
tools: ["stats-log-api-gen"],
- cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier --javaPackage com.android.textclassifier.common.statsd --javaClass TextClassifierStatsLog",
+ cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier" +
+ " --javaPackage com.android.textclassifier.common.statsd" +
+ " --javaClass TextClassifierStatsLog --minApiLevel 30",
out: ["com/android/textclassifier/common/statsd/TextClassifierStatsLog.java"],
}
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index 7c251e4..f2dfcb7 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -18,7 +18,7 @@
-->
<!--
- This manifest file is for the standalone TCS used for testing.
+ This manifest file is for the tcs library.
The TCS is typically shipped as part of ExtServices and is configured
in ExtServices's manifest.
-->
@@ -27,15 +27,15 @@
android:versionCode="1"
android:versionName="1.0.0">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
+ <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
<uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/>
+ <uses-permission android:name="android.permission.INTERNET" />
- <application android:label="@string/tcs_app_name"
- android:icon="@drawable/tcs_app_icon"
- android:extractNativeLibs="false">
+ <application>
<service
android:exported="true"
android:name=".DefaultTextClassifierService"
diff --git a/java/assets/textclassifier/actions_suggestions.universal.model b/java/assets/textclassifier/actions_suggestions.universal.model
new file mode 100755
index 0000000..f74fed4
--- /dev/null
+++ b/java/assets/textclassifier/actions_suggestions.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/annotator.universal.model b/java/assets/textclassifier/annotator.universal.model
new file mode 100755
index 0000000..09f1e0b
--- /dev/null
+++ b/java/assets/textclassifier/annotator.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/lang_id.model b/java/assets/textclassifier/lang_id.model
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/java/assets/textclassifier/lang_id.model
Binary files differ
diff --git a/java/res/drawable/tcs_app_icon.xml b/java/res/drawable/tcs_app_icon.xml
deleted file mode 100644
index 8cce7ca..0000000
--- a/java/res/drawable/tcs_app_icon.xml
+++ /dev/null
@@ -1,11 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-<vector xmlns:android="http://schemas.android.com/apk/res/android"
- android:width="24dp"
- android:height="24dp"
- android:viewportWidth="24"
- android:viewportHeight="24">
-
- <path
- android:fillColor="#000000"
- android:pathData="M2.5 4v3h5v12h3V7h5V4h-13zm19 5h-9v3h3v7h3v-7h3V9z" />
-</vector>
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 3d3b359..b145aa8 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -16,6 +16,10 @@
package com.android.textclassifier;
+import android.content.BroadcastReceiver;
+import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
@@ -27,7 +31,10 @@
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
@@ -37,44 +44,56 @@
import java.io.FileDescriptor;
import java.io.PrintWriter;
import java.util.concurrent.Callable;
+import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
private static final String TAG = "default_tcs";
+ private final Injector injector;
// TODO: Figure out do we need more concurrency.
- private final ListeningExecutorService normPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newFixedThreadPool(
- /* nThreads= */ 2,
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-norm-prio-executor")
- .setPriority(Thread.NORM_PRIORITY)
- .build()));
-
- private final ListeningExecutorService lowPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newSingleThreadExecutor(
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-low-prio-executor")
- .setPriority(Thread.NORM_PRIORITY - 1)
- .build()));
-
+ private ListeningExecutorService normPriorityExecutor;
+ private ListeningExecutorService lowPriorityExecutor;
private TextClassifierImpl textClassifier;
+ private TextClassifierSettings settings;
+ private ModelFileManager modelFileManager;
+ private BroadcastReceiver localeChangedReceiver;
+
+ public DefaultTextClassifierService() {
+ this.injector = new InjectorImpl(this);
+ }
+
+ @VisibleForTesting
+ DefaultTextClassifierService(Injector injector) {
+ this.injector = Preconditions.checkNotNull(injector);
+ }
+
+ private TextClassifierApiUsageLogger textClassifierApiUsageLogger;
@Override
public void onCreate() {
super.onCreate();
- TextClassifierSettings settings = new TextClassifierSettings();
- ModelFileManager modelFileManager = new ModelFileManager(this, settings);
- textClassifier = new TextClassifierImpl(this, settings, modelFileManager);
+ settings = injector.createTextClassifierSettings();
+ modelFileManager = injector.createModelFileManager(settings);
+ normPriorityExecutor = injector.createNormPriorityExecutor();
+ lowPriorityExecutor = injector.createLowPriorityExecutor();
+ textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
+ localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
+
+ textClassifierApiUsageLogger =
+ injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
+
+ injector
+ .getContext()
+ .registerReceiver(localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
}
@Override
public void onDestroy() {
super.onDestroy();
+ injector.getContext().unregisterReceiver(localeChangedReceiver);
}
@Override
@@ -84,7 +103,11 @@
CancellationSignal cancellationSignal,
Callback<TextSelection> callback) {
handleRequestAsync(
- () -> textClassifier.suggestSelection(request), callback, cancellationSignal);
+ () -> textClassifier.suggestSelection(request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId),
+ cancellationSignal);
}
@Override
@@ -93,7 +116,12 @@
TextClassification.Request request,
CancellationSignal cancellationSignal,
Callback<TextClassification> callback) {
- handleRequestAsync(() -> textClassifier.classifyText(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () -> textClassifier.classifyText(request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId),
+ cancellationSignal);
}
@Override
@@ -102,7 +130,12 @@
TextLinks.Request request,
CancellationSignal cancellationSignal,
Callback<TextLinks> callback) {
- handleRequestAsync(() -> textClassifier.generateLinks(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () -> textClassifier.generateLinks(request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId),
+ cancellationSignal);
}
@Override
@@ -112,7 +145,11 @@
CancellationSignal cancellationSignal,
Callback<ConversationActions> callback) {
handleRequestAsync(
- () -> textClassifier.suggestConversationActions(request), callback, cancellationSignal);
+ () -> textClassifier.suggestConversationActions(request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId),
+ cancellationSignal);
}
@Override
@@ -121,7 +158,12 @@
TextLanguage.Request request,
CancellationSignal cancellationSignal,
Callback<TextLanguage> callback) {
- handleRequestAsync(() -> textClassifier.detectLanguage(request), callback, cancellationSignal);
+ handleRequestAsync(
+ () -> textClassifier.detectLanguage(request),
+ callback,
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId),
+ cancellationSignal);
}
@Override
@@ -143,7 +185,10 @@
}
private <T> void handleRequestAsync(
- Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
+ Callable<T> callable,
+ Callback<T> callback,
+ TextClassifierApiUsageLogger.Session apiLoggerSession,
+ CancellationSignal cancellationSignal) {
ListenableFuture<T> result = normPriorityExecutor.submit(callable);
Futures.addCallback(
result,
@@ -151,12 +196,14 @@
@Override
public void onSuccess(T result) {
callback.onSuccess(result);
+ apiLoggerSession.reportSuccess();
}
@Override
public void onFailure(Throwable t) {
TcLog.e(TAG, "onFailure: ", t);
callback.onFailure(t.getMessage());
+ apiLoggerSession.reportFailure();
}
},
MoreExecutors.directExecutor());
@@ -183,4 +230,101 @@
},
MoreExecutors.directExecutor());
}
+
+ /**
+ * Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving.
+ */
+ static class LocaleChangedReceiver extends BroadcastReceiver {
+ private final ModelFileManager modelFileManager;
+
+ LocaleChangedReceiver(ModelFileManager modelFileManager) {
+ this.modelFileManager = modelFileManager;
+ }
+
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ modelFileManager.deleteUnusedModelFiles();
+ }
+ }
+
+ // Do not call any of these methods, except the constructor, before Service.onCreate is called.
+ private static class InjectorImpl implements Injector {
+ // Do not access the context object before Service.onCreate is invoked.
+ private final Context context;
+
+ private InjectorImpl(Context context) {
+ this.context = Preconditions.checkNotNull(context);
+ }
+
+ @Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
+ public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
+ return new ModelFileManager(context, settings);
+ }
+
+ @Override
+ public TextClassifierSettings createTextClassifierSettings() {
+ return new TextClassifierSettings();
+ }
+
+ @Override
+ public TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ return new TextClassifierImpl(context, settings, modelFileManager);
+ }
+
+ @Override
+ public ListeningExecutorService createNormPriorityExecutor() {
+ return MoreExecutors.listeningDecorator(
+ Executors.newFixedThreadPool(
+ /* nThreads= */ 2,
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-norm-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY)
+ .build()));
+ }
+
+ @Override
+ public ListeningExecutorService createLowPriorityExecutor() {
+ return MoreExecutors.listeningDecorator(
+ Executors.newSingleThreadExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-low-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY - 1)
+ .build()));
+ }
+
+ @Override
+ public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor) {
+ return new TextClassifierApiUsageLogger(
+ settings::getTextClassifierApiLogSampleRate, executor);
+ }
+ }
+
+ /*
+ * Provides dependencies to the {@link DefaultTextClassifierService}. This makes the service
+ * class testable.
+ */
+ interface Injector {
+ Context getContext();
+
+ ModelFileManager createModelFileManager(TextClassifierSettings settings);
+
+ TextClassifierSettings createTextClassifierSettings();
+
+ TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager);
+
+ ListeningExecutorService createNormPriorityExecutor();
+
+ ListeningExecutorService createLowPriorityExecutor();
+
+ TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor);
+ }
}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 0552ad2..9bc31fb 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -17,13 +17,15 @@
package com.android.textclassifier;
import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
-import android.text.TextUtils;
+import android.util.ArraySet;
import androidx.annotation.GuardedBy;
import androidx.annotation.StringDef;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
@@ -31,63 +33,104 @@
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
-import com.google.common.base.Splitter;
+import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
-import java.util.function.Function;
-import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
+// TODO(licha): Consider making this a singleton class
+// TODO(licha): Check whether this is thread-safe
/**
* Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
* model files to load.
*/
final class ModelFileManager {
+
private static final String TAG = "ModelFileManager";
+
private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
+ private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
+ private static final String ASSETS_DIR = "textclassifier";
- private final File downloadModelDir;
- private final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers;
+ private final List<ModelFileLister> modelFileListers;
+ private final File modelDownloaderDir;
- /** Create a ModelFileManager based on hardcoded model file locations. */
public ModelFileManager(Context context, TextClassifierSettings settings) {
Preconditions.checkNotNull(context);
Preconditions.checkNotNull(settings);
- this.downloadModelDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
- if (!downloadModelDir.exists()) {
- downloadModelDir.mkdirs();
- }
- ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>> suppliersBuilder =
- ImmutableMap.builder();
- for (String modelType : ModelType.values()) {
- suppliersBuilder.put(
- modelType, new ModelFileSupplierImpl(settings, modelType, downloadModelDir));
- }
- this.modelFileSuppliers = suppliersBuilder.build();
+ AssetManager assetManager = context.getAssets();
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ modelFileListers =
+ ImmutableList.of(
+ // Annotator models.
+ new RegularFilePatternMatchLister(
+ ModelType.ANNOTATOR,
+ this.modelDownloaderDir,
+ "annotator\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR,
+ new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ANNOTATOR,
+ ASSETS_DIR,
+ "annotator\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // Actions models.
+ new RegularFilePatternMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ this.modelDownloaderDir,
+ "actions_suggestions\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ACTIONS_SUGGESTIONS,
+ ASSETS_DIR,
+ "actions_suggestions\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // LangID models.
+ new RegularFilePatternMatchLister(
+ ModelType.LANG_ID,
+ this.modelDownloaderDir,
+ "lang_id\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.LANG_ID,
+ new File(CONFIG_UPDATER_DIR, "lang_id.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.LANG_ID,
+ ASSETS_DIR,
+ "lang_id.model",
+ /* isEnabled= */ () -> true));
}
@VisibleForTesting
- ModelFileManager(
- File downloadModelDir,
- ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers) {
- this.downloadModelDir = Preconditions.checkNotNull(downloadModelDir);
- this.modelFileSuppliers = Preconditions.checkNotNull(modelFileSuppliers);
+ ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ this.modelFileListers = ImmutableList.copyOf(modelFileListers);
}
/**
@@ -95,27 +138,203 @@
*
* @param modelType which type of model files to look for
*/
- public ImmutableList<ModelFile> listModelFiles(@ModelType.ModelTypeDef String modelType) {
- if (modelFileSuppliers.containsKey(modelType)) {
- return modelFileSuppliers.get(modelType).get();
+ public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
+ Preconditions.checkNotNull(modelType);
+
+ ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
+ for (ModelFileLister modelFileLister : modelFileListers) {
+ modelFiles.addAll(modelFileLister.list(modelType));
}
- return ImmutableList.of();
+ return modelFiles.build();
+ }
+
+ /** Lists model files. */
+ public interface ModelFileLister {
+ List<ModelFile> list(@ModelTypeDef String modelType);
+ }
+
+ /** Lists model files by performing full match on file path. */
+ public static class RegularFileFullMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File targetFile;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param targetFile the expected model file
+ * @param isEnabled whether this lister is enabled
+ */
+ public RegularFileFullMatchLister(
+ @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.targetFile = Preconditions.checkNotNull(targetFile);
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!targetFile.exists()) {
+ return ImmutableList.of();
+ }
+ try {
+ return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
+ }
+ return ImmutableList.of();
+ }
+ }
+
+ /** Lists model file in a specified folder by doing pattern matching on file names. */
+ public static class RegularFilePatternMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File folder;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param folder the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether the lister is enabled
+ */
+ public RegularFilePatternMatchLister(
+ @ModelTypeDef String modelType,
+ File folder,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.folder = Preconditions.checkNotNull(folder);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!folder.isDirectory()) {
+ return ImmutableList.of();
+ }
+ File[] files = folder.listFiles();
+ if (files == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (File file : files) {
+ final Matcher matcher = fileNamePattern.matcher(file.getName());
+ if (!matcher.matches() || !file.isFile()) {
+ continue;
+ }
+ try {
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists the model files preloaded in the APK file. */
+ public static class AssetFilePatternMatchLister implements ModelFileLister {
+ private final AssetManager assetManager;
+ private final String modelType;
+ private final String pathToList;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+ private final Object lock = new Object();
+ // Assets won't change without updating the app, so cache the result for performance reason.
+ @GuardedBy("lock")
+ private final Map<String, ImmutableList<ModelFile>> resultCache;
+
+ /**
+ * @param modelType the type of the model.
+ * @param pathToList the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether this lister is enabled
+ */
+ public AssetFilePatternMatchLister(
+ AssetManager assetManager,
+ @ModelTypeDef String modelType,
+ String pathToList,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.assetManager = Preconditions.checkNotNull(assetManager);
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.pathToList = Preconditions.checkNotNull(pathToList);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ resultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ synchronized (lock) {
+ if (resultCache.get(modelType) != null) {
+ return resultCache.get(modelType);
+ }
+ String[] fileNames = null;
+ try {
+ fileNames = assetManager.list(pathToList);
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to list assets", e);
+ }
+ if (fileNames == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (String fileName : fileNames) {
+ final Matcher matcher = fileNamePattern.matcher(fileName);
+ if (!matcher.matches()) {
+ continue;
+ }
+ String absolutePath =
+ new StringBuilder(pathToList).append('/').append(fileName).toString();
+ try {
+ modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
+ }
+ }
+ ImmutableList<ModelFile> result = modelFilesBuilder.build();
+ resultCache.put(modelType, result);
+ return result;
+ }
+ }
}
/**
* Returns the best model file for the given localelist, {@code null} if nothing is found.
*
* @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
- * @param localeList an ordered list of user preferences for locales, use {@code null} if there is
- * no preference.
+ * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
+ * there is no preference.
*/
@Nullable
public ModelFile findBestModelFile(
- @ModelType.ModelTypeDef String modelType, @Nullable LocaleList localeList) {
+ @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
final String languages =
- localeList == null || localeList.isEmpty()
+ localePreferences == null || localePreferences.isEmpty()
? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
+ : localePreferences.toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
ModelFile bestModel = null;
@@ -130,18 +349,46 @@
}
/**
+ * Deletes model files that are not preferred for any locales in user's preference.
+ *
+ * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
+ * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
+ * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
+ */
+ public void deleteUnusedModelFiles() {
+ TcLog.d(TAG, "Start to delete unused model files.");
+ LocaleList localeList = LocaleList.getDefault();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
+ for (int i = 0; i < localeList.size(); i++) {
+ // If a model file is preferred for any local in locale list, then keep it
+ ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
+ allModelFiles.remove(bestModel);
+ }
+ for (ModelFile modelFile : allModelFiles) {
+ if (modelFile.canWrite()) {
+ TcLog.d(TAG, "Deleting model: " + modelFile);
+ if (!modelFile.delete()) {
+ TcLog.w(TAG, "Failed to delete model: " + modelFile);
+ }
+ }
+ }
+ }
+ }
+
+ /**
* Returns a {@link File} that represents the destination to download a model.
*
- * <p>Each model file's name is uniquely formatted based on its unique remote URL address.
+ * <p>Each model file's name is uniquely formatted based on its unique remote manifest URL.
*
* <p>{@link ModelDownloadManager} needs to call this to get the right location and file name.
*
* @param modelType the type of the model image to download
- * @param url the unique remote url of the model image
+ * @param manifestUrl the unique remote url of the model manifest
*/
- public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String url) {
- String fileName = String.format("%s.%d.model", modelType, url.hashCode());
- return new File(downloadModelDir, fileName);
+ public File getDownloadTargetFile(@ModelTypeDef String modelType, String manifestUrl) {
+ String fileName = String.format("%s.%d.model", modelType, manifestUrl.hashCode());
+ return new File(modelDownloaderDir, fileName);
}
/**
@@ -152,7 +399,7 @@
public void dump(IndentingPrintWriter printWriter) {
printWriter.println("ModelFileManager:");
printWriter.increaseIndent();
- for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ for (@ModelTypeDef String modelType : ModelType.values()) {
printWriter.println(modelType + " model file(s):");
printWriter.increaseIndent();
for (ModelFile modelFile : listModelFiles(modelType)) {
@@ -163,260 +410,102 @@
printWriter.decreaseIndent();
}
- /** Default implementation of the model file supplier. */
- @VisibleForTesting
- static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
- private static final String FACTORY_MODEL_DIR = "/etc/textclassifier/";
+ /** Fetch metadata of a model file. */
+ private static class ModelInfoFetcher {
+ private final Function<AssetFileDescriptor, Integer> versionFetcher;
+ private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
- private static final class ModelFileInfo {
- private final String modelNameRegex;
- private final String configUpdaterModelPath;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
-
- public ModelFileInfo(
- String modelNameRegex,
- String configUpdaterModelPath,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.modelNameRegex = Preconditions.checkNotNull(modelNameRegex);
- this.configUpdaterModelPath = Preconditions.checkNotNull(configUpdaterModelPath);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
-
- public String getModelNameRegex() {
- return modelNameRegex;
- }
-
- public String getConfigUpdaterModelPath() {
- return configUpdaterModelPath;
- }
-
- public Function<Integer, Integer> getVersionSupplier() {
- return versionSupplier;
- }
-
- public Function<Integer, String> getSupportedLocalesSupplier() {
- return supportedLocalesSupplier;
- }
+ private ModelInfoFetcher(
+ Function<AssetFileDescriptor, Integer> versionFetcher,
+ Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
+ this.versionFetcher = versionFetcher;
+ this.supportedLocalesFetcher = supportedLocalesFetcher;
}
- private static final ImmutableMap<String, ModelFileInfo> MODEL_FILE_INFO_MAP =
- ImmutableMap.<String, ModelFileInfo>builder()
- .put(
- ModelType.ANNOTATOR,
- new ModelFileInfo(
- "(annotator|textclassifier)\\.(.*)\\.model",
- "/data/misc/textclassifier/textclassifier.model",
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales))
- .put(
- ModelType.LANG_ID,
- new ModelFileInfo(
- "lang_id.model",
- "/data/misc/textclassifier/lang_id.model",
- LangIdModel::getVersion,
- fd -> ModelFile.LANGUAGE_INDEPENDENT))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- new ModelFileInfo(
- "actions_suggestions\\.(.*)\\.model",
- "/data/misc/textclassifier/actions_suggestions.model",
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales))
- .build();
-
- private final TextClassifierSettings settings;
- @ModelType.ModelTypeDef private final String modelType;
- private final File configUpdaterModelFile;
- private final File downloaderModelDir;
- private final File factoryModelDir;
- private final Pattern modelFilenamePattern;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
- private final Object lock = new Object();
-
- @GuardedBy("lock")
- private ImmutableList<ModelFile> factoryModels;
-
- public ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File downloaderModelDir) {
- this(
- settings,
- modelType,
- new File(FACTORY_MODEL_DIR),
- MODEL_FILE_INFO_MAP.get(modelType).getModelNameRegex(),
- new File(MODEL_FILE_INFO_MAP.get(modelType).getConfigUpdaterModelPath()),
- downloaderModelDir,
- MODEL_FILE_INFO_MAP.get(modelType).getVersionSupplier(),
- MODEL_FILE_INFO_MAP.get(modelType).getSupportedLocalesSupplier());
+ int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return versionFetcher.apply(assetFileDescriptor);
}
- @VisibleForTesting
- ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File factoryModelDir,
- String modelFileNameRegex,
- File configUpdaterModelFile,
- File downloaderModelDir,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.settings = Preconditions.checkNotNull(settings);
- this.modelType = Preconditions.checkNotNull(modelType);
- this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- this.modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(modelFileNameRegex));
- this.configUpdaterModelFile = Preconditions.checkNotNull(configUpdaterModelFile);
- this.downloaderModelDir = Preconditions.checkNotNull(downloaderModelDir);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
+ return supportedLocalesFetcher.apply(assetFileDescriptor);
}
- @Override
- public ImmutableList<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The dwonloader and config updater model have higher precedences.
- if (downloaderModelDir.exists() && settings.isModelDownloadManagerEnabled()) {
- modelFiles.addAll(getMatchedModelFiles(downloaderModelDir));
+ static ModelInfoFetcher create(@ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return new ModelInfoFetcher(
+ ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
+ case ModelType.LANG_ID:
+ return new ModelInfoFetcher(
+ LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
+ default: // fall out
}
- if (configUpdaterModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(configUpdaterModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- synchronized (lock) {
- if (factoryModels == null) {
- factoryModels = getMatchedModelFiles(factoryModelDir);
- }
- modelFiles.addAll(factoryModels);
- }
- return ImmutableList.copyOf(modelFiles);
- }
-
- private ImmutableList<ModelFile> getMatchedModelFiles(File parentDir) {
- ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
- if (parentDir.exists() && parentDir.isDirectory()) {
- final File[] files = parentDir.listFiles();
- for (File file : files) {
- final Matcher matcher = modelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- modelFilesBuilder.add(model);
- }
- }
- }
- }
- return modelFilesBuilder.build();
- }
-
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = versionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- modelType,
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
+ throw new IllegalStateException("Unsupported model types");
}
}
/** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
+ public static class ModelFile {
+ @VisibleForTesting static final String LANGUAGE_INDEPENDENT = "*";
- @ModelType.ModelTypeDef private final String modelType;
- private final File file;
- private final int version;
- private final List<Locale> supportedLocales;
- private final String supportedLocalesStr;
- private final boolean languageIndependent;
+ @ModelTypeDef public final String modelType;
+ public final String absolutePath;
+ public final int version;
+ public final LocaleList supportedLocales;
+ public final boolean languageIndependent;
+ public final boolean isAsset;
- public ModelFile(
- @ModelType.ModelTypeDef String modelType,
- File file,
+ public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
+ throws IOException {
+ ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
+ return createFromAssetFileDescriptor(
+ file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
+ }
+ }
+
+ public static ModelFile createFromAsset(
+ AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
+ return createFromAssetFileDescriptor(
+ absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
+ }
+ }
+
+ private static ModelFile createFromAssetFileDescriptor(
+ String absolutePath,
+ @ModelTypeDef String modelType,
+ AssetFileDescriptor assetFileDescriptor,
+ boolean isAsset) {
+ ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
+ return new ModelFile(
+ modelType,
+ absolutePath,
+ modelInfoFetcher.getVersion(assetFileDescriptor),
+ modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
+ isAsset);
+ }
+
+ @VisibleForTesting
+ ModelFile(
+ @ModelTypeDef String modelType,
+ String absolutePath,
int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- this.modelType = Preconditions.checkNotNull(modelType);
- this.file = Preconditions.checkNotNull(file);
+ String supportedLocaleTags,
+ boolean isAsset) {
+ this.modelType = modelType;
+ this.absolutePath = absolutePath;
this.version = version;
- this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
- this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- this.languageIndependent = languageIndependent;
- }
-
- /** Returns the type of this model, defined in {@link ModelType}. */
- @ModelType.ModelTypeDef
- public String getModelType() {
- return modelType;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return file.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return file.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return version;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
+ this.supportedLocales =
+ languageIndependent
+ ? LocaleList.getEmptyLocaleList()
+ : LocaleList.forLanguageTags(supportedLocaleTags);
+ this.isAsset = isAsset;
}
/** Returns if this model file is preferred to the given one. */
@@ -436,70 +525,111 @@
}
// A higher-version model is preferred.
- if (version > model.getVersion()) {
+ if (version > model.version) {
return true;
}
return false;
}
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ if (languageIndependent) {
+ return true;
+ }
+ List<String> supportedLocaleTags =
+ Arrays.asList(supportedLocales.toLanguageTags().split(","));
+ return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
+ }
+
+ public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
+ if (isAsset) {
+ return assetManager.openFd(absolutePath);
+ }
+ File file = new File(absolutePath);
+ ParcelFileDescriptor parcelFileDescriptor =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
+ }
+
+ public boolean canWrite() {
+ if (isAsset) {
+ return false;
+ }
+ return new File(absolutePath).canWrite();
+ }
+
+ public boolean delete() {
+ if (isAsset) {
+ throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
+ }
+ return new File(absolutePath).delete();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ModelFile)) {
+ return false;
+ }
+ ModelFile modelFile = (ModelFile) o;
+ return version == modelFile.version
+ && languageIndependent == modelFile.languageIndependent
+ && isAsset == modelFile.isAsset
+ && Objects.equals(modelType, modelFile.modelType)
+ && Objects.equals(absolutePath, modelFile.absolutePath)
+ && Objects.equals(supportedLocales, modelFile.supportedLocales);
+ }
+
@Override
public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
+ return Objects.hash(
+ modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
}
public ModelInfo toModelInfo() {
- return new ModelInfo(getVersion(), supportedLocalesStr);
+ return new ModelInfo(version, supportedLocales.toLanguageTags());
}
@Override
public String toString() {
return String.format(
Locale.US,
- "ModelFile { type=%s path=%s name=%s version=%d locales=%s }",
+ "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
modelType,
- getPath(),
- getName(),
+ absolutePath,
version,
- supportedLocalesStr);
+ languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
+ isAsset);
}
public static ImmutableList<Optional<ModelInfo>> toModelInfos(
- Optional<ModelFile>... modelFiles) {
+ Optional<ModelFileManager.ModelFile>... modelFiles) {
return Arrays.stream(modelFiles)
- .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
+ .map(modelFile -> modelFile.transform(ModelFileManager.ModelFile::toModelInfo))
.collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
}
+ }
- /** Effectively an enum class to represent types of models. */
- public static final class ModelType {
- @Retention(RetentionPolicy.SOURCE)
- @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
- public @interface ModelTypeDef {}
+ /** Effectively an enum class to represent types of models. */
+ public static final class ModelType {
+ @Retention(RetentionPolicy.SOURCE)
+ @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
+ @interface ModelTypeDef {}
- public static final String ANNOTATOR = "annotator";
- public static final String LANG_ID = "lang_id";
- public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
+ public static final String ANNOTATOR = "annotator";
+ public static final String LANG_ID = "lang_id";
+ public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
- public static final ImmutableList<String> VALUES =
- ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
+ public static final ImmutableList<String> VALUES =
+ ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
- public static ImmutableList<String> values() {
- return VALUES;
- }
-
- private ModelType() {}
+ public static ImmutableList<String> values() {
+ return VALUES;
}
+
+ private ModelType() {}
}
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 9a213e4..b824ed0 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -22,11 +22,11 @@
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
+import android.content.res.AssetFileDescriptor;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.LocaleList;
import android.os.Looper;
-import android.os.ParcelFileDescriptor;
import android.util.ArrayMap;
import android.view.View.OnClickListener;
import android.view.textclassifier.ConversationAction;
@@ -43,7 +43,7 @@
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -62,8 +62,6 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.time.ZoneId;
import java.time.ZonedDateTime;
@@ -87,7 +85,6 @@
private final Context context;
private final ModelFileManager modelFileManager;
- private final TextClassifier fallback;
private final GenerateLinksLogger generateLinksLogger;
private final Object lock = new Object();
@@ -118,137 +115,116 @@
private final TemplateIntentFactory templateIntentFactory;
TextClassifierImpl(
- Context context,
- TextClassifierSettings settings,
- ModelFileManager modelFileManager,
- TextClassifier fallback) {
+ Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
this.context = Preconditions.checkNotNull(context);
this.settings = Preconditions.checkNotNull(settings);
this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
- this.fallback = Preconditions.checkNotNull(fallback);
generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
templateIntentFactory = new TemplateIntentFactory();
}
- TextClassifierImpl(
- Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
- this(context, settings, modelFileManager, TextClassifier.NO_OP);
+ @WorkerThread
+ TextSelection suggestSelection(TextSelection.Request request) throws IOException {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
+ Preconditions.checkArgument(
+ rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final LangIdModel langIdModel = getLangIdImpl();
+ final String detectLanguageTags =
+ String.join(",", detectLanguageTags(langIdModel, request.getText()));
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final int[] startEnd =
+ annotatorImpl.suggestSelection(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ AnnotatorModel.SelectionOptions.builder()
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .build());
+ final int start = startEnd[0];
+ final int end = startEnd[1];
+ if (start >= end
+ || start < 0
+ || start > request.getStartIndex()
+ || end > string.length()
+ || end < request.getEndIndex()) {
+ throw new IllegalArgumentException("Got bad indices for input text. Ignoring result.");
+ }
+ final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final AnnotatorModel.ClassificationResult[] results =
+ annotatorImpl.classifyText(
+ string,
+ start,
+ end,
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(detectLanguageTags)
+ .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .build(),
+ // Passing null here to suppress intent generation
+ // TODO: Use an explicit flag to suppress it.
+ /* appContext */ null,
+ /* deviceLocales */ null);
+ final int size = results.length;
+ for (int i = 0; i < size; i++) {
+ tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
+ }
+ final String resultId =
+ createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
+ return tsBuilder.setId(resultId).build();
}
@WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
+ TextClassification classifyText(TextClassification.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final Optional<LangIdModel> langIdModel = getLangIdImpl();
- final String detectLanguageTags =
- String.join(",", detectLanguageTags(langIdModel, request.getText()));
- final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final int[] startEnd =
- annotatorImpl.suggestSelection(
+ LangIdModel langId = getLangIdImpl();
+ List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
+ Preconditions.checkArgument(
+ rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
+
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final ZonedDateTime refTime =
+ request.getReferenceTime() != null
+ ? request.getReferenceTime()
+ : ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel.ClassificationResult[] results =
+ getAnnotatorImpl(request.getDefaultLocales())
+ .classifyText(
string,
request.getStartIndex(),
request.getEndIndex(),
- AnnotatorModel.SelectionOptions.builder()
+ AnnotatorModel.ClassificationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
.setLocales(localesString)
- .setDetectedTextLanguageTags(detectLanguageTags)
- .build());
- final int start = startEnd[0];
- final int end = startEnd[1];
- if (start < end
- && start >= 0
- && end <= string.length()
- && start <= request.getStartIndex()
- && end >= request.getEndIndex()) {
- final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
- final AnnotatorModel.ClassificationResult[] results =
- annotatorImpl.classifyText(
- string,
- start,
- end,
- AnnotatorModel.ClassificationOptions.builder()
- .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
- .setReferenceTimezone(refTime.getZone().getId())
- .setLocales(localesString)
- .setDetectedTextLanguageTags(detectLanguageTags)
- .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
- .build(),
- // Passing null here to suppress intent generation
- // TODO: Use an explicit flag to suppress it.
- /* appContext */ null,
- /* deviceLocales */ null);
- final int size = results.length;
- for (int i = 0; i < size; i++) {
- tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
- }
- final String resultId =
- createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
- return tsBuilder.setId(resultId).build();
- } else {
- // We can not trust the result. Log the issue and ignore the result.
- TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t);
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
+ .build(),
+ context,
+ getResourceLocalesString());
+ if (results.length == 0) {
+ throw new IllegalStateException("Empty text classification. Something went wrong.");
}
- // Getting here means something went wrong, return a NO_OP result.
- return fallback.suggestSelection(request);
+ return createClassificationResult(
+ results, string, request.getStartIndex(), request.getEndIndex(), langId);
}
@WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- Optional<LangIdModel> langId = getLangIdImpl();
- List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final ZonedDateTime refTime =
- request.getReferenceTime() != null
- ? request.getReferenceTime()
- : ZonedDateTime.now(ZoneId.systemDefault());
- final AnnotatorModel.ClassificationResult[] results =
- getAnnotatorImpl(request.getDefaultLocales())
- .classifyText(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- AnnotatorModel.ClassificationOptions.builder()
- .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
- .setReferenceTimezone(refTime.getZone().getId())
- .setLocales(localesString)
- .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
- .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
- .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
- .build(),
- context,
- getResourceLocalesString());
- if (results.length > 0) {
- return createClassificationResult(
- results, string, request.getStartIndex(), request.getEndIndex(), langId);
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting text classification info.", t);
- }
- // Getting here means something went wrong, return a NO_OP result.
- return fallback.classifyText(request);
- }
-
- @WorkerThread
- TextLinks generateLinks(TextLinks.Request request) {
+ TextLinks generateLinks(TextLinks.Request request) throws IOException {
Preconditions.checkNotNull(request);
Preconditions.checkArgument(
request.getText().length() <= getMaxGenerateLinksTextLength(),
@@ -259,75 +235,69 @@
final String textString = request.getText().toString();
final TextLinks.Builder builder = new TextLinks.Builder(textString);
- try {
- final long startTimeMs = System.currentTimeMillis();
- final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
- final Collection<String> entitiesToIdentify =
- request.getEntityConfig() != null
- ? request
- .getEntityConfig()
- .resolveEntityListModifications(
- getEntitiesForHints(request.getEntityConfig().getHints()))
- : settings.getEntityListDefault();
- final String localesString = concatenateLocales(request.getDefaultLocales());
- Optional<LangIdModel> langId = getLangIdImpl();
- ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final boolean isSerializedEntityDataEnabled =
- ExtrasUtils.isSerializedEntityDataEnabled(request);
- final AnnotatorModel.AnnotatedSpan[] annotations =
- annotatorImpl.annotate(
- textString,
- AnnotatorModel.AnnotationOptions.builder()
- .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
- .setReferenceTimezone(refTime.getZone().getId())
- .setLocales(localesString)
- .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
- .setEntityTypes(entitiesToIdentify)
- .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
- .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled)
- .build());
- for (AnnotatorModel.AnnotatedSpan span : annotations) {
- final AnnotatorModel.ClassificationResult[] results = span.getClassification();
- if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
- continue;
- }
- final Map<String, Float> entityScores = new ArrayMap<>();
- for (int i = 0; i < results.length; i++) {
- entityScores.put(results[i].getCollection(), results[i].getScore());
- }
- Bundle extras = new Bundle();
- if (isSerializedEntityDataEnabled) {
- ExtrasUtils.putEntities(extras, results);
- }
- builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
+ final long startTimeMs = System.currentTimeMillis();
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final Collection<String> entitiesToIdentify =
+ request.getEntityConfig() != null
+ ? request
+ .getEntityConfig()
+ .resolveEntityListModifications(
+ getEntitiesForHints(request.getEntityConfig().getHints()))
+ : settings.getEntityListDefault();
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ LangIdModel langId = getLangIdImpl();
+ ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final boolean isSerializedEntityDataEnabled =
+ ExtrasUtils.isSerializedEntityDataEnabled(request);
+ final AnnotatorModel.AnnotatedSpan[] annotations =
+ annotatorImpl.annotate(
+ textString,
+ AnnotatorModel.AnnotationOptions.builder()
+ .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
+ .setReferenceTimezone(refTime.getZone().getId())
+ .setLocales(localesString)
+ .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
+ .setEntityTypes(entitiesToIdentify)
+ .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
+ .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled)
+ .build());
+ for (AnnotatorModel.AnnotatedSpan span : annotations) {
+ final AnnotatorModel.ClassificationResult[] results = span.getClassification();
+ if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
+ continue;
}
- final TextLinks links = builder.build();
- final long endTimeMs = System.currentTimeMillis();
- final String callingPackageName =
- request.getCallingPackageName() == null
- ? context.getPackageName() // local (in process) TC.
- : request.getCallingPackageName();
- Optional<ModelInfo> annotatorModelInfo;
- Optional<ModelInfo> langIdModelInfo;
- synchronized (lock) {
- annotatorModelInfo =
- Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
- langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
+ final Map<String, Float> entityScores = new ArrayMap<>();
+ for (AnnotatorModel.ClassificationResult result : results) {
+ entityScores.put(result.getCollection(), result.getScore());
}
- generateLinksLogger.logGenerateLinks(
- request.getText(),
- links,
- callingPackageName,
- endTimeMs - startTimeMs,
- annotatorModelInfo,
- langIdModelInfo);
- return links;
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting links info.", t);
+ Bundle extras = new Bundle();
+ if (isSerializedEntityDataEnabled) {
+ ExtrasUtils.putEntities(extras, results);
+ }
+ builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
}
- return fallback.generateLinks(request);
+ final TextLinks links = builder.build();
+ final long endTimeMs = System.currentTimeMillis();
+ final String callingPackageName =
+ request.getCallingPackageName() == null
+ ? context.getPackageName() // local (in process) TC.
+ : request.getCallingPackageName();
+ Optional<ModelInfo> annotatorModelInfo;
+ Optional<ModelInfo> langIdModelInfo;
+ synchronized (lock) {
+ annotatorModelInfo =
+ Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
+ langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
+ }
+ generateLinksLogger.logGenerateLinks(
+ request.getText(),
+ links,
+ callingPackageName,
+ endTimeMs - startTimeMs,
+ annotatorModelInfo,
+ langIdModelInfo);
+ return links;
}
int getMaxGenerateLinksTextLength() {
@@ -364,60 +334,42 @@
TextClassifierEventConverter.fromPlatform(event));
}
- TextLanguage detectLanguage(TextLanguage.Request request) {
+ TextLanguage detectLanguage(TextLanguage.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- final TextLanguage.Builder builder = new TextLanguage.Builder();
- Optional<LangIdModel> langIdImpl = getLangIdImpl();
- if (langIdImpl.isPresent()) {
- final LangIdModel.LanguageResult[] langResults =
- langIdImpl.get().detectLanguages(request.getText().toString());
- for (int i = 0; i < langResults.length; i++) {
- builder.putLocale(
- ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore());
- }
- return builder.build();
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error detecting text language.", t);
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ LangIdModel langIdImpl = getLangIdImpl();
+ final LangIdModel.LanguageResult[] langResults =
+ langIdImpl.detectLanguages(request.getText().toString());
+ for (LangIdModel.LanguageResult langResult : langResults) {
+ builder.putLocale(ULocale.forLanguageTag(langResult.getLanguage()), langResult.getScore());
}
- return fallback.detectLanguage(request);
+ return builder.build();
}
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ ConversationActions suggestConversationActions(ConversationActions.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
- try {
- ActionsSuggestionsModel actionsImpl = getActionsImpl();
- if (actionsImpl == null) {
- // Actions model is optional, fallback if it is not available.
- return fallback.suggestConversationActions(request);
- }
- Optional<LangIdModel> langId = getLangIdImpl();
- ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- request.getConversation(), text -> detectLanguageTags(langId, text));
- if (nativeMessages.length == 0) {
- return fallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.Conversation nativeConversation =
- new ActionsSuggestionsModel.Conversation(nativeMessages);
-
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
- actionsImpl.suggestActionsWithIntents(
- nativeConversation,
- null,
- context,
- getResourceLocalesString(),
- getAnnotatorImpl(LocaleList.getDefault()));
- return createConversationActionResult(request, nativeSuggestions);
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting conversation actions.", t);
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ LangIdModel langId = getLangIdImpl();
+ ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ request.getConversation(), text -> detectLanguageTags(langId, text));
+ if (nativeMessages.length == 0) {
+ return new ConversationActions(ImmutableList.of(), /* id= */ null);
}
- return fallback.suggestConversationActions(request);
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages);
+
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ actionsImpl.suggestActionsWithIntents(
+ nativeConversation,
+ null,
+ context,
+ getResourceLocalesString(),
+ getAnnotatorImpl(LocaleList.getDefault()));
+ return createConversationActionResult(request, nativeSuggestions);
}
/**
@@ -482,94 +434,61 @@
return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
}
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws IOException {
synchronized (lock) {
localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
modelFileManager.findBestModelFile(ModelType.ANNOTATOR, localeList);
if (bestModel == null) {
- throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
+ throw new IllegalStateException("Failed to find the best annotator model");
}
if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- // The current annotator model may be still used by another thread / model.
- // Do not call close() here, and let the GC to clean it up when no one else
- // is using it.
- annotatorImpl = new AnnotatorModel(pfd.getFd());
- Optional<LangIdModel> langIdModel = getLangIdImpl();
- if (langIdModel.isPresent()) {
- annotatorImpl.setLangIdModel(langIdModel.get());
- }
- annotatorModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
+ // The current annotator model may be still used by another thread / model.
+ // Do not call close() here, and let the GC to clean it up when no one else
+ // is using it.
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ annotatorImpl = new AnnotatorModel(afd);
+ annotatorImpl.setLangIdModel(getLangIdImpl());
+ annotatorModelInUse = bestModel;
}
}
return annotatorImpl;
}
}
- private Optional<LangIdModel> getLangIdImpl() {
+ private LangIdModel getLangIdImpl() throws IOException {
synchronized (lock) {
final ModelFileManager.ModelFile bestModel =
- modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localeList= */ null);
+ modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localePreferences= */ null);
if (bestModel == null) {
- return Optional.absent();
+ throw new IllegalStateException("Failed to find the best LangID model.");
}
if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd;
- try {
- pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to open the LangID model file", e);
- return Optional.absent();
- }
- try {
- if (pfd != null) {
- langIdImpl = new LangIdModel(pfd.getFd());
- langIdModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ langIdImpl = new LangIdModel(afd);
+ langIdModelInUse = bestModel;
}
}
- return Optional.of(langIdImpl);
+ return langIdImpl;
}
}
- @Nullable
- private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ private ActionsSuggestionsModel getActionsImpl() throws IOException {
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
final ModelFileManager.ModelFile bestModel =
modelFileManager.findBestModelFile(
ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault());
if (bestModel == null) {
- return null;
+ throw new IllegalStateException("Failed to find the best actions model");
}
if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd == null) {
- TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
- return null;
- }
- actionsImpl = new ActionsSuggestionsModel(pfd.getFd());
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ actionsImpl = new ActionsSuggestionsModel(afd);
actionModelInUse = bestModel;
- } finally {
- maybeCloseAndLogError(pfd);
}
}
return actionsImpl;
@@ -597,7 +516,7 @@
String text,
int start,
int end,
- Optional<LangIdModel> langId) {
+ LangIdModel langId) {
final String classifiedText = text.substring(start, end);
final TextClassification.Builder builder =
new TextClassification.Builder().setText(classifiedText);
@@ -644,10 +563,7 @@
actionIntents.add(intent);
}
Bundle extras = new Bundle();
- Optional<Bundle> foreignLanguageExtra =
- langId
- .transform(model -> maybeCreateExtrasForTranslate(actionIntents, model))
- .or(Optional.<Bundle>absent());
+ Optional<Bundle> foreignLanguageExtra = maybeCreateExtrasForTranslate(actionIntents, langId);
if (foreignLanguageExtra.isPresent()) {
ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get());
}
@@ -694,16 +610,10 @@
topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion()));
}
- private ImmutableList<String> detectLanguageTags(
- Optional<LangIdModel> langId, CharSequence text) {
- return langId
- .transform(
- model -> {
- float threshold = getLangIdThreshold(model);
- EntityConfidence languagesConfidence = detectLanguages(model, text, threshold);
- return ImmutableList.copyOf(languagesConfidence.getEntities());
- })
- .or(ImmutableList.of());
+ private ImmutableList<String> detectLanguageTags(LangIdModel langId, CharSequence text) {
+ float threshold = getLangIdThreshold(langId);
+ EntityConfidence languagesConfidence = detectLanguages(langId, text, threshold);
+ return ImmutableList.copyOf(languagesConfidence.getEntities());
}
/**
@@ -734,7 +644,6 @@
printWriter.increaseIndent();
modelFileManager.dump(printWriter);
- printWriter.printPair("mFallback", fallback);
printWriter.decreaseIndent();
printWriter.println();
@@ -753,19 +662,6 @@
}
}
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
-
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
-
private static void checkMainThread() {
if (Looper.myLooper() == Looper.getMainLooper()) {
TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
@@ -775,7 +671,10 @@
private static PendingIntent createPendingIntent(
final Context context, final Intent intent, int requestCode) {
return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ context,
+ requestCode,
+ intent,
+ PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
}
@Nullable
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
index 005bd7c..b13d166 100644
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -20,7 +20,8 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
@@ -44,6 +45,7 @@
* @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
*/
public final class TextClassifierSettings {
+ private static final String TAG = "TextClassifierSettings";
public static final String NAMESPACE = DeviceConfig.NAMESPACE_TEXTCLASSIFIER;
private static final String DELIMITER = ":";
@@ -109,6 +111,12 @@
/** Whether to enable model downloading with ModelDownloadManager */
@VisibleForTesting
static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
+ /** Type of network to download model manifest. A String value of androidx.work.NetworkType. */
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE =
+ "manifest_download_required_network_type";
+ /** Max attempts allowed for a single ModelDownloader downloading task. */
+ @VisibleForTesting
+ static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
/** The prefix of the URL to download models. E.g. https://www.gstatic.com/android/ */
@VisibleForTesting static final String ANNOTATOR_URL_PREFIX = "annotator_url_prefix";
@@ -126,6 +134,9 @@
static final String PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX =
"primary_actions_suggestions_url_suffix";
+ /** Sampling rate for TextClassifier API logging. */
+ static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
+
/**
* A colon(:) separated string that specifies the configuration to use when including surrounding
* context text in language detection queries.
@@ -190,6 +201,9 @@
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
+ // Manifest files are usually small, default to any network type
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "NOT_ROAMING";
+ private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
private static final String ANNOTATOR_URL_PREFIX_DEFAULT =
"https://www.gstatic.com/android/text_classifier/";
private static final String LANG_ID_URL_PREFIX_DEFAULT =
@@ -200,8 +214,12 @@
private static final String PRIMARY_LANG_ID_URL_SUFFIX_DEFAULT = "";
private static final String PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX_DEFAULT = "";
private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+ /**
+ * Sampling rate for API logging. For example, 100 means there is a 0.01 chance that the API call
+ * is the logged.
+ */
+ private static final int TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT = 10;
- @VisibleForTesting
interface IDeviceConfig {
default int getInt(@NonNull String namespace, @NonNull String name, @NonNull int defaultValue) {
return defaultValue;
@@ -257,7 +275,7 @@
}
@VisibleForTesting
- TextClassifierSettings(IDeviceConfig deviceConfig) {
+ public TextClassifierSettings(IDeviceConfig deviceConfig) {
this.deviceConfig = deviceConfig;
}
@@ -344,7 +362,12 @@
NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
}
- public String getModelURLPrefix(@ModelType.ModelTypeDef String modelType) {
+ public int getModelDownloadMaxAttempts() {
+ return deviceConfig.getInt(
+ NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
+ }
+
+ public String getModelURLPrefix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
@@ -359,7 +382,7 @@
}
}
- public String getPrimaryModelURLSuffix(@ModelType.ModelTypeDef String modelType) {
+ public String getPrimaryModelURLSuffix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
@@ -377,37 +400,43 @@
}
}
+ public int getTextClassifierApiLogSampleRate() {
+ return deviceConfig.getInt(
+ NAMESPACE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, TEXTCLASSIFIER_API_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
void dump(IndentingPrintWriter pw) {
pw.println("TextClassifierSettings:");
pw.increaseIndent();
- pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
- pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
- pw.printPair("entity_list_default", getEntityListDefault());
- pw.printPair("entity_list_editable", getEntityListEditable());
- pw.printPair("entity_list_not_editable", getEntityListNotEditable());
- pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
- pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
- pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
- pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
- pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
- pw.printPair("translate_action_threshold", getTranslateActionThreshold());
+ pw.printPair(CLASSIFY_TEXT_MAX_RANGE_LENGTH, getClassifyTextMaxRangeLength());
+ pw.printPair(DETECT_LANGUAGES_FROM_TEXT_ENABLED, isDetectLanguagesFromTextEnabled());
+ pw.printPair(ENTITY_LIST_DEFAULT, getEntityListDefault());
+ pw.printPair(ENTITY_LIST_EDITABLE, getEntityListEditable());
+ pw.printPair(ENTITY_LIST_NOT_EDITABLE, getEntityListNotEditable());
+ pw.printPair(GENERATE_LINKS_LOG_SAMPLE_RATE, getGenerateLinksLogSampleRate());
+ pw.printPair(GENERATE_LINKS_MAX_TEXT_LENGTH, getGenerateLinksMaxTextLength());
+ pw.printPair(IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, getInAppConversationActionTypes());
+ pw.printPair(LANG_ID_CONTEXT_SETTINGS, Arrays.toString(getLangIdContextSettings()));
+ pw.printPair(LANG_ID_THRESHOLD_OVERRIDE, getLangIdThresholdOverride());
+ pw.printPair(TRANSLATE_ACTION_THRESHOLD, getTranslateActionThreshold());
pw.printPair(
- "notification_conversation_action_types_default", getNotificationConversationActionTypes());
- pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
- pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
- pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
- pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
- pw.printPair("model_download_manager_enabled", isModelDownloadManagerEnabled());
- pw.printPair("annotator_url_prefix", getModelURLPrefix(ModelType.ANNOTATOR));
- pw.printPair("lang_id_url_prefix", getModelURLPrefix(ModelType.LANG_ID));
- pw.printPair(
- "actions_suggestions_url_prefix", getModelURLPrefix(ModelType.ACTIONS_SUGGESTIONS));
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, getNotificationConversationActionTypes());
+ pw.printPair(SUGGEST_SELECTION_MAX_RANGE_LENGTH, getSuggestSelectionMaxRangeLength());
+ pw.printPair(USER_LANGUAGE_PROFILE_ENABLED, isUserLanguageProfileEnabled());
+ pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
+ pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
+ pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
+ pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
+ pw.printPair(ANNOTATOR_URL_PREFIX, getModelURLPrefix(ModelType.ANNOTATOR));
+ pw.printPair(LANG_ID_URL_PREFIX, getModelURLPrefix(ModelType.LANG_ID));
+ pw.printPair(ACTIONS_SUGGESTIONS_URL_PREFIX, getModelURLPrefix(ModelType.ACTIONS_SUGGESTIONS));
pw.decreaseIndent();
- pw.printPair("primary_annotator_url_suffix", getPrimaryModelURLSuffix(ModelType.ANNOTATOR));
- pw.printPair("primary_lang_id_url_suffix", getPrimaryModelURLSuffix(ModelType.LANG_ID));
+ pw.printPair(PRIMARY_ANNOTATOR_URL_SUFFIX, getPrimaryModelURLSuffix(ModelType.ANNOTATOR));
+ pw.printPair(PRIMARY_LANG_ID_URL_SUFFIX, getPrimaryModelURLSuffix(ModelType.LANG_ID));
pw.printPair(
- "primary_actions_suggestions_url_suffix",
+ PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX,
getPrimaryModelURLSuffix(ModelType.ACTIONS_SUGGESTIONS));
+ pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
pw.decreaseIndent();
}
diff --git a/java/src/com/android/textclassifier/common/base/TcLog.java b/java/src/com/android/textclassifier/common/base/TcLog.java
index 87f1187..05a2443 100644
--- a/java/src/com/android/textclassifier/common/base/TcLog.java
+++ b/java/src/com/android/textclassifier/common/base/TcLog.java
@@ -16,6 +16,8 @@
package com.android.textclassifier.common.base;
+import android.util.Log;
+
/**
* Logging for android.view.textclassifier package.
*
@@ -31,27 +33,30 @@
public static final String TAG = "androidtc";
/** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ public static final boolean ENABLE_FULL_LOGGING = Log.isLoggable(TAG, Log.VERBOSE);
private TcLog() {}
public static void v(String tag, String msg) {
if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
+ Log.v(getTag(tag), msg);
}
}
public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
+ Log.d(getTag(tag), msg);
}
public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
+ Log.w(getTag(tag), msg);
+ }
+
+ public static void e(String tag, String msg) {
+ Log.e(getTag(tag), msg);
}
public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
+ Log.e(getTag(tag), msg, tr);
}
private static String getTag(String customTag) {
diff --git a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
index b32e1ce..5c420ad 100644
--- a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
@@ -210,7 +210,10 @@
private static PendingIntent createPendingIntent(
final Context context, final Intent intent, int requestCode) {
return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ context,
+ requestCode,
+ intent,
+ PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
}
@Nullable
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
index 45785f1..822eb77 100644
--- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -141,7 +141,7 @@
String annotatorModelName = annotatorModel.transform(ModelInfo::toModelName).or("");
String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or("");
TextClassifierStatsLog.write(
- TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID,
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
callId,
TextClassifierEvent.TYPE_LINKS_GENERATED,
annotatorModelName,
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java
new file mode 100644
index 0000000..8a79d74
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLogger.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static java.lang.annotation.RetentionPolicy.SOURCE;
+
+import android.os.SystemClock;
+import android.view.textclassifier.TextClassificationSessionId;
+import androidx.annotation.IntDef;
+import androidx.annotation.Nullable;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Supplier;
+import java.lang.annotation.Retention;
+import java.util.Locale;
+import java.util.Random;
+import java.util.concurrent.Executor;
+
+/** Logs the TextClassifier API usages. */
+public final class TextClassifierApiUsageLogger {
+ private static final String TAG = "ApiUsageLogger";
+
+ public static final int API_TYPE_SUGGEST_SELECTION =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__SUGGEST_SELECTION;
+ public static final int API_TYPE_CLASSIFY_TEXT =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__CLASSIFY_TEXT;
+ public static final int API_TYPE_GENERATE_LINKS =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__GENERATE_LINKS;
+ public static final int API_TYPE_SUGGEST_CONVERSATION_ACTIONS =
+ TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__SUGGEST_CONVERSATION_ACTIONS;
+ public static final int API_TYPE_DETECT_LANGUAGES =
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED__API_TYPE__DETECT_LANGUAGES;
+
+ /** The type of the API. */
+ @Retention(SOURCE)
+ @IntDef({
+ API_TYPE_SUGGEST_SELECTION,
+ API_TYPE_CLASSIFY_TEXT,
+ API_TYPE_GENERATE_LINKS,
+ API_TYPE_SUGGEST_CONVERSATION_ACTIONS,
+ API_TYPE_DETECT_LANGUAGES
+ })
+ public @interface ApiType {}
+
+ private final Executor executor;
+
+ private final Supplier<Integer> sampleRateSupplier;
+
+ private final Random random;
+
+ /**
+ * @param sampleRateSupplier The rate at which log events are written. (e.g. 100 means there is a
+ * 0.01 chance that a call to logGenerateLinks results in an event being written). To write
+ * all events, pass 1. To disable logging, pass any number < 1. Sampling is used to reduce the
+ * amount of logging data generated.
+ * @param executor that is used to execute the logging work.
+ */
+ public TextClassifierApiUsageLogger(Supplier<Integer> sampleRateSupplier, Executor executor) {
+ this.executor = Preconditions.checkNotNull(executor);
+ this.sampleRateSupplier = sampleRateSupplier;
+ this.random = new Random();
+ }
+
+ public Session createSession(
+ @ApiType int apiType, @Nullable TextClassificationSessionId sessionId) {
+ return new Session(apiType, sessionId);
+ }
+
+ /** A session to log an API invocation. Creates a new session for each API call. */
+ public final class Session {
+ @ApiType private final int apiType;
+ @Nullable private final TextClassificationSessionId sessionId;
+ private final long beginElapsedRealTime;
+
+ private Session(@ApiType int apiType, @Nullable TextClassificationSessionId sessionId) {
+ this.apiType = apiType;
+ this.sessionId = sessionId;
+ beginElapsedRealTime = SystemClock.elapsedRealtime();
+ }
+
+ public void reportSuccess() {
+ reportInternal(/* success= */ true);
+ }
+
+ public void reportFailure() {
+ reportInternal(/* success= */ false);
+ }
+
+ private void reportInternal(boolean success) {
+ if (!shouldLog()) {
+ return;
+ }
+ final long latencyInMillis = SystemClock.elapsedRealtime() - beginElapsedRealTime;
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.ENGLISH,
+ "TextClassifierApiUsageLogger: apiType=%d success=%b latencyInMillis=%d",
+ apiType,
+ success,
+ latencyInMillis));
+ }
+ executor.execute(
+ () ->
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_CLASSIFIER_API_USAGE_REPORTED,
+ apiType,
+ success
+ ? TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__RESULT_TYPE__SUCCESS
+ : TextClassifierStatsLog
+ .TEXT_CLASSIFIER_API_USAGE_REPORTED__RESULT_TYPE__FAIL,
+ latencyInMillis,
+ sessionId == null ? "" : sessionId.getValue()));
+ }
+ }
+
+ /** Returns whether this particular event should be logged. */
+ private boolean shouldLog() {
+ if (sampleRateSupplier.get() < 1) {
+ return false;
+ } else {
+ return random.nextInt(sampleRateSupplier.get()) == 0;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
index 307be6b..6678142 100644
--- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -34,11 +34,6 @@
/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
public final class TextClassifierEventLogger {
private static final String TAG = "TCEventLogger";
- // These constants are defined in atoms.proto.
- private static final int TEXT_SELECTION_EVENT_ATOM_ID = 219;
- static final int TEXT_LINKIFY_EVENT_ATOM_ID = 220;
- private static final int CONVERSATION_ACTIONS_EVENT_ATOM_ID = 221;
- private static final int LANGUAGE_DETECTION_EVENT_ATOM_ID = 222;
/** Emits a text classifier event to the logs. */
public void writeEvent(
@@ -68,7 +63,7 @@
TextClassifierEvent.TextSelectionEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
TextClassifierStatsLog.write(
- TEXT_SELECTION_EVENT_ATOM_ID,
+ TextClassifierStatsLog.TEXT_SELECTION_EVENT,
sessionId == null ? null : sessionId.getValue(),
getEventType(event),
getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
@@ -98,7 +93,7 @@
TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
TextClassifierStatsLog.write(
- TEXT_LINKIFY_EVENT_ATOM_ID,
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
sessionId == null ? null : sessionId.getValue(),
event.getEventType(),
getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
@@ -119,7 +114,7 @@
String resultId = nullToEmpty(event.getResultId());
ImmutableList<String> modelNames = ResultIdUtils.getModelNames(resultId);
TextClassifierStatsLog.write(
- CONVERSATION_ACTIONS_EVENT_ATOM_ID,
+ TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
// TODO: Update ExtServices to set the session id.
sessionId == null
? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
@@ -140,7 +135,7 @@
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.LanguageDetectionEvent event) {
TextClassifierStatsLog.write(
- LANGUAGE_DETECTION_EVENT_ATOM_ID,
+ TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
sessionId == null ? null : sessionId.getValue(),
event.getEventType(),
getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null),
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index a0cd0ec..fa31894 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -55,5 +55,5 @@
instrumentation_for: "TextClassifierService",
- data: ["testdata/*"]
+ data: ["testdata/*"],
}
\ No newline at end of file
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 5fde758..3ee30da 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -2,7 +2,7 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier.tests">
- <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="30"/>
+ <uses-sdk android:minSdkVersion="30" android:targetSdkVersion="30"/>
<uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
<uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE"/>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
new file mode 100644
index 0000000..1c4f7f8
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -0,0 +1,303 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.verify;
+
+import android.content.Context;
+import android.os.Binder;
+import android.os.CancellationSignal;
+import android.os.Parcel;
+import android.service.textclassifier.TextClassifierService;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationSessionId;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.statsd.StatsdTestUtils;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class DefaultTextClassifierServiceTest {
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
+ private static final String SESSION_ID = "abcdef";
+
+ private TestInjector testInjector;
+ private DefaultTextClassifierService defaultTextClassifierService;
+ @Mock private TextClassifierService.Callback<TextClassification> textClassificationCallback;
+ @Mock private TextClassifierService.Callback<TextSelection> textSelectionCallback;
+ @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
+ @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
+ @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
+ defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
+ defaultTextClassifierService.onCreate();
+ }
+
+ @Before
+ public void setupStatsdTestUtils() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void classifyText_success() throws Exception {
+ String text = "www.android.com";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, 0, text.length()).build();
+
+ defaultTextClassifierService.onClassifyText(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textClassificationCallback);
+
+ ArgumentCaptor<TextClassification> captor = ArgumentCaptor.forClass(TextClassification.class);
+ verify(textClassificationCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void suggestSelection_success() throws Exception {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int start = text.indexOf(selected);
+ int end = start + suggested.length();
+ TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build();
+
+ defaultTextClassifierService.onSuggestSelection(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textSelectionCallback);
+
+ ArgumentCaptor<TextSelection> captor = ArgumentCaptor.forClass(TextSelection.class);
+ verify(textSelectionCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.SUGGEST_SELECTION, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void generateLinks_success() throws Exception {
+ String text = "Visit http://www.android.com for more information";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ defaultTextClassifierService.onGenerateLinks(
+ createTextClassificationSessionId(), request, new CancellationSignal(), textLinksCallback);
+
+ ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class);
+ verify(textLinksCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getLinks()).hasSize(1);
+ TextLink textLink = captor.getValue().getLinks().iterator().next();
+ assertThat(textLink.getEntityCount()).isGreaterThan(0);
+ assertThat(textLink.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ verifyApiUsageLog(ApiType.GENERATE_LINKS, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void detectLanguage_success() throws Exception {
+ String text = "ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ defaultTextClassifierService.onDetectLanguage(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textLanguageCallback);
+
+ ArgumentCaptor<TextLanguage> captor = ArgumentCaptor.forClass(TextLanguage.class);
+ verify(textLanguageCallback).onSuccess(captor.capture());
+ assertThat(captor.getValue().getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(captor.getValue().getLocale(0).toLanguageTag()).isEqualTo("ja");
+ verifyApiUsageLog(ApiType.DETECT_LANGUAGES, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void suggestConversationActions_success() throws Exception {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Checkout www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ defaultTextClassifierService.onSuggestConversationActions(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ conversationActionsCallback);
+
+ ArgumentCaptor<ConversationActions> captor = ArgumentCaptor.forClass(ConversationActions.class);
+ verify(conversationActionsCallback).onSuccess(captor.capture());
+ List<ConversationAction> conversationActions = captor.getValue().getConversationActions();
+ assertThat(conversationActions.size()).isGreaterThan(0);
+ assertThat(conversationActions.get(0).getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ verifyApiUsageLog(ApiType.SUGGEST_CONVERSATION_ACTIONS, ResultType.SUCCESS);
+ }
+
+ @Test
+ public void missingModelFile_onFailureShouldBeCalled() throws Exception {
+ testInjector.setModelFileManager(
+ new ModelFileManager(ApplicationProvider.getApplicationContext(), ImmutableList.of()));
+ defaultTextClassifierService.onCreate();
+
+ TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
+ defaultTextClassifierService.onClassifyText(
+ createTextClassificationSessionId(),
+ request,
+ new CancellationSignal(),
+ textClassificationCallback);
+
+ verify(textClassificationCallback).onFailure(Mockito.anyString());
+ verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.FAIL);
+ }
+
+ private static void verifyApiUsageLog(
+ AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType,
+ AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)
+ throws Exception {
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0);
+ assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType);
+ assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType);
+ assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // Used a hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
+
+ private static final class TestInjector implements DefaultTextClassifierService.Injector {
+ private final Context context;
+ private ModelFileManager modelFileManager;
+
+ private TestInjector(Context context) {
+ this.context = Preconditions.checkNotNull(context);
+ }
+
+ private void setModelFileManager(ModelFileManager modelFileManager) {
+ this.modelFileManager = modelFileManager;
+ }
+
+ @Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
+ public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
+ if (modelFileManager == null) {
+ return TestDataUtils.createModelFileManagerForTesting(context);
+ }
+ return modelFileManager;
+ }
+
+ @Override
+ public TextClassifierSettings createTextClassifierSettings() {
+ return new TextClassifierSettings();
+ }
+
+ @Override
+ public TextClassifierImpl createTextClassifierImpl(
+ TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ return new TextClassifierImpl(context, settings, modelFileManager);
+ }
+
+ @Override
+ public ListeningExecutorService createNormPriorityExecutor() {
+ return MoreExecutors.newDirectExecutorService();
+ }
+
+ @Override
+ public ListeningExecutorService createLowPriorityExecutor() {
+ return MoreExecutors.newDirectExecutorService();
+ }
+
+ @Override
+ public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
+ TextClassifierSettings settings, Executor executor) {
+ return new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
index 8ef3908..de819ef 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -16,29 +16,31 @@
package com.android.textclassifier;
+import static com.android.textclassifier.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.ArgumentMatchers.anyBoolean;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.when;
import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFilePatternMatchLister;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
+import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
-import java.util.Collections;
import java.util.List;
import java.util.Locale;
-import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -51,22 +53,16 @@
private static final String URL = "http://www.gstatic.com/android/text_classifier/q/711/en.fb";
private static final String URL_2 = "http://www.gstatic.com/android/text_classifier/q/712/en.fb";
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE = ModelFile.ModelType.ANNOTATOR;
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE_2 = ModelFile.ModelType.LANG_ID;
+ @ModelTypeDef private static final String MODEL_TYPE_2 = ModelType.LANG_ID;
- @Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
@Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
- private File rootTestDir;
- private File factoryModelDir;
- private File configUpdaterModelFile;
- private File downloaderModelDir;
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ private File rootTestDir;
private ModelFileManager modelFileManager;
- private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
@Before
public void setup() {
@@ -74,28 +70,11 @@
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
- factoryModelDir = new File(rootTestDir, "factory");
- configUpdaterModelFile = new File(rootTestDir, "configupdater.model");
- downloaderModelDir = new File(rootTestDir, "downloader");
-
- modelFileManager =
- new ModelFileManager(downloaderModelDir, ImmutableMap.of(MODEL_TYPE, modelFileSupplier));
- modelFileSupplierImpl =
- new ModelFileManager.ModelFileSupplierImpl(
- new TextClassifierSettings(mockDeviceConfig),
- MODEL_TYPE,
- factoryModelDir,
- "test\\d.model",
- configUpdaterModelFile,
- downloaderModelDir,
- fd -> 1,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
-
rootTestDir.mkdirs();
- factoryModelDir.mkdirs();
- downloaderModelDir.mkdirs();
-
- Locale.setDefault(DEFAULT_LOCALE);
+ modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ new TextClassifierSettings(mockDeviceConfig));
}
@After
@@ -104,100 +83,106 @@
}
@Test
- public void get() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(MODEL_TYPE);
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
- assertThat(modelFiles).hasSize(1);
- assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
}
@Test
public void findBestModel_versionCode() {
ModelFileManager.ModelFile olderModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile newerModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.getEmptyLocaleList());
-
+ ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
assertThat(bestModelFile).isEqualTo(newerModelFile);
}
@Test
public void findBestModel_languageDependentModelIsPreferred() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ DEFAULT_LOCALE.toLanguageTag(),
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(
- MODEL_TYPE, LocaleList.forLanguageTags(locale.toLanguageTag()));
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_noMatchedLanguageModel() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(DEFAULT_LOCALE),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
DEFAULT_LOCALE.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
@@ -209,23 +194,22 @@
ModelFileManager.ModelFile matchButOlderModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("fr")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"fr",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile mismatchButNewerModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
@@ -233,21 +217,26 @@
}
@Test
- public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
+ public void findBestModel_preferMatchedLocaleModel() {
ModelFileManager.ModelFile matchLocaleModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageIndependentModel =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
@@ -256,9 +245,135 @@
}
@Test
+ public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model1.getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model2.getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isFalse();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
+ File readOnlyDir = new File(rootTestDir, "read_only/");
+ readOnlyDir.mkdirs();
+ File model1 = new File(readOnlyDir, "model1.fb");
+ model1.createNewFile();
+ readOnlyDir.setWritable(false);
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ }
+
+ @Test
public void getDownloadTargetFile_targetFileInCorrectDir() {
File targetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
- assertThat(targetFile.getParentFile()).isEqualTo(downloaderModelDir);
+ assertThat(targetFile.getAbsolutePath())
+ .startsWith(ApplicationProvider.getApplicationContext().getFilesDir().getAbsolutePath());
}
@Test
@@ -278,21 +393,11 @@
public void modelFileEquals() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isEqualTo(modelB);
}
@@ -301,67 +406,23 @@
public void modelFile_different() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isNotEqualTo(modelB);
}
@Test
- public void modelFile_getPath() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getPath()).isEqualTo("/path/a");
- }
-
- @Test
- public void modelFile_getName() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getName()).isEqualTo("a");
- }
-
- @Test
public void modelFile_isPreferredTo_languageDependentIsBetter() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -370,16 +431,11 @@
public void modelFile_isPreferredTo_version() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 1, ImmutableList.of(), "", false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -388,7 +444,7 @@
public void modelFile_toModelInfo() {
ModelFileManager.ModelFile modelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelInfo modelInfo = modelFile.toModelInfo();
@@ -398,11 +454,9 @@
@Test
public void modelFile_toModelInfos() {
ModelFile englishModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
ModelFile japaneseModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ImmutableList<Optional<ModelInfo>> modelInfos =
ModelFileManager.ModelFile.toModelInfos(
@@ -417,64 +471,53 @@
}
@Test
- public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(false);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File model1 = new File(factoryModelDir, "test1.model");
- model1.createNewFile();
- File model2 = new File(factoryModelDir, "test2.model");
- model2.createNewFile();
- new File(factoryModelDir, "not_match_regex.model").createNewFile();
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- model1.getAbsolutePath(),
- model2.getAbsolutePath());
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_includeDownloaderFile() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(true);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File factoryModelFile = new File(factoryModelDir, "test1.model");
- factoryModelFile.createNewFile();
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(ModelFile::getPath).collect(Collectors.toList());
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- downloaderModelFile.getAbsolutePath(),
- factoryModelFile.getAbsolutePath());
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile1.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).absolutePath).isEqualTo(modelFile2.getAbsolutePath());
+ assertThat(listedModels.get(1).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_empty() {
- factoryModelDir.delete();
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
- assertThat(modelFiles).hasSize(0);
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
}
private static void recursiveDelete(File f) {
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 88c0ac8..7565a0b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,14 +16,45 @@
package com.android.textclassifier;
+import android.content.Context;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.google.common.collect.ImmutableList;
import java.io.File;
/** Utils to access test data files. */
public final class TestDataUtils {
+ private static final String TEST_ANNOTATOR_MODEL_PATH = "testdata/annotator.model";
+ private static final String TEST_ACTIONS_MODEL_PATH = "testdata/actions.model";
+ private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
+
/** Returns the root folder that contains the test data. */
public static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
+ public static File getTestAnnotatorModelFile() {
+ return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
+ }
+
+ public static File getTestActionsModelFile() {
+ return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
+ }
+
+ public static File getLangIdModelFile() {
+ return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
+ }
+
+ public static ModelFileManager createModelFileManagerForTesting(Context context) {
+ return new ModelFileManager(
+ context,
+ ImmutableList.of(
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
+ new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)));
+ }
+
private TestDataUtils() {}
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
new file mode 100644
index 0000000..27ea7f0
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.UiAutomation;
+import android.content.pm.PackageManager;
+import android.content.pm.PackageManager.NameNotFoundException;
+import android.icu.util.ULocale;
+import android.provider.DeviceConfig;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.RunWith;
+
+/**
+ * End-to-end tests for the {@link TextClassifier} APIs. Unlike {@link TextClassifierImplTest}.
+ *
+ * <p>Unlike {@link TextClassifierImplTest}, we are trying to run the tests in a environment that is
+ * closer to the production environment. For example, we are not injecting the model files.
+ */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierApiTest {
+
+ private TextClassifier textClassifier;
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ @Before
+ public void setup() {
+ textClassifier = extServicesTextClassifierRule.getTextClassifier();
+ }
+
+ @Test
+ public void suggestSelection() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
+
+ TextSelection selection = textClassifier.suggestSelection(request);
+ assertThat(selection.getEntityCount()).isGreaterThan(0);
+ assertThat(selection.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(selection.getSelectionStartIndex()).isEqualTo(smartStartIndex);
+ assertThat(selection.getSelectionEndIndex()).isEqualTo(smartEndIndex);
+ }
+
+ @Test
+ public void classifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
+
+ TextClassification classification = textClassifier.classifyText(request);
+ assertThat(classification.getEntityCount()).isGreaterThan(0);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getText()).isEqualTo(classifiedText);
+ assertThat(classification.getActions()).isNotEmpty();
+ }
+
+ @Test
+ public void generateLinks() {
+ String text = "Check this out, http://www.android.com";
+
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = textClassifier.generateLinks(request);
+
+ List<TextLink> links = new ArrayList<>(textLinks.getLinks());
+ assertThat(textLinks.getText().toString()).isEqualTo(text);
+ assertThat(links).hasSize(1);
+ assertThat(links.get(0).getEntityCount()).isGreaterThan(0);
+ assertThat(links.get(0).getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0f);
+ }
+
+ @Test
+ public void detectedLanguage() {
+ String text = "朝、ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ TextLanguage textLanguage = textClassifier.detectLanguage(request);
+
+ assertThat(textLanguage.getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(textLanguage.getLocale(0).getLanguage()).isEqualTo("ja");
+ assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0f);
+ }
+
+ @Test
+ public void suggestConversationActions() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ ConversationActions conversationActions = textClassifier.suggestConversationActions(request);
+
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ assertThat(conversationAction.getAction()).isNotNull();
+ }
+
+ /** A rule that manages a text classifier that is backed by the ExtServices. */
+ private static class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+ private static final String PKG_NAME_GOOGLE_EXTSERVICES = "com.google.android.ext.services";
+ private static final String PKG_NAME_AOSP_EXTSERVICES = "android.ext.services";
+
+ private String textClassifierServiceOverrideFlagOldValue;
+
+ @Override
+ protected void before() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ textClassifierServiceOverrideFlagOldValue =
+ DeviceConfig.getString(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ null);
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ getExtServicesPackageName(),
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ @Override
+ protected void after() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ textClassifierServiceOverrideFlagOldValue,
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ private static String getExtServicesPackageName() {
+ PackageManager packageManager =
+ ApplicationProvider.getApplicationContext().getPackageManager();
+ try {
+ packageManager.getApplicationInfo(PKG_NAME_GOOGLE_EXTSERVICES, /* flags= */ 0);
+ return PKG_NAME_GOOGLE_EXTSERVICES;
+ } catch (NameNotFoundException e) {
+ return PKG_NAME_AOSP_EXTSERVICES;
+ }
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ return textClassificationManager.getTextClassifier();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 22674dd..06ec640 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -38,20 +38,16 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import java.io.File;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.Locale;
-import java.util.function.Supplier;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
@@ -68,51 +64,8 @@
private static final String NO_TYPE = null;
private TextClassifierImpl classifier;
- private static final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>>
- MODEL_FILES_SUPPLIER =
- new ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>>()
- .put(
- ModelType.ANNOTATOR,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ANNOTATOR,
- new File(
- TestDataUtils.getTestDataFolder(), "testdata/annotator.model"),
- 711,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- (Supplier<ImmutableList<ModelFile>>)
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ACTIONS_SUGGESTIONS,
- new File(
- TestDataUtils.getTestDataFolder(), "testdata/actions.model"),
- 104,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.LANG_ID,
- (Supplier<ImmutableList<ModelFile>>)
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.LANG_ID,
- new File(
- TestDataUtils.getTestDataFolder(), "testdata/langid.model"),
- 1,
- ImmutableList.of(),
- "*",
- true)))
- .build();
private final ModelFileManager modelFileManager =
- new ModelFileManager(
- /* downloadModelDir= */ TestDataUtils.getTestDataFolder(), MODEL_FILES_SUPPLIER);
+ TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
@Before
public void setup() {
@@ -126,7 +79,7 @@
}
@Test
- public void testSuggestSelection() {
+ public void testSuggestSelection() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
@@ -145,7 +98,7 @@
}
@Test
- public void testSuggestSelection_url() {
+ public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
@@ -163,7 +116,7 @@
}
@Test
- public void testSmartSelection_withEmoji() {
+ public void testSmartSelection_withEmoji() throws IOException {
String text = "\uD83D\uDE02 Hello.";
String selected = "Hello";
int startIndex = text.indexOf(selected);
@@ -178,7 +131,7 @@
}
@Test
- public void testClassifyText() {
+ public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
String classifiedText = "droid@android.com";
int startIndex = text.indexOf(classifiedText);
@@ -193,7 +146,7 @@
}
@Test
- public void testClassifyText_url() {
+ public void testClassifyText_url() throws IOException {
String text = "Visit www.android.com for more information";
String classifiedText = "www.android.com";
int startIndex = text.indexOf(classifiedText);
@@ -209,7 +162,7 @@
}
@Test
- public void testClassifyText_address() {
+ public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length())
@@ -221,7 +174,7 @@
}
@Test
- public void testClassifyText_url_inCaps() {
+ public void testClassifyText_url_inCaps() throws IOException {
String text = "Visit HTTP://ANDROID.COM for more information";
String classifiedText = "HTTP://ANDROID.COM";
int startIndex = text.indexOf(classifiedText);
@@ -237,7 +190,7 @@
}
@Test
- public void testClassifyText_date() {
+ public void testClassifyText_date() throws IOException {
String text = "Let's meet on January 9, 2018.";
String classifiedText = "January 9, 2018";
int startIndex = text.indexOf(classifiedText);
@@ -258,7 +211,7 @@
}
@Test
- public void testClassifyText_datetime() {
+ public void testClassifyText_datetime() throws IOException {
String text = "Let's meet 2018/01/01 10:30:20.";
String classifiedText = "2018/01/01 10:30:20";
int startIndex = text.indexOf(classifiedText);
@@ -273,7 +226,7 @@
}
@Test
- public void testClassifyText_foreignText() {
+ public void testClassifyText_foreignText() throws IOException {
LocaleList originalLocales = LocaleList.getDefault();
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
@@ -302,7 +255,7 @@
}
@Test
- public void testGenerateLinks_phone() {
+ public void testGenerateLinks_phone() throws IOException {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
@@ -311,7 +264,7 @@
}
@Test
- public void testGenerateLinks_exclude() {
+ public void testGenerateLinks_exclude() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
@@ -327,7 +280,7 @@
}
@Test
- public void testGenerateLinks_explicit_address() {
+ public void testGenerateLinks_explicit_address() throws IOException {
String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
TextLinks.Request request =
@@ -342,7 +295,7 @@
}
@Test
- public void testGenerateLinks_exclude_override() {
+ public void testGenerateLinks_exclude_override() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
@@ -358,7 +311,7 @@
}
@Test
- public void testGenerateLinks_maxLength() {
+ public void testGenerateLinks_maxLength() throws IOException {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
@@ -367,7 +320,7 @@
}
@Test
- public void testApplyLinks_unsupportedCharacter() {
+ public void testApplyLinks_unsupportedCharacter() throws IOException {
Spannable url = new SpannableString("\u202Emoc.diordna.com");
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
@@ -384,7 +337,7 @@
}
@Test
- public void testGenerateLinks_entityData() {
+ public void testGenerateLinks_entityData() throws IOException {
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
@@ -401,7 +354,7 @@
}
@Test
- public void testGenerateLinks_entityData_disabled() {
+ public void testGenerateLinks_entityData_disabled() throws IOException {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
@@ -414,7 +367,7 @@
}
@Test
- public void testDetectLanguage() {
+ public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -422,7 +375,7 @@
}
@Test
- public void testDetectLanguage_japanese() {
+ public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -430,7 +383,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -454,7 +407,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_noMax() {
+ public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -477,7 +430,7 @@
}
@Test
- public void testSuggestConversationActions_openUrl() {
+ public void testSuggestConversationActions_openUrl() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Check this out: https://www.android.com")
@@ -504,7 +457,7 @@
}
@Test
- public void testSuggestConversationActions_copy() {
+ public void testSuggestConversationActions_copy() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Authentication code: 12345")
@@ -532,7 +485,7 @@
}
@Test
- public void testSuggestConversationActions_deduplicate() {
+ public void testSuggestConversationActions_deduplicate() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("a@android.com b@android.com")
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
index c0a823e..e1e7982 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -22,7 +22,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import androidx.test.platform.app.InstrumentationRegistry;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import java.util.function.Consumer;
import org.junit.After;
import org.junit.Before;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
index c2a911a..6c66dd5 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
@@ -49,6 +49,8 @@
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
private static final ModelInfo ANNOTATOR_MODEL =
new ModelInfo(1, ImmutableList.of(Locale.ENGLISH));
private static final ModelInfo LANGID_MODEL =
@@ -92,7 +94,7 @@
LATENCY_MS,
Optional.of(ANNOTATOR_MODEL),
Optional.of(LANGID_MODEL));
- ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
ImmutableList<TextLinkifyEvent> loggedEvents =
ImmutableList.copyOf(
@@ -157,7 +159,7 @@
LATENCY_MS,
Optional.of(ANNOTATOR_MODEL),
Optional.of(LANGID_MODEL));
- ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
ImmutableList<TextLinkifyEvent> loggedEvents =
ImmutableList.copyOf(
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index b52509c..1bcd7b7 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -46,7 +46,6 @@
/** Util functions to make statsd testing easier by using adb shell cmd stats commands. */
public class StatsdTestUtils {
private static final String TAG = "StatsdTestUtils";
- private static final long LONG_WAIT_MS = 5000;
private StatsdTestUtils() {}
@@ -74,9 +73,10 @@
/**
* Extracts logged atoms from the report, sorted by logging time, and deletes the saved report.
*/
- public static ImmutableList<Atom> getLoggedAtoms(long configId) throws Exception {
+ public static ImmutableList<Atom> getLoggedAtoms(long configId, long timeoutMillis)
+ throws Exception {
// There is no callback to notify us the log is collected. So we do a wait here.
- Thread.sleep(LONG_WAIT_MS);
+ Thread.sleep(timeoutMillis);
ConfigMetricsReportList reportList = getAndRemoveReportList(configId);
assertThat(reportList.getReportsCount()).isEqualTo(1);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
new file mode 100644
index 0000000..b9b7a95
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.Binder;
+import android.os.Parcel;
+import android.view.textclassifier.TextClassificationSessionId;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
+import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
+import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger.Session;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+@LargeTest
+public class TextClassifierApiUsageLoggerTest {
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
+ private static final String SESSION_ID = "abcdef";
+
+ @Before
+ public void setup() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void reportSuccess() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION,
+ createTextClassificationSessionId());
+ // so that the latency we log is greater than 0.
+ Thread.sleep(10);
+ session.reportSuccess();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported event = loggedEvents.get(0);
+ assertThat(event.getApiType()).isEqualTo(ApiType.SUGGEST_SELECTION);
+ assertThat(event.getResultType()).isEqualTo(ResultType.SUCCESS);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ @Test
+ public void reportFailure() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT,
+ createTextClassificationSessionId());
+ // so that the latency we log is greater than 0.
+ Thread.sleep(10);
+ session.reportFailure();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).hasSize(1);
+ TextClassifierApiUsageReported event = loggedEvents.get(0);
+ assertThat(event.getApiType()).isEqualTo(ApiType.CLASSIFY_TEXT);
+ assertThat(event.getResultType()).isEqualTo(ResultType.FAIL);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
+ assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
+ }
+
+ @Test
+ public void noLog_sampleRateZero() throws Exception {
+ TextClassifierApiUsageLogger textClassifierApiUsageLogger =
+ new TextClassifierApiUsageLogger(
+ /* sampleRateSupplier= */ () -> 0, MoreExecutors.directExecutor());
+ Session session =
+ textClassifierApiUsageLogger.createSession(
+ TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT,
+ createTextClassificationSessionId());
+ session.reportSuccess();
+
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
+
+ ImmutableList<TextClassifierApiUsageReported> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream()
+ .map(Atom::getTextClassifierApiUsageReported)
+ .collect(Collectors.toList()));
+
+ assertThat(loggedEvents).isEmpty();
+ }
+
+ private static TextClassificationSessionId createTextClassificationSessionId() {
+ // Used a hack to create TextClassificationSessionId because its constructor is @hide.
+ Parcel parcel = Parcel.obtain();
+ parcel.writeString(SESSION_ID);
+ parcel.writeStrongBinder(new Binder());
+ parcel.setDataPosition(0);
+ return TextClassificationSessionId.CREATOR.createFromParcel(parcel);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
index 719fc31..f105e26 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
@@ -45,6 +45,8 @@
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
+ private static final long SHORT_TIMEOUT_MS = 1000;
+
private TextClassifierEventLogger textClassifierEventLogger;
@Before
@@ -102,7 +104,7 @@
.setPackageName(PKG_NAME)
.setLangidModelName("und_v1")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent()).isEqualTo(event);
}
@@ -119,7 +121,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.SMART_SELECTION_SINGLE);
@@ -137,7 +139,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.SMART_SELECTION_MULTI);
@@ -155,7 +157,7 @@
textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
.isEqualTo(EventType.AUTO_SELECTION);
@@ -189,7 +191,7 @@
.setPackageName(PKG_NAME)
.setLangidModelName("und_v1")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getTextLinkifyEvent()).isEqualTo(event);
}
@@ -223,7 +225,7 @@
.setAnnotatorModelName("zh_v2")
.setLangidModelName("und_v3")
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getConversationActionsEvent()).isEqualTo(event);
}
@@ -254,7 +256,7 @@
.setActionIndex(1)
.setPackageName(PKG_NAME)
.build();
- ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
assertThat(atoms).hasSize(1);
assertThat(atoms.get(0).getLanguageDetectionEvent()).isEqualTo(event);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
deleted file mode 100644
index 3585f87..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.subjects;
-
-import static com.google.common.truth.Truth.assertAbout;
-
-import com.android.textclassifier.Entity;
-import com.google.common.truth.FailureMetadata;
-import com.google.common.truth.Subject;
-import javax.annotation.Nullable;
-
-/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
-public final class EntitySubject extends Subject<EntitySubject, Entity> {
-
- private static final float TOLERANCE = 0.0001f;
-
- private final Entity entity;
-
- public static EntitySubject assertThat(@Nullable Entity entity) {
- return assertAbout(EntitySubject::new).that(entity);
- }
-
- private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
- super(failureMetadata, entity);
- this.entity = entity;
- }
-
- public void isMatchWithinTolerance(@Nullable Entity entity) {
- if (!entity.getEntityType().equals(this.entity.getEntityType())) {
- failWithActual("expected to have type", entity.getEntityType());
- }
- check("expected to have confidence score")
- .that(entity.getScore())
- .isWithin(TOLERANCE)
- .of(this.entity.getScore());
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
new file mode 100644
index 0000000..ec1405b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.os.LocaleList;
+import org.junit.rules.ExternalResource;
+
+public class SetDefaultLocalesRule extends ExternalResource {
+
+ private LocaleList originalValue;
+
+ @Override
+ protected void before() throws Throwable {
+ super.before();
+ originalValue = LocaleList.getDefault();
+ }
+
+ public void set(LocaleList newValue) {
+ LocaleList.setDefault(newValue);
+ }
+
+ @Override
+ protected void after() {
+ super.after();
+ LocaleList.setDefault(originalValue);
+ }
+}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index d2001ed..47a369e 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -342,6 +342,8 @@
@Nullable private final String contactNickname;
@Nullable private final String contactEmailAddress;
@Nullable private final String contactPhoneNumber;
+ @Nullable private final String contactAccountType;
+ @Nullable private final String contactAccountName;
@Nullable private final String contactId;
@Nullable private final String appName;
@Nullable private final String appPackageName;
@@ -363,6 +365,8 @@
@Nullable String contactNickname,
@Nullable String contactEmailAddress,
@Nullable String contactPhoneNumber,
+ @Nullable String contactAccountType,
+ @Nullable String contactAccountName,
@Nullable String contactId,
@Nullable String appName,
@Nullable String appPackageName,
@@ -382,6 +386,8 @@
this.contactNickname = contactNickname;
this.contactEmailAddress = contactEmailAddress;
this.contactPhoneNumber = contactPhoneNumber;
+ this.contactAccountType = contactAccountType;
+ this.contactAccountName = contactAccountName;
this.contactId = contactId;
this.appName = appName;
this.appPackageName = appPackageName;
@@ -444,6 +450,16 @@
}
@Nullable
+ public String getContactAccountType() {
+ return contactAccountType;
+ }
+
+ @Nullable
+ public String getContactAccountName() {
+ return contactAccountName;
+ }
+
+ @Nullable
public String getContactId() {
return contactId;
}
@@ -550,22 +566,40 @@
public InputFragment(String text) {
this.text = text;
this.datetimeOptionsNullable = null;
+ this.boundingBoxTop = 0;
+ this.boundingBoxHeight = 0;
}
- public InputFragment(String text, DatetimeOptions datetimeOptions) {
+ public InputFragment(
+ String text,
+ DatetimeOptions datetimeOptions,
+ float boundingBoxTop,
+ float boundingBoxHeight) {
this.text = text;
this.datetimeOptionsNullable = datetimeOptions;
+ this.boundingBoxTop = boundingBoxTop;
+ this.boundingBoxHeight = boundingBoxHeight;
}
private final String text;
// The DatetimeOptions can't be Optional because the _api16 build of the TCLib SDK does not
// support java.util.Optional.
private final DatetimeOptions datetimeOptionsNullable;
+ private final float boundingBoxTop;
+ private final float boundingBoxHeight;
public String getText() {
return text;
}
+ public float getBoundingBoxTop() {
+ return boundingBoxTop;
+ }
+
+ public float getBoundingBoxHeight() {
+ return boundingBoxHeight;
+ }
+
public boolean hasDatetimeOptions() {
return datetimeOptionsNullable != null;
}
@@ -588,6 +622,7 @@
private final double userLocationLng;
private final float userLocationAccuracyMeters;
private final boolean usePodNer;
+ private final boolean useVocabAnnotator;
private SelectionOptions(
@Nullable String locales,
@@ -596,7 +631,8 @@
double userLocationLat,
double userLocationLng,
float userLocationAccuracyMeters,
- boolean usePodNer) {
+ boolean usePodNer,
+ boolean useVocabAnnotator) {
this.locales = locales;
this.detectedTextLanguageTags = detectedTextLanguageTags;
this.annotationUsecase = annotationUsecase;
@@ -604,6 +640,7 @@
this.userLocationLng = userLocationLng;
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.usePodNer = usePodNer;
+ this.useVocabAnnotator = useVocabAnnotator;
}
/** Can be used to build a SelectionsOptions instance. */
@@ -615,6 +652,7 @@
private double userLocationLng = INVALID_LONGITUDE;
private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
private boolean usePodNer = true;
+ private boolean useVocabAnnotator = true;
public Builder setLocales(@Nullable String locales) {
this.locales = locales;
@@ -651,6 +689,11 @@
return this;
}
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
public SelectionOptions build() {
return new SelectionOptions(
locales,
@@ -659,7 +702,8 @@
userLocationLat,
userLocationLng,
userLocationAccuracyMeters,
- usePodNer);
+ usePodNer,
+ useVocabAnnotator);
}
}
@@ -697,6 +741,10 @@
public boolean getUsePodNer() {
return usePodNer;
}
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
/** Represents options for the classifyText call. */
@@ -712,6 +760,7 @@
private final String userFamiliarLanguageTags;
private final boolean usePodNer;
private final boolean triggerDictionaryOnBeginnerWords;
+ private final boolean useVocabAnnotator;
private ClassificationOptions(
long referenceTimeMsUtc,
@@ -724,7 +773,8 @@
float userLocationAccuracyMeters,
String userFamiliarLanguageTags,
boolean usePodNer,
- boolean triggerDictionaryOnBeginnerWords) {
+ boolean triggerDictionaryOnBeginnerWords,
+ boolean useVocabAnnotator) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
@@ -736,6 +786,7 @@
this.userFamiliarLanguageTags = userFamiliarLanguageTags;
this.usePodNer = usePodNer;
this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ this.useVocabAnnotator = useVocabAnnotator;
}
/** Can be used to build a ClassificationOptions instance. */
@@ -751,6 +802,7 @@
private String userFamiliarLanguageTags = "";
private boolean usePodNer = true;
private boolean triggerDictionaryOnBeginnerWords = false;
+ private boolean useVocabAnnotator = true;
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -808,6 +860,11 @@
return this;
}
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
public ClassificationOptions build() {
return new ClassificationOptions(
referenceTimeMsUtc,
@@ -820,7 +877,8 @@
userLocationAccuracyMeters,
userFamiliarLanguageTags,
usePodNer,
- triggerDictionaryOnBeginnerWords);
+ triggerDictionaryOnBeginnerWords,
+ useVocabAnnotator);
}
}
@@ -874,6 +932,10 @@
public boolean getTriggerDictionaryOnBeginnerWords() {
return triggerDictionaryOnBeginnerWords;
}
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
/** Represents options for the annotate call. */
@@ -893,6 +955,7 @@
private final float userLocationAccuracyMeters;
private final boolean usePodNer;
private final boolean triggerDictionaryOnBeginnerWords;
+ private final boolean useVocabAnnotator;
private AnnotationOptions(
long referenceTimeMsUtc,
@@ -909,7 +972,8 @@
double userLocationLng,
float userLocationAccuracyMeters,
boolean usePodNer,
- boolean triggerDictionaryOnBeginnerWords) {
+ boolean triggerDictionaryOnBeginnerWords,
+ boolean useVocabAnnotator) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
@@ -925,6 +989,7 @@
this.hasPersonalizationPermission = hasPersonalizationPermission;
this.usePodNer = usePodNer;
this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ this.useVocabAnnotator = useVocabAnnotator;
}
/** Can be used to build an AnnotationOptions instance. */
@@ -944,6 +1009,7 @@
private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
private boolean usePodNer = true;
private boolean triggerDictionaryOnBeginnerWords = false;
+ private boolean useVocabAnnotator = true;
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -1020,6 +1086,11 @@
return this;
}
+ public Builder setUseVocabAnnotator(boolean useVocabAnnotator) {
+ this.useVocabAnnotator = useVocabAnnotator;
+ return this;
+ }
+
public AnnotationOptions build() {
return new AnnotationOptions(
referenceTimeMsUtc,
@@ -1036,7 +1107,8 @@
userLocationLng,
userLocationAccuracyMeters,
usePodNer,
- triggerDictionaryOnBeginnerWords);
+ triggerDictionaryOnBeginnerWords,
+ useVocabAnnotator);
}
}
@@ -1106,6 +1178,10 @@
public boolean getTriggerDictionaryOnBeginnerWords() {
return triggerDictionaryOnBeginnerWords;
}
+
+ public boolean getUseVocabAnnotator() {
+ return useVocabAnnotator;
+ }
}
/**
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index 0015826..890c9b0 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -16,6 +16,7 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -48,6 +49,29 @@
}
}
+ /**
+ * Creates a new instance of LangId predictor, using the provided model image, given as an {@link
+ * AssetFileDescriptor}.
+ */
+ public LangIdModel(AssetFileDescriptor assetFileDescriptor) {
+ modelPtr =
+ nativeNewWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from asset file descriptor.");
+ }
+ }
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(int fd, long offset, long size) {
+ modelPtr = nativeNewWithOffset(fd, offset, size);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
+ }
+ }
+
/** Detects the languages for given text. */
public LanguageResult[] detectLanguages(String text) {
return nativeDetectLanguages(modelPtr, text);
@@ -95,14 +119,22 @@
return nativeGetVersion(modelPtr);
}
- public float getLangIdThreshold() {
- return nativeGetLangIdThreshold(modelPtr);
- }
-
public static int getVersion(int fd) {
return nativeGetVersionFromFd(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
+ public float getLangIdThreshold() {
+ return nativeGetLangIdThreshold(modelPtr);
+ }
+
/** Retrieves the pointer to the native object. */
long getNativePointer() {
return modelPtr;
@@ -130,6 +162,8 @@
private static native long nativeNewFromPath(String path);
+ private static native long nativeNewWithOffset(int fd, long offset, long size);
+
private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
private native void nativeClose(long nativePtr);
@@ -143,4 +177,6 @@
private native float nativeGetLangIdNoiseThreshold(long nativePtr);
private native int nativeGetMinTextSizeInBytes(long nativePtr);
+
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
}
diff --git a/native/Android.bp b/native/Android.bp
index aeac511..e9cbe13 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -89,7 +89,7 @@
"-DTC3_UNILIB_JAVAICU",
"-DTC3_CALENDAR_JAVAICU",
"-DTC3_AOSP",
- "-DTC3_VOCAB_ANNOTATOR_DUMMY"
+ "-DTC3_VOCAB_ANNOTATOR_DUMMY",
],
product_variables: {
@@ -204,8 +204,6 @@
],
compile_multilib: "prefer32",
-
- sdk_variant_only: true
}
// ------------------------------------
@@ -216,6 +214,7 @@
defaults: ["libtextclassifier_defaults"],
srcs: [
":libtextclassifier_java_test_sources",
+ "annotator/datetime/testing/*.cc",
"actions/test-utils.cc",
"utils/testing/annotator.cc",
"utils/testing/logging_event_listener.cc",
@@ -250,6 +249,7 @@
],
jni_uses_sdk_apis: true,
data: [
+ "**/*.bfbs",
"**/test_data/*",
],
test_config: "JavaTest.xml",
diff --git a/native/FlatBufferHeaders.bp b/native/FlatBufferHeaders.bp
index ba9f677..6248d2a 100644
--- a/native/FlatBufferHeaders.bp
+++ b/native/FlatBufferHeaders.bp
@@ -43,20 +43,6 @@
}
genrule {
- name: "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
- srcs: ["annotator/grammar/dates/timezone-code.fbs"],
- out: ["annotator/grammar/dates/timezone-code_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_annotator_grammar_dates_dates",
- srcs: ["annotator/grammar/dates/dates.fbs"],
- out: ["annotator/grammar/dates/dates_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
name: "libtextclassifier_fbgen_annotator_model",
srcs: ["annotator/model.fbs"],
out: ["annotator/model_generated.h"],
@@ -78,6 +64,13 @@
}
genrule {
+ name: "libtextclassifier_fbgen_annotator_datetime_datetime",
+ srcs: ["annotator/datetime/datetime.fbs"],
+ out: ["annotator/datetime/datetime_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
name: "libtextclassifier_fbgen_annotator_entity-data",
srcs: ["annotator/entity-data.fbs"],
out: ["annotator/entity-data_generated.h"],
@@ -85,9 +78,16 @@
}
genrule {
- name: "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
- srcs: ["utils/grammar/next/semantics/expression.fbs"],
- out: ["utils/grammar/next/semantics/expression_generated.h"],
+ name: "libtextclassifier_fbgen_utils_grammar_testing_value",
+ srcs: ["utils/grammar/testing/value.fbs"],
+ out: ["utils/grammar/testing/value_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_grammar_semantics_expression",
+ srcs: ["utils/grammar/semantics/expression.fbs"],
+ out: ["utils/grammar/semantics/expression_generated.h"],
defaults: ["fbgen"],
}
@@ -182,13 +182,13 @@
"libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_actions_actions_model",
"libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
- "libtextclassifier_fbgen_annotator_grammar_dates_dates",
"libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_testing_value",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
"libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_normalization",
"libtextclassifier_fbgen_utils_resources",
@@ -206,13 +206,13 @@
"libtextclassifier_fbgen_lang_id_common_flatbuffers_embedding-network",
"libtextclassifier_fbgen_actions_actions_model",
"libtextclassifier_fbgen_actions_actions-entity-data",
- "libtextclassifier_fbgen_annotator_grammar_dates_timezone-code",
- "libtextclassifier_fbgen_annotator_grammar_dates_dates",
"libtextclassifier_fbgen_annotator_model",
"libtextclassifier_fbgen_annotator_person_name_person_name_model",
"libtextclassifier_fbgen_annotator_experimental_experimental",
+ "libtextclassifier_fbgen_annotator_datetime_datetime",
"libtextclassifier_fbgen_annotator_entity-data",
- "libtextclassifier_fbgen_utils_grammar_next_semantics_expression",
+ "libtextclassifier_fbgen_utils_grammar_testing_value",
+ "libtextclassifier_fbgen_utils_grammar_semantics_expression",
"libtextclassifier_fbgen_utils_grammar_rules",
"libtextclassifier_fbgen_utils_normalization",
"libtextclassifier_fbgen_utils_resources",
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index af2ae1c..1c5099d 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -17,14 +17,27 @@
filegroup {
name: "libtextclassifier_java_test_sources",
srcs: [
+ "actions/actions-suggestions_test.cc",
"actions/grammar-actions_test.cc",
- "annotator/datetime/parser_test.cc",
+ "annotator/datetime/regex-parser_test.cc",
+ "utils/grammar/parsing/lexer_test.cc",
+ "utils/regex-match_test.cc",
+ "utils/calendar/calendar_test.cc",
"utils/intents/intent-generator-test-lib.cc",
"annotator/grammar/grammar-annotator_test.cc",
"annotator/grammar/test-utils.cc",
"annotator/number/number_test-include.cc",
"annotator/annotator_test-include.cc",
"utils/utf8/unilib_test-include.cc",
- "utils/calendar/calendar_test-include.cc",
+ "utils/grammar/parsing/parser_test.cc",
+ "utils/grammar/analyzer_test.cc",
+ "utils/grammar/semantics/composer_test.cc",
+ "utils/grammar/semantics/evaluators/merge-values-eval_test.cc",
+ "utils/grammar/semantics/evaluators/constituent-eval_test.cc",
+ "utils/grammar/semantics/evaluators/parse-number-eval_test.cc",
+ "utils/grammar/semantics/evaluators/arithmetic-eval_test.cc",
+ "utils/grammar/semantics/evaluators/span-eval_test.cc",
+ "utils/grammar/semantics/evaluators/const-eval_test.cc",
+ "utils/grammar/semantics/evaluators/compose-eval_test.cc",
],
}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index d3f13e4..7421579 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index f550cc7..a9edde9 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -1341,7 +1341,13 @@
// Check that messages are valid utf8.
for (const ConversationMessage& message : conversation.messages) {
- if (!IsValidUTF8(message.text.data(), message.text.size())) {
+ if (message.text.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
+ return {};
+ }
+
+ if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
+ message.text.data(), message.text.size(), /*do_copy=*/false))) {
TC3_LOG(ERROR) << "Not valid utf8 provided.";
return response;
}
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index ed92981..55aa852 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -19,6 +19,7 @@
#include <fstream>
#include <iterator>
#include <memory>
+#include <string>
#include "actions/actions_model_generated.h"
#include "actions/test-utils.h"
@@ -30,6 +31,7 @@
#include "utils/flatbuffers/mutable.h"
#include "utils/grammar/utils/rules.h"
#include "utils/hash/farmhash.h"
+#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -46,6 +48,8 @@
using ::testing::SizeIs;
constexpr char kModelFileName[] = "actions_suggestions_test.model";
+constexpr char kModelGrammarFileName[] =
+ "actions_suggestions_grammar_test.model";
constexpr char kMultiTaskModelFileName[] =
"actions_suggestions_test.multi_task_9heads.model";
constexpr char kHashGramModelFileName[] =
@@ -60,28 +64,30 @@
class ActionsSuggestionsTest : public testing::Test {
protected:
- ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- std::unique_ptr<ActionsSuggestions> LoadTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
- &unilib_);
+ explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
+ std::unique_ptr<ActionsSuggestions> LoadTestModel(
+ const std::string model_file_name) {
+ return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
+ unilib_.get());
}
std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
- &unilib_);
+ unilib_.get());
}
std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
return ActionsSuggestions::FromPath(
- GetModelPath() + kMultiTaskModelFileName, &unilib_);
+ GetModelPath() + kMultiTaskModelFileName, unilib_.get());
}
- UniLib unilib_;
+ std::unique_ptr<UniLib> unilib_;
};
TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
- EXPECT_THAT(LoadTestModel(), NotNull());
+ EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
}
TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?\xf0\x9f",
@@ -91,8 +97,23 @@
EXPECT_THAT(response.actions, IsEmpty());
}
+TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, IsEmpty());
+}
+
TEST_F(ActionsSuggestionsTest, SuggestsActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
@@ -102,7 +123,8 @@
}
TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
@@ -112,7 +134,8 @@
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
@@ -156,7 +179,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
@@ -213,7 +236,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
@@ -239,7 +262,8 @@
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
AnnotatedSpan flight_annotation;
flight_annotation.span = {11, 15};
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
@@ -280,7 +304,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
AnnotatedSpan flight_annotation;
flight_annotation.span = {11, 15};
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
@@ -380,7 +404,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(1));
EXPECT_EQ(response.actions[0].type, "track_flight");
}
@@ -395,7 +419,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
3;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(2));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
@@ -411,7 +435,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(2));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
@@ -428,7 +452,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(3));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
@@ -446,7 +470,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(3));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
@@ -464,7 +488,7 @@
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
- &unilib_);
+ unilib_.get());
EXPECT_THAT(response.actions, SizeIs(4));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
@@ -503,7 +527,7 @@
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
},
- &unilib_,
+ unilib_.get(),
/*expected_size=*/1 /*no smart reply, only actions*/
);
}
@@ -513,7 +537,7 @@
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_reply_score_threshold = 1.0;
},
- &unilib_,
+ unilib_.get(),
/*expected_size=*/1 /*no smart reply, only actions*/
);
}
@@ -523,7 +547,7 @@
[](ActionsModelT* actions_model) {
actions_model->preconditions->max_sensitive_topic_score = 0.0;
},
- &unilib_,
+ unilib_.get(),
/*expected_size=*/4 /* no sensitive prediction in test model*/);
}
@@ -532,7 +556,7 @@
[](ActionsModelT* actions_model) {
actions_model->preconditions->max_input_length = 0;
},
- &unilib_);
+ unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
@@ -540,7 +564,7 @@
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_input_length = 100;
},
- &unilib_);
+ unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
@@ -551,7 +575,7 @@
TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
TestSuggestActionsWithThreshold(
// Keep model untouched.
- [](ActionsModelT* actions_model) {}, &unilib_,
+ [](ActionsModelT* actions_model) {}, unilib_.get(),
/*expected_size=*/0,
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize()));
@@ -568,7 +592,7 @@
actions_model->low_confidence_rules->regex_rule.back()->pattern =
"low-ground";
},
- &unilib_);
+ unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
@@ -617,7 +641,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
ASSERT_TRUE(actions_suggestions);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -686,7 +710,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, serialize_preconditions);
+ builder.GetSize(), unilib_.get(), serialize_preconditions);
ASSERT_TRUE(actions_suggestions);
const ActionsSuggestionsResponse response =
@@ -719,7 +743,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
@@ -749,7 +773,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
@@ -770,8 +794,31 @@
EXPECT_EQ(response.actions[0].score, 1.0);
}
+TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelGrammarFileName);
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("phone", 0.0)};
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "Contact us at: *1234",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "call_phone");
+ EXPECT_EQ(response.actions.front().score, 0.0);
+ EXPECT_EQ(response.actions.front().priority_score, 0.0);
+ EXPECT_EQ(response.actions.front().annotations.size(), 1);
+ EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
+ EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
+}
+
TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {8, 12};
annotation.classification = {
@@ -849,7 +896,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -913,7 +960,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -963,7 +1010,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -996,11 +1043,11 @@
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
grammar::Rules rules;
- rules.Add("<knock>", {"<^>", "ventura", "!?", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/0);
+ rules.Add(
+ "<knock>", {"<^>", "ventura", "!?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/0);
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules->rules.get());
action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
@@ -1021,7 +1068,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -1036,7 +1083,7 @@
#if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
std::unique_ptr<Annotator> annotator =
- Annotator::FromPath(GetModelPath() + "en.fb", &unilib_);
+ Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
@@ -1056,11 +1103,11 @@
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
grammar::Rules rules;
- rules.Add("<event>", {"it", "is", "at", "<time>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/0);
+ rules.Add(
+ "<event>", {"it", "is", "at", "<time>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/0);
rules.BindAnnotation("<time>", "time");
rules.AddAnnotation("datetime");
rules.Finalize().Serialize(/*include_debug_information=*/false,
@@ -1082,7 +1129,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
@@ -1097,7 +1144,8 @@
#endif
TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
@@ -1139,7 +1187,7 @@
ActionsModel::Pack(builder, actions_model.get()));
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
@@ -1149,7 +1197,8 @@
}
TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {7, 11};
annotation.classification = {
@@ -1195,7 +1244,7 @@
ActionsModel::Pack(builder, actions_model.get()));
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
+ builder.GetSize(), unilib_.get());
response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I'm on LX38",
@@ -1209,7 +1258,8 @@
#endif
TEST_F(ActionsSuggestionsTest, RanksActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
std::vector<AnnotatedSpan> annotations(2);
annotations[0].span = {11, 15};
annotations[0].classification = {ClassificationResult("address", 1.0)};
@@ -1308,15 +1358,15 @@
}
};
- const UniLib unilib_;
+ std::unique_ptr<UniLib> unilib_;
};
TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
- : INIT_UNILIB_FOR_TESTING(unilib_) {
+ : unilib_(CreateUniLibForTesting()) {
model_ = model;
const ActionsTokenFeatureProcessorOptions* options =
model->feature_processor_options();
- feature_processor_.reset(new ActionsFeatureProcessor(options, &unilib_));
+ feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
embedding_executor_.reset(new FakeEmbeddingExecutor());
EXPECT_TRUE(
EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 7d626a8..1548816 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -14,16 +14,16 @@
// limitations under the License.
//
-include "annotator/model.fbs";
-include "utils/normalization.fbs";
include "actions/actions-entity-data.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/resources.fbs";
-include "utils/zlib/buffer.fbs";
include "utils/grammar/rules.fbs";
include "utils/tokenizer.fbs";
-include "utils/intents/intent-config.fbs";
include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/zlib/buffer.fbs";
+include "utils/normalization.fbs";
+include "annotator/model.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/resources.fbs";
file_identifier "TC3A";
diff --git a/native/actions/grammar-actions.cc b/native/actions/grammar-actions.cc
index 597ee59..e925086 100644
--- a/native/actions/grammar-actions.cc
+++ b/native/actions/grammar-actions.cc
@@ -16,187 +16,14 @@
#include "actions/grammar-actions.h"
-#include <algorithm>
-#include <unordered_map>
-
#include "actions/feature-processor.h"
#include "actions/utils.h"
#include "annotator/types.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/i18n/language-tag_generated.h"
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-namespace {
-
-class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
- public:
- GrammarActionsCallbackDelegate(const UniLib* unilib,
- const RulesModel_::GrammarRules* grammar_rules)
- : unilib_(*unilib), grammar_rules_(grammar_rules) {}
-
- // Handle a grammar rule match in the actions grammar.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override {
- switch (static_cast<GrammarActions::Callback>(type)) {
- case GrammarActions::Callback::kActionRuleMatch: {
- HandleRuleMatch(match, /*rule_id=*/value);
- return;
- }
- default:
- grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
- }
- }
-
- // Deduplicate, verify and populate actions from grammar matches.
- bool GetActions(const Conversation& conversation,
- const std::string& smart_reply_action_type,
- const MutableFlatbufferBuilder* entity_data_builder,
- std::vector<ActionSuggestion>* action_suggestions) const {
- std::vector<UnicodeText::const_iterator> codepoint_offsets;
- const UnicodeText message_unicode =
- UTF8ToUnicodeText(conversation.messages.back().text,
- /*do_copy=*/false);
- for (auto it = message_unicode.begin(); it != message_unicode.end(); it++) {
- codepoint_offsets.push_back(it);
- }
- codepoint_offsets.push_back(message_unicode.end());
- for (const grammar::Derivation& candidate :
- grammar::DeduplicateDerivations(candidates_)) {
- // Check that assertions are fulfilled.
- if (!VerifyAssertions(candidate.match)) {
- continue;
- }
- if (!InstantiateActionsFromMatch(
- codepoint_offsets,
- /*message_index=*/conversation.messages.size() - 1,
- smart_reply_action_type, candidate, entity_data_builder,
- action_suggestions)) {
- return false;
- }
- }
- return true;
- }
-
- private:
- // Handles action rule matches.
- void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
- candidates_.push_back(grammar::Derivation{match, rule_id});
- }
-
- // Instantiates action suggestions from verified and deduplicated rule matches
- // and appends them to the result.
- // Expects the message as codepoints for text extraction from capturing
- // matches as well as the index of the message, for correct span production.
- bool InstantiateActionsFromMatch(
- const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
- int message_index, const std::string& smart_reply_action_type,
- const grammar::Derivation& candidate,
- const MutableFlatbufferBuilder* entity_data_builder,
- std::vector<ActionSuggestion>* result) const {
- const RulesModel_::GrammarRules_::RuleMatch* rule_match =
- grammar_rules_->rule_match()->Get(candidate.rule_id);
- if (rule_match == nullptr || rule_match->action_id() == nullptr) {
- TC3_LOG(ERROR) << "No rule action defined.";
- return false;
- }
-
- // Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- candidate.match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
-
- // Instantiate actions from the rule match.
- for (const uint16 action_id : *rule_match->action_id()) {
- const RulesModel_::RuleActionSpec* action_spec =
- grammar_rules_->actions()->Get(action_id);
- std::vector<ActionSuggestionAnnotation> annotations;
-
- std::unique_ptr<MutableFlatbuffer> entity_data =
- entity_data_builder != nullptr ? entity_data_builder->NewRoot()
- : nullptr;
-
- // Set information from capturing matches.
- if (action_spec->capturing_group() != nullptr) {
- for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
- *action_spec->capturing_group()) {
- auto it = capturing_matches.find(group->group_id());
- if (it == capturing_matches.end()) {
- // Capturing match is not active, skip.
- continue;
- }
-
- const grammar::Match* capturing_match = it->second;
- StringPiece match_text = StringPiece(
- message_codepoint_offsets[capturing_match->codepoint_span.first]
- .utf8_data(),
- message_codepoint_offsets[capturing_match->codepoint_span.second]
- .utf8_data() -
- message_codepoint_offsets[capturing_match->codepoint_span
- .first]
- .utf8_data());
- UnicodeText normalized_match_text =
- NormalizeMatchText(unilib_, group, match_text);
-
- if (!MergeEntityDataFromCapturingMatch(
- group, normalized_match_text.ToUTF8String(),
- entity_data.get())) {
- TC3_LOG(ERROR)
- << "Could not merge entity data from a capturing match.";
- return false;
- }
-
- // Add smart reply suggestions.
- SuggestTextRepliesFromCapturingMatch(entity_data_builder, group,
- normalized_match_text,
- smart_reply_action_type, result);
-
- // Add annotation.
- ActionSuggestionAnnotation annotation;
- if (FillAnnotationFromCapturingMatch(
- /*span=*/capturing_match->codepoint_span, group,
- /*message_index=*/message_index, match_text, &annotation)) {
- if (group->use_annotation_match()) {
- const grammar::AnnotationMatch* annotation_match =
- grammar::SelectFirstOfType<grammar::AnnotationMatch>(
- capturing_match, grammar::Match::kAnnotationMatch);
- if (!annotation_match) {
- TC3_LOG(ERROR) << "Could not get annotation for match.";
- return false;
- }
- annotation.entity = *annotation_match->annotation;
- }
- annotations.push_back(std::move(annotation));
- }
- }
- }
-
- if (action_spec->action() != nullptr) {
- ActionSuggestion suggestion;
- suggestion.annotations = annotations;
- FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
- &suggestion);
- result->push_back(std::move(suggestion));
- }
- }
- return true;
- }
-
- const UniLib& unilib_;
- const RulesModel_::GrammarRules* grammar_rules_;
-
- // All action rule match candidates.
- // Grammar rule matches are recorded, deduplicated, verified and then
- // instantiated.
- std::vector<grammar::Derivation> candidates_;
-};
-} // namespace
GrammarActions::GrammarActions(
const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
@@ -205,15 +32,104 @@
: unilib_(*unilib),
grammar_rules_(grammar_rules),
tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
- lexer_(unilib, grammar_rules->rules()),
entity_data_builder_(entity_data_builder),
- smart_reply_action_type_(smart_reply_action_type),
- rules_locales_(ParseRulesLocales(grammar_rules->rules())) {}
+ analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
+ smart_reply_action_type_(smart_reply_action_type) {}
+bool GrammarActions::InstantiateActionsFromMatch(
+ const grammar::TextContext& text_context, const int message_index,
+ const grammar::Derivation& derivation,
+ std::vector<ActionSuggestion>* result) const {
+ const RulesModel_::GrammarRules_::RuleMatch* rule_match =
+ grammar_rules_->rule_match()->Get(derivation.rule_id);
+ if (rule_match == nullptr || rule_match->action_id() == nullptr) {
+ TC3_LOG(ERROR) << "No rule action defined.";
+ return false;
+ }
+
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
+ for (const grammar::MappingNode* mapping_node :
+ grammar::SelectAllOfType<grammar::MappingNode>(
+ derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
+ capturing_matches[mapping_node->id] = mapping_node;
+ }
+
+ // Instantiate actions from the rule match.
+ for (const uint16 action_id : *rule_match->action_id()) {
+ const RulesModel_::RuleActionSpec* action_spec =
+ grammar_rules_->actions()->Get(action_id);
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+
+ // Set information from capturing matches.
+ if (action_spec->capturing_group() != nullptr) {
+ for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
+ *action_spec->capturing_group()) {
+ auto it = capturing_matches.find(group->group_id());
+ if (it == capturing_matches.end()) {
+ // Capturing match is not active, skip.
+ continue;
+ }
+
+ const grammar::ParseTree* capturing_match = it->second;
+ const UnicodeText match_text =
+ text_context.Span(capturing_match->codepoint_span);
+ UnicodeText normalized_match_text =
+ NormalizeMatchText(unilib_, group, match_text);
+
+ if (!MergeEntityDataFromCapturingMatch(
+ group, normalized_match_text.ToUTF8String(),
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not merge entity data from a capturing match.";
+ return false;
+ }
+
+ // Add smart reply suggestions.
+ SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
+ normalized_match_text,
+ smart_reply_action_type_, result);
+
+ // Add annotation.
+ ActionSuggestionAnnotation annotation;
+ if (FillAnnotationFromCapturingMatch(
+ /*span=*/capturing_match->codepoint_span, group,
+ /*message_index=*/message_index, match_text.ToUTF8String(),
+ &annotation)) {
+ if (group->use_annotation_match()) {
+ std::vector<const grammar::AnnotationNode*> annotations =
+ grammar::SelectAllOfType<grammar::AnnotationNode>(
+ capturing_match, grammar::ParseTree::Type::kAnnotation);
+ if (annotations.size() != 1) {
+ TC3_LOG(ERROR) << "Could not get annotation for match.";
+ return false;
+ }
+ annotation.entity = *annotations.front()->annotation;
+ }
+ annotations.push_back(std::move(annotation));
+ }
+ }
+ }
+
+ if (action_spec->action() != nullptr) {
+ ActionSuggestion suggestion;
+ suggestion.annotations = annotations;
+ FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
+ &suggestion);
+ result->push_back(std::move(suggestion));
+ }
+ }
+ return true;
+}
bool GrammarActions::SuggestActions(
const Conversation& conversation,
std::vector<ActionSuggestion>* result) const {
- if (grammar_rules_->rules()->rules() == nullptr) {
+ if (grammar_rules_->rules()->rules() == nullptr ||
+ conversation.messages.back().text.empty()) {
// Nothing to do.
return true;
}
@@ -225,30 +141,32 @@
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(grammar_rules_->rules(), rules_locales_,
- locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
+ const int message_index = conversation.messages.size() - 1;
+ grammar::TextContext text = analyzer_.BuildTextContextForInput(
+ UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
+ locales);
+ text.annotations = conversation.messages.back().annotations;
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+ StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
+ analyzer_.Parse(text, &arena);
+ // TODO(b/171294882): Return the status here and below.
+ if (!evaluated_derivations.ok()) {
+ TC3_LOG(ERROR) << "Could not run grammar analyzer: "
+ << evaluated_derivations.status().error_message();
+ return false;
}
- GrammarActionsCallbackDelegate callback_handler(&unilib_, grammar_rules_);
- grammar::Matcher matcher(&unilib_, grammar_rules_->rules(), locale_rules,
- &callback_handler);
+ for (const grammar::EvaluatedDerivation& evaluated_derivation :
+ evaluated_derivations.ValueOrDie()) {
+ if (!InstantiateActionsFromMatch(text, message_index,
+ evaluated_derivation.derivation, result)) {
+ TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
+ return false;
+ }
+ }
- const UnicodeText text =
- UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false);
-
- // Run grammar on last message.
- lexer_.Process(text, tokenizer_->Tokenize(text),
- /*annotations=*/&conversation.messages.back().annotations,
- &matcher);
-
- // Populate results.
- return callback_handler.GetActions(conversation, smart_reply_action_type_,
- entity_data_builder_, result);
+ return true;
}
} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions.h b/native/actions/grammar-actions.h
index ea8c2b4..2a1725f 100644
--- a/native/actions/grammar-actions.h
+++ b/native/actions/grammar-actions.h
@@ -23,10 +23,10 @@
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/flatbuffers/mutable.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/types.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/text-context.h"
#include "utils/i18n/locale.h"
-#include "utils/strings/stringpiece.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unilib.h"
@@ -35,8 +35,6 @@
// Grammar backed actions suggestions.
class GrammarActions {
public:
- enum class Callback : grammar::CallbackId { kActionRuleMatch = 1 };
-
explicit GrammarActions(const UniLib* unilib,
const RulesModel_::GrammarRules* grammar_rules,
const MutableFlatbufferBuilder* entity_data_builder,
@@ -47,15 +45,18 @@
std::vector<ActionSuggestion>* result) const;
private:
+ // Creates action suggestions from a grammar match result.
+ bool InstantiateActionsFromMatch(const grammar::TextContext& text_context,
+ int message_index,
+ const grammar::Derivation& derivation,
+ std::vector<ActionSuggestion>* result) const;
+
const UniLib& unilib_;
const RulesModel_::GrammarRules* grammar_rules_;
const std::unique_ptr<Tokenizer> tokenizer_;
- const grammar::Lexer lexer_;
const MutableFlatbufferBuilder* entity_data_builder_;
+ const grammar::Analyzer analyzer_;
const std::string smart_reply_action_type_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
};
} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc
index e738dee..9fe73d4 100644
--- a/native/actions/grammar-actions_test.cc
+++ b/native/actions/grammar-actions_test.cc
@@ -25,6 +25,7 @@
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
#include "utils/grammar/utils/rules.h"
#include "utils/jvm-test-utils.h"
#include "gmock/gmock.h"
@@ -69,14 +70,6 @@
false;
}
- flatbuffers::DetachedBuffer PackRules(
- const RulesModel_::GrammarRulesT& action_grammar_rules) const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- RulesModel_::GrammarRules::Pack(builder, &action_grammar_rules));
- return builder.Release();
- }
-
int AddActionSpec(const std::string& type, const std::string& response_text,
const std::vector<AnnotationSpec>& annotations,
RulesModel_::GrammarRulesT* action_grammar_rules) const {
@@ -156,19 +149,16 @@
rules.Add(
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
/*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
AddSmartReplySpec("Yes?", &action_grammar_rules)},
&action_grammar_rules));
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -186,15 +176,15 @@
action_grammar_rules.rules.reset(new grammar::RulesSetT);
grammar::Rules rules;
- rules.Add("<scripted_reply>",
- {"<^>", "text", "<captured_reply>", "to", "<command>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddCapturingMatchSmartReplySpec(
- /*match_id=*/0, &action_grammar_rules)},
- &action_grammar_rules));
+ rules.Add(
+ "<scripted_reply>",
+ {"<^>", "text", "<captured_reply>", "to", "<command>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddCapturingMatchSmartReplySpec(
+ /*match_id=*/0, &action_grammar_rules)},
+ &action_grammar_rules));
// <command> ::= unsubscribe | cancel | confirm | receive
rules.Add("<command>", {"unsubscribe"});
@@ -212,11 +202,9 @@
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
{
std::vector<ActionSuggestion> result;
@@ -248,8 +236,7 @@
rules.Add(
"<call_phone>", {"please", "dial", "<phone>"},
/*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
/*annotations=*/{{0 /*value*/, "phone"}},
@@ -262,11 +249,9 @@
/*value=*/0);
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -289,16 +274,14 @@
rules.Add(
"<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
/*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
&action_grammar_rules));
rules.Add(
"<toc>", {"<knock>"},
/*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
&action_grammar_rules),
@@ -312,11 +295,9 @@
new LanguageTagT);
action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
// Check default.
{
@@ -367,8 +348,7 @@
rules.Add(
"<track_flight>", {"<flight>", "<context_assertion>?"},
/*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
/*annotations=*/{{0 /*value*/, "flight"}},
@@ -382,11 +362,9 @@
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -424,20 +402,18 @@
action_grammar_rules.actions[spec_id]->action->entity_data->text =
"I have the high ground.";
- rules.Add("<greeting>", {"<^>", "hello", "there", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Add(
+ "<greeting>", {"<^>", "hello", "there", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -505,20 +481,18 @@
/*value=*/location_match_id);
rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
/*value=*/greeting_match_id);
- rules.Add("<test>", {"<^>", "<greeting>", "<$>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Add(
+ "<test>", {"<^>", "<greeting>", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -563,20 +537,18 @@
rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
/*value=*/0);
- rules.Add("<test>", {"<greeting>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Add(
+ "<test>", {"<greeting>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
std::vector<ActionSuggestion> result;
EXPECT_TRUE(grammar_actions.SuggestActions(
@@ -601,17 +573,17 @@
SetTokenizerOptions(&action_grammar_rules);
action_grammar_rules.rules.reset(new grammar::RulesSetT);
grammar::Rules rules;
- rules.Add("<call_phone>", {"please", "dial", "<phone>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
- /*annotations=*/
- {{0 /*value*/, "phone",
- /*use_annotation_match=*/true}},
- &action_grammar_rules)},
- &action_grammar_rules));
+ rules.Add(
+ "<call_phone>", {"please", "dial", "<phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
+ /*annotations=*/
+ {{0 /*value*/, "phone",
+ /*use_annotation_match=*/true}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
rules.AddValueMapping("<phone>", {"<phone_annotation>"},
/*value=*/0);
@@ -627,11 +599,9 @@
action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
ir.GetNonterminalForName("<phone_annotation>");
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get());
std::vector<ActionSuggestion> result;
@@ -667,25 +637,23 @@
rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
/*excluded_nonterminal=*/"<excluded>");
- rules.Add("<set_reminder>",
- {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
- /*callback=*/
- static_cast<grammar::CallbackId>(
- GrammarActions::Callback::kActionRuleMatch),
- /*callback_param=*/
- AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
- /*annotations=*/
- {}, &action_grammar_rules)},
- &action_grammar_rules));
+ rules.Add(
+ "<set_reminder>",
+ {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
+ /*annotations=*/
+ {}, &action_grammar_rules)},
+ &action_grammar_rules));
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules.rules.get());
- flatbuffers::DetachedBuffer serialized_rules =
- PackRules(action_grammar_rules);
- TestGrammarActions grammar_actions(
- unilib_.get(),
- flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
- entity_data_builder_.get());
+ OwnedFlatbuffer<RulesModel_::GrammarRules, std::string> model(
+ PackFlatbuffer<RulesModel_::GrammarRules>(&action_grammar_rules));
+ TestGrammarActions grammar_actions(unilib_.get(), model.get(),
+ entity_data_builder_.get());
{
std::vector<ActionSuggestion> result;
diff --git a/native/actions/regex-actions.cc b/native/actions/regex-actions.cc
index 9a2b5a4..9d91c73 100644
--- a/native/actions/regex-actions.cc
+++ b/native/actions/regex-actions.cc
@@ -93,6 +93,9 @@
bool RegexActions::InitializeRulesModel(
const RulesModel* rules, ZlibDecompressor* decompressor,
std::vector<CompiledRule>* compiled_rules) const {
+ if (rules->regex_rule() == nullptr) {
+ return true;
+ }
for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
UncompressMakeRegexPattern(
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
new file mode 100644
index 0000000..ae6dc60
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 5d265c1..52f932e 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index e6d8758..6145540 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index 708b0be..de8520a 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/utils.cc b/native/actions/utils.cc
index 53714d6..648f04d 100644
--- a/native/actions/utils.cc
+++ b/native/actions/utils.cc
@@ -77,13 +77,18 @@
const UniLib& unilib,
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text) {
- UnicodeText normalized_match_text =
- UTF8ToUnicodeText(match_text, /*do_copy=*/false);
- if (group->normalization_options() != nullptr) {
- normalized_match_text = NormalizeText(
- unilib, group->normalization_options(), normalized_match_text);
+ return NormalizeMatchText(unilib, group,
+ UTF8ToUnicodeText(match_text, /*do_copy=*/false));
+}
+
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText match_text) {
+ if (group->normalization_options() == nullptr) {
+ return match_text;
}
- return normalized_match_text;
+ return NormalizeText(unilib, group->normalization_options(), match_text);
}
bool FillAnnotationFromCapturingMatch(
diff --git a/native/actions/utils.h b/native/actions/utils.h
index d8bdec2..4838464 100644
--- a/native/actions/utils.h
+++ b/native/actions/utils.h
@@ -50,6 +50,11 @@
const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
StringPiece match_text);
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText match_text);
+
// Fills the fields in an annotation from a capturing match.
bool FillAnnotationFromCapturingMatch(
const CodepointSpan& span,
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index a2d8281..2635820 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -20,12 +20,14 @@
#include <cmath>
#include <cstddef>
#include <iterator>
+#include <limits>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "annotator/collections.h"
+#include "annotator/datetime/regex-parser.h"
#include "annotator/flatbuffer-utils.h"
#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
@@ -33,7 +35,9 @@
#include "utils/base/logging.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
#include "utils/checksum.h"
+#include "utils/i18n/locale-list.h"
#include "utils/i18n/locale.h"
#include "utils/math/softmax.h"
#include "utils/normalization.h"
@@ -105,12 +109,8 @@
}
// Returns whether the provided input is valid:
-// * Valid utf8 text.
// * Sane span indices.
bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
- if (!context.is_valid()) {
- return false;
- }
return (span.first >= 0 && span.first < span.second &&
span.second <= context.size_codepoints());
}
@@ -127,37 +127,6 @@
return ints_set;
}
-DateAnnotationOptions ToDateAnnotationOptions(
- const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
- const std::string& reference_timezone, const int64 reference_time_ms_utc) {
- DateAnnotationOptions result_annotation_options;
- result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
- result_annotation_options.reference_timezone = reference_timezone;
- if (fb_annotation_options != nullptr) {
- result_annotation_options.enable_special_day_offset =
- fb_annotation_options->enable_special_day_offset();
- result_annotation_options.merge_adjacent_components =
- fb_annotation_options->merge_adjacent_components();
- result_annotation_options.enable_date_range =
- fb_annotation_options->enable_date_range();
- result_annotation_options.include_preposition =
- fb_annotation_options->include_preposition();
- if (fb_annotation_options->extra_requested_dates() != nullptr) {
- for (const auto& extra_requested_date :
- *fb_annotation_options->extra_requested_dates()) {
- result_annotation_options.extra_requested_dates.push_back(
- extra_requested_date->str());
- }
- }
- if (fb_annotation_options->ignored_spans() != nullptr) {
- for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
- result_annotation_options.ignored_spans.push_back(ignored_span->str());
- }
- }
- }
- return result_annotation_options;
-}
-
} // namespace
tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
@@ -446,25 +415,9 @@
return;
}
}
- if (model_->grammar_datetime_model() &&
- model_->grammar_datetime_model()->datetime_rules()) {
- cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
- unilib_,
- /*tokenizer_options=*/
- model_->grammar_datetime_model()->grammar_tokenizer_options(),
- calendarlib_,
- /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
- model_->grammar_datetime_model()->target_classification_score(),
- model_->grammar_datetime_model()->priority_score()));
- if (!cfg_datetime_parser_) {
- TC3_LOG(ERROR) << "Could not initialize context free grammar based "
- "datetime parser.";
- return;
- }
- }
if (model_->datetime_model()) {
- datetime_parser_ = DatetimeParser::Instance(
+ datetime_parser_ = RegexDatetimeParser::Instance(
model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
if (!datetime_parser_) {
TC3_LOG(ERROR) << "Could not initialize datetime parser.";
@@ -662,7 +615,11 @@
return true;
}
-void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
+bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
+ if (lang_id == nullptr) {
+ return false;
+ }
+
lang_id_ = lang_id;
if (lang_id_ != nullptr && model_->translate_annotator_options() &&
model_->translate_annotator_options()->enabled()) {
@@ -671,6 +628,7 @@
} else {
translate_annotator_.reset(nullptr);
}
+ return true;
}
bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
@@ -854,6 +812,11 @@
CodepointSpan Annotator::SuggestSelection(
const std::string& context, CodepointSpan click_indices,
const SelectionOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
+ return {};
+ }
+
CodepointSpan original_click_indices = click_indices;
if (!initialized_) {
TC3_LOG(ERROR) << "Not initialized";
@@ -885,6 +848,11 @@
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return original_click_indices;
+ }
+
if (!IsValidSpanInput(context_unicode, click_indices)) {
TC3_VLOG(1)
<< "Trying to run SuggestSelection with invalid input, indices: "
@@ -987,9 +955,11 @@
candidates.annotated_spans[0].push_back(grammar_suggested_span);
}
- if (pod_ner_annotator_ != nullptr && options.use_pod_ner) {
- candidates.annotated_spans[0].push_back(
- pod_ner_annotator_->SuggestSelection(context_unicode, click_indices));
+ AnnotatedSpan pod_ner_suggested_span;
+ if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
+ pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
+ &pod_ner_suggested_span)) {
+ candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
}
if (experimental_annotator_ != nullptr) {
@@ -1009,7 +979,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return original_click_indices;
@@ -1034,7 +1004,8 @@
!filtered_collections_selection_.empty()) {
if (!ModelClassifyText(
context, detected_text_language_tags,
- candidates.annotated_spans[0][i].span, &interpreter_manager,
+ candidates.annotated_spans[0][i].span, options,
+ &interpreter_manager,
/*embedding_cache=*/nullptr,
&candidates.annotated_spans[0][i].classification)) {
return original_click_indices;
@@ -1079,8 +1050,8 @@
const std::vector<AnnotatedSpan>& candidates, const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager, std::vector<int>* result) const {
+ const BaseOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
for (int i = 0; i < candidates.size();) {
@@ -1092,8 +1063,8 @@
std::vector<int> candidate_indices;
if (!ResolveConflict(context, cached_tokens, candidates,
detected_text_language_tags, i,
- first_non_overlapping, annotation_usecase,
- interpreter_manager, &candidate_indices)) {
+ first_non_overlapping, options, interpreter_manager,
+ &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -1159,7 +1130,7 @@
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags, int start_index,
- int end_index, AnnotationUsecase annotation_usecase,
+ int end_index, const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
@@ -1180,7 +1151,7 @@
// classification to determine its priority:
std::vector<ClassificationResult> classification;
if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- candidates[i].span, interpreter_manager,
+ candidates[i].span, options, interpreter_manager,
/*embedding_cache=*/nullptr, &classification)) {
return false;
}
@@ -1222,11 +1193,13 @@
}
const bool needs_conflict_resolution =
- annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
- (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
+ options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_SMART ||
+ (options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_RAW &&
do_conflict_resolution_in_raw_mode_);
if (needs_conflict_resolution &&
- DoSourcesConflict(annotation_usecase, source_set_pair.first,
+ DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
candidates[considered_candidate].source) &&
DoesCandidateConflict(considered_candidate, candidates,
source_set_pair.second)) {
@@ -1376,12 +1349,12 @@
bool Annotator::ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
return ModelClassifyText(context, {}, detected_text_language_tags,
- selection_indices, interpreter_manager,
+ selection_indices, options, interpreter_manager,
embedding_cache, classification_results);
}
@@ -1451,20 +1424,20 @@
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
std::vector<Token> tokens;
return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- selection_indices, interpreter_manager,
+ selection_indices, options, interpreter_manager,
embedding_cache, classification_results, &tokens);
}
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
@@ -1608,14 +1581,14 @@
return true;
}
} else if (top_collection == Collections::Dictionary()) {
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ if ((options.use_vocab_annotator && vocab_annotator_) ||
+ !Locale::IsAnyLocaleSupported(detected_text_language_tags,
dictionary_locales_,
/*default_value=*/false)) {
*classification_results = {{Collections::Other(), 1.0}};
return true;
}
}
-
*classification_results = {{top_collection, /*arg_score=*/1.0,
/*arg_priority_score=*/scores[best_score_index]}};
@@ -1694,7 +1667,7 @@
const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options,
std::vector<ClassificationResult>* classification_results) const {
- if (!datetime_parser_ && !cfg_datetime_parser_) {
+ if (!datetime_parser_) {
return true;
}
@@ -1702,35 +1675,20 @@
UTF8ToUnicodeText(context, /*do_copy=*/false)
.UTF8Substring(selection_indices.first, selection_indices.second);
- std::vector<DatetimeParseResultSpan> datetime_spans;
-
- if (cfg_datetime_parser_) {
- if (!(model_->grammar_datetime_model()->enabled_modes() &
- ModeFlag_CLASSIFICATION)) {
- return true;
- }
- std::vector<Locale> parsed_locales;
- ParseLocales(options.locales, &parsed_locales);
- cfg_datetime_parser_->Parse(
- selection_text,
- ToDateAnnotationOptions(
- model_->grammar_datetime_model()->annotation_options(),
- options.reference_timezone, options.reference_time_ms_utc),
- parsed_locales, &datetime_spans);
+ LocaleList locale_list = LocaleList::ParseFrom(options.locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, locale_list,
+ ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
+ /*anchor_start_end=*/true);
+ if (!result_status.ok()) {
+ TC3_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
}
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
- options.reference_timezone, options.locales,
- ModeFlag_CLASSIFICATION,
- options.annotation_usecase,
- /*anchor_start_end=*/true, &datetime_spans)) {
- TC3_LOG(ERROR) << "Error during parsing datetime.";
- return false;
- }
- }
-
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
// Only consider the result valid if the selection and extracted datetime
// spans exactly match.
if (CodepointSpan(datetime_span.span.first + selection_indices.first,
@@ -1755,6 +1713,10 @@
std::vector<ClassificationResult> Annotator::ClassifyText(
const std::string& context, const CodepointSpan& selection_indices,
const ClassificationOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
+ return {};
+ }
if (!initialized_) {
TC3_LOG(ERROR) << "Not initialized";
return {};
@@ -1782,8 +1744,15 @@
return {};
}
- if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
- selection_indices)) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
+ if (!IsValidSpanInput(context_unicode, selection_indices)) {
TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
<< selection_indices.first << " " << selection_indices.second;
return {};
@@ -1857,9 +1826,6 @@
candidates.back().source = AnnotatedSpan::Source::DATETIME;
}
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
// Try the number annotator.
// TODO(b/126579108): Propagate error status.
ClassificationResult number_annotator_result;
@@ -1905,7 +1871,7 @@
}
ClassificationResult vocab_annotator_result;
- if (vocab_annotator_ &&
+ if (vocab_annotator_ && options.use_vocab_annotator &&
vocab_annotator_->ClassifyText(
context_unicode, selection_indices, detected_text_language_tags,
options.trigger_dictionary_on_beginner_words,
@@ -1929,7 +1895,7 @@
std::vector<Token> tokens;
if (!ModelClassifyText(
context, /*cached_tokens=*/{}, detected_text_language_tags,
- selection_indices, &interpreter_manager,
+ selection_indices, options, &interpreter_manager,
/*embedding_cache=*/nullptr, &model_results, &tokens)) {
return {};
}
@@ -1939,7 +1905,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
@@ -1969,8 +1935,8 @@
bool Annotator::ModelAnnotate(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
+ const BaseOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
return true;
@@ -2003,24 +1969,26 @@
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
- *tokens = selection_feature_processor_->Tokenize(line_str);
+ std::vector<Token> line_tokens;
+ line_tokens = selection_feature_processor_->Tokenize(line_str);
+
selection_feature_processor_->RetokenizeAndFindClick(
line_str, {0, std::distance(line.first, line.second)},
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- tokens,
+ &line_tokens,
/*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0,
- static_cast<TokenIndex>(tokens->size())};
+ const TokenSpan full_line_span = {
+ 0, static_cast<TokenIndex>(line_tokens.size())};
// TODO(zilka): Add support for greater granularity of this check.
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, full_line_span)) {
+ line_tokens, full_line_span)) {
continue;
}
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
- *tokens, full_line_span,
+ line_tokens, full_line_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
/*embedding_cache=*/nullptr,
@@ -2032,7 +2000,7 @@
}
std::vector<TokenSpan> local_chunks;
- if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
+ if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
interpreter_manager->SelectionInterpreter(),
*cached_features, &local_chunks)) {
TC3_LOG(ERROR) << "Could not chunk.";
@@ -2043,7 +2011,7 @@
for (const TokenSpan& chunk : local_chunks) {
CodepointSpan codepoint_span =
selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+ line_str, TokenSpanToCodepointSpan(line_tokens, chunk));
if (model_->selection_options()->strip_unpaired_brackets()) {
codepoint_span =
StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
@@ -2052,9 +2020,10 @@
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
- codepoint_span, interpreter_manager,
- &embedding_cache, &classification)) {
+ if (!ModelClassifyText(line_str, line_tokens,
+ detected_text_language_tags, codepoint_span,
+ options, interpreter_manager, &embedding_cache,
+ &classification)) {
TC3_LOG(ERROR) << "Could not classify text: "
<< (codepoint_span.first + offset) << " "
<< (codepoint_span.second + offset);
@@ -2072,6 +2041,16 @@
}
}
}
+
+ // If we are going line-by-line, we need to insert the tokens for each line.
+ // But if not, we can optimize and just std::move the current line vector to
+ // the output.
+ if (selection_feature_processor_->GetOptions()
+ ->only_use_line_with_click()) {
+ tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
+ } else {
+ *tokens = std::move(line_tokens);
+ }
}
return true;
}
@@ -2134,10 +2113,6 @@
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- if (!context_unicode.is_valid()) {
- return Status(StatusCode::INVALID_ARGUMENT,
- "Context string isn't valid UTF8.");
- }
std::vector<Locale> detected_text_language_tags;
if (!ParseLocales(options.detected_text_language_tags,
@@ -2157,16 +2132,34 @@
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ const bool is_raw_usecase =
+ options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
// Annotate with the selection model.
+ const bool model_annotations_enabled =
+ !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
std::vector<Token> tokens;
- if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
- &tokens, candidates)) {
+ if (model_annotations_enabled &&
+ !ModelAnnotate(context, detected_text_language_tags, options,
+ &interpreter_manager, &tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
+ } else if (!model_annotations_enabled) {
+ // If the ML model didn't run, we need to tokenize to support the other
+ // annotators that depend on the tokens.
+ // Optimization could be made to only do this when an annotator that uses
+ // the tokens is enabled, but it's unclear if the added complexity is worth
+ // it.
+ if (selection_feature_processor_ != nullptr) {
+ tokens = selection_feature_processor_->Tokenize(context_unicode);
+ }
}
- const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
// Annotate with the regular expression models.
- if (!RegexChunk(
+ const bool regex_annotations_enabled =
+ !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
+ if (regex_annotations_enabled &&
+ !RegexChunk(
UTF8ToUnicodeText(context, /*do_copy=*/false),
annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
is_entity_type_enabled, options.annotation_usecase, candidates)) {
@@ -2174,6 +2167,8 @@
}
// Annotate with the datetime model.
+ // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
+ // relatively slow for some clients.
if ((is_entity_type_enabled(Collections::Date()) ||
is_entity_type_enabled(Collections::DateTime())) &&
!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
@@ -2185,27 +2180,26 @@
}
// Annotate with the contact engine.
- if (contact_engine_ &&
+ const bool contact_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
+ if (contact_annotations_enabled && contact_engine_ &&
!contact_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
}
// Annotate with the installed app engine.
- if (installed_app_engine_ &&
+ const bool app_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::App());
+ if (app_annotations_enabled && installed_app_engine_ &&
!installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run installed app engine Chunk.");
}
// Annotate with the number annotator.
- bool number_annotations_enabled = true;
- // Disable running the annotator in RAW mode if the number/percentage
- // annotations are not explicitly requested.
- if (options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
- !is_entity_type_enabled(Collections::Number()) &&
- !is_entity_type_enabled(Collections::Percentage())) {
- number_annotations_enabled = false;
- }
+ const bool number_annotations_enabled =
+ !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
+ is_entity_type_enabled(Collections::Percentage()));
if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
candidates)) {
@@ -2214,8 +2208,9 @@
}
// Annotate with the duration annotator.
- if (is_entity_type_enabled(Collections::Duration()) &&
- duration_annotator_ != nullptr &&
+ const bool duration_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
+ if (duration_annotations_enabled && duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
options.annotation_usecase, candidates)) {
return Status(StatusCode::INTERNAL,
@@ -2223,8 +2218,9 @@
}
// Annotate with the person name engine.
- if (is_entity_type_enabled(Collections::PersonName()) &&
- person_name_engine_ &&
+ const bool person_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
+ if (person_annotations_enabled && person_name_engine_ &&
!person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run person name engine Chunk.");
@@ -2238,13 +2234,19 @@
}
// Annotate with the POD NER annotator.
- if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
+ const bool pod_ner_annotations_enabled =
+ !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
+ if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
+ options.use_pod_ner &&
!pod_ner_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
}
// Annotate with the vocab annotator.
- if (vocab_annotator_ != nullptr &&
+ const bool vocab_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
+ if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
+ options.use_vocab_annotator &&
!vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
options.trigger_dictionary_on_beginner_words,
candidates)) {
@@ -2279,7 +2281,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
}
@@ -2330,15 +2332,21 @@
std::vector<std::string> text_to_annotate;
text_to_annotate.reserve(string_fragments.size());
+ std::vector<FragmentMetadata> fragment_metadata;
+ fragment_metadata.reserve(string_fragments.size());
for (const auto& string_fragment : string_fragments) {
text_to_annotate.push_back(string_fragment.text);
+ fragment_metadata.push_back(
+ {.relative_bounding_box_top = string_fragment.bounding_box_top,
+ .relative_bounding_box_height = string_fragment.bounding_box_height});
}
// KnowledgeEngine is special, because it supports annotation of multiple
// fragments at once.
if (knowledge_engine_ &&
!knowledge_engine_
- ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
+ ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
+ options.annotation_usecase,
options.location_context, options.permissions,
options.annotate_mode, &annotation_candidates)
.ok()) {
@@ -2391,6 +2399,18 @@
std::vector<AnnotatedSpan> Annotator::Annotate(
const std::string& context, const AnnotationOptions& options) const {
+ if (context.size() > std::numeric_limits<int>::max()) {
+ TC3_LOG(ERROR) << "Rejecting too long input.";
+ return {};
+ }
+
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
std::vector<InputFragment> string_fragments;
string_fragments.push_back({.text = context});
StatusOr<Annotations> annotations =
@@ -2688,6 +2708,58 @@
return true;
}
+bool Annotator::IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->classification_feature_options() == nullptr ||
+ model_->classification_feature_options()->collections() == nullptr) {
+ return false;
+ }
+ for (int i = 0;
+ i < model_->classification_feature_options()->collections()->size();
+ i++) {
+ if (is_entity_type_enabled(model_->classification_feature_options()
+ ->collections()
+ ->Get(i)
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->regex_model() == nullptr ||
+ model_->regex_model()->patterns() == nullptr) {
+ return false;
+ }
+ for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
+ if (is_entity_type_enabled(model_->regex_model()
+ ->patterns()
+ ->Get(i)
+ ->collection_name()
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (pod_ner_annotator_ == nullptr) {
+ return false;
+ }
+
+ for (const std::string& collection :
+ pod_ner_annotator_->GetSupportedCollections()) {
+ if (is_entity_type_enabled(collection)) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
bool is_serialized_entity_data_enabled,
@@ -3011,31 +3083,21 @@
AnnotationUsecase annotation_usecase,
bool is_serialized_entity_data_enabled,
std::vector<AnnotatedSpan>* result) const {
- std::vector<DatetimeParseResultSpan> datetime_spans;
- if (cfg_datetime_parser_) {
- if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
- return true;
- }
- std::vector<Locale> parsed_locales;
- ParseLocales(locales, &parsed_locales);
- cfg_datetime_parser_->Parse(
- context_unicode.ToUTF8String(),
- ToDateAnnotationOptions(
- model_->grammar_datetime_model()->annotation_options(),
- reference_timezone, reference_time_ms_utc),
- parsed_locales, &datetime_spans);
+ if (!datetime_parser_) {
+ return true;
+ }
+ LocaleList locale_list = LocaleList::ParseFrom(locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locale_list, mode,
+ annotation_usecase,
+ /*anchor_start_end=*/false);
+ if (!result_status.ok()) {
+ return false;
}
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
- reference_timezone, locales, mode,
- annotation_usecase,
- /*anchor_start_end=*/false, &datetime_spans)) {
- return false;
- }
- }
-
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
AnnotatedSpan annotated_span;
annotated_span.span = datetime_span.span;
for (const DatetimeParseResult& parse_result : datetime_span.data) {
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index f55be4d..5397f56 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -30,7 +30,6 @@
#include "annotator/duration/duration.h"
#include "annotator/experimental/experimental.h"
#include "annotator/feature-processor.h"
-#include "annotator/grammar/dates/cfg-datetime-annotator.h"
#include "annotator/grammar/grammar-annotator.h"
#include "annotator/installed_app/installed-app-engine.h"
#include "annotator/knowledge/knowledge-engine.h"
@@ -46,6 +45,7 @@
#include "annotator/zlib-utils.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
@@ -174,7 +174,7 @@
bool InitializeExperimentalAnnotators();
// Sets up the lang-id instance that should be used.
- void SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id);
+ bool SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id);
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
@@ -261,7 +261,7 @@
const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
@@ -273,7 +273,7 @@
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags,
int start_index, int end_index,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
@@ -292,7 +292,7 @@
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
@@ -302,7 +302,7 @@
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -311,7 +311,7 @@
bool ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -342,6 +342,7 @@
// reuse.
bool ModelAnnotate(const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
@@ -440,8 +441,6 @@
std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
std::unique_ptr<const DatetimeParser> datetime_parser_;
- std::unique_ptr<const dates::CfgDatetimeAnnotator> cfg_datetime_parser_;
-
std::unique_ptr<const GrammarAnnotator> grammar_annotator_;
std::string owned_buffer_;
@@ -485,6 +484,18 @@
std::string* quantity,
int* exponent) const;
+ // Returns true if any of the ff-model entity types is enabled.
+ bool IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the regex entity types is enabled.
+ bool IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the POD NER entity types is enabled.
+ bool IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
bool enabled_for_annotation_ = false;
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 8d5ad33..7f095f9 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -205,6 +205,22 @@
env, classification_result.contact_phone_number.c_str()));
}
+ ScopedLocalRef<jstring> contact_account_type;
+ if (!classification_result.contact_account_type.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_type,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_type.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_account_name;
+ if (!classification_result.contact_account_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_name.c_str()));
+ }
+
ScopedLocalRef<jstring> contact_id;
if (!classification_result.contact_id.empty()) {
TC3_ASSIGN_OR_RETURN(
@@ -275,7 +291,8 @@
row_datetime_parse.get(), serialized_knowledge_result.get(),
contact_name.get(), contact_given_name.get(), contact_family_name.get(),
contact_nickname.get(), contact_email_address.get(),
- contact_phone_number.get(), contact_id.get(), app_name.get(),
+ contact_phone_number.get(), contact_account_type.get(),
+ contact_account_name.get(), contact_id.get(), app_name.get(),
app_package_name.get(), extras.get(), serialized_entity_data.get(),
remote_action_templates_result.get(), classification_result.duration_ms,
classification_result.numeric_value,
@@ -304,13 +321,23 @@
JniHelper::GetMethodID(
env, result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
- "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ "$DatetimeResult;"
+ "[B"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "[L" TC3_PACKAGE_PATH "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";"
+ "[B"
+ "[L" TC3_PACKAGE_PATH "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";"
+ "JJD)V"));
TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
JniHelper::GetMethodID(env, datetime_parse_class.get(),
"<init>", "(JI)V"));
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index a049a22..a6f636f 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -126,6 +126,13 @@
TC3_ASSIGN_OR_RETURN(bool use_pod_ner, JniHelper::CallBooleanMethod(
env, joptions, get_use_pod_ner));
+ // .getUseVocabAnnotator()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_use_vocab_annotator,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUseVocabAnnotator", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool use_vocab_annotator,
+ JniHelper::CallBooleanMethod(env, joptions, get_use_vocab_annotator));
T options;
TC3_ASSIGN_OR_RETURN(options.locales,
JStringToUtf8String(env, locales.get()));
@@ -140,6 +147,7 @@
options.location_context = {user_location_lat, user_location_lng,
user_location_accuracy_meters};
options.use_pod_ner = use_pod_ner;
+ options.use_vocab_annotator = use_vocab_annotator;
return options;
}
} // namespace
@@ -419,6 +427,24 @@
.reference_timezone = reference_timezone};
}
+ // .getBoundingBoxHeight()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_height,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxHeight", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_height,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_height));
+
+ fragment.bounding_box_height = bounding_box_height;
+
+ // .getBoundingBoxTop()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_top,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxTop", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_top,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_top));
+ fragment.bounding_box_top = bounding_box_top;
return fragment;
}
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
index a8fda33..3ed91e1 100644
--- a/native/annotator/annotator_test-include.cc
+++ b/native/annotator/annotator_test-include.cc
@@ -22,9 +22,11 @@
#include <type_traits>
#include "annotator/annotator.h"
+#include "annotator/collections.h"
#include "annotator/model_generated.h"
#include "annotator/test-utils.h"
#include "annotator/types-test-util.h"
+#include "annotator/types.h"
#include "utils/grammar/utils/rules.h"
#include "utils/testing/annotator.h"
#include "lang_id/fb_model/lang-id-from-fb.h"
@@ -42,41 +44,8 @@
std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
-std::string GetModelWithGrammarPath() {
- return GetModelPath() + "test_grammar_model.fb";
-}
-
-void FillDatetimeAnnotationOptionsToModel(
- ModelT* unpacked_model, const DateAnnotationOptions& options) {
- unpacked_model->grammar_datetime_model->annotation_options.reset(
- new GrammarDatetimeModel_::AnnotationOptionsT);
- unpacked_model->grammar_datetime_model->annotation_options
- ->enable_date_range = options.enable_date_range;
- unpacked_model->grammar_datetime_model->annotation_options
- ->include_preposition = options.include_preposition;
- unpacked_model->grammar_datetime_model->annotation_options
- ->merge_adjacent_components = options.merge_adjacent_components;
- unpacked_model->grammar_datetime_model->annotation_options
- ->enable_special_day_offset = options.enable_special_day_offset;
- for (const auto& extra_dates_rule_id : options.extra_requested_dates) {
- unpacked_model->grammar_datetime_model->annotation_options
- ->extra_requested_dates.push_back(extra_dates_rule_id);
- }
-}
-
-std::string GetGrammarModel(const DateAnnotationOptions& options) {
- const std::string test_model = ReadFile(GetModelWithGrammarPath());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- FillDatetimeAnnotationOptionsToModel(unpacked_model.get(), options);
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-std::string GetGrammarModel() {
- DateAnnotationOptions options;
- return GetGrammarModel(options);
+std::string GetModelWithVocabPath() {
+ return GetModelPath() + "test_vocab_model.fb";
}
void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
@@ -150,31 +119,57 @@
VerifyClassifyText(classifier.get());
}
-TEST_F(AnnotatorTest, ClassifyTextWithGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyText(std::move(classifier.get()));
-}
-
TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
ASSERT_TRUE(classifier);
- EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 6})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
ClassificationOptions classification_options;
classification_options.detected_text_language_tags = "en";
EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
- "isotope", {0, 6}, classification_options)));
+ "isotope", {0, 7}, classification_options)));
classification_options.detected_text_language_tags = "uz";
EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
- "isotope", {0, 6}, classification_options)));
+ "isotope", {0, 7}, classification_options)));
}
+TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+ classification_options.use_vocab_annotator = true;
+
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 7}, classification_options)));
+}
+
+#ifdef TC3_VOCAB_ANNOTATOR_IMPL
+TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+
+ // The FFModel model does not annotate "integrity" as "dictionary", but the
+ // vocab annotator does. So we can use that to check if the vocab annotator is
+ // in use.
+ classification_options.use_vocab_annotator = true;
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "integrity", {0, 9}, classification_options)));
+ classification_options.use_vocab_annotator = false;
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "integrity", {0, 9}, classification_options)));
+}
+#endif // TC3_VOCAB_ANNOTATOR_IMPL
+
TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
const std::string test_model = ReadFile(GetTestModelPath());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -1296,13 +1291,6 @@
VerifyAnnotatesDurationsInRawMode(classifier.get());
}
-TEST_F(AnnotatorTest, AnnotatesDurationsInRawModeWithDatetimeGrammar) {
- const std::string test_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
- VerifyAnnotatesDurationsInRawMode(classifier.get());
-}
-
TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
@@ -1602,14 +1590,6 @@
VerifyClassifyTextDateInZurichTimezone(classifier.get());
}
-TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyTextDateInZurichTimezone(classifier.get());
-}
-
void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
EXPECT_TRUE(classifier);
ClassificationOptions options;
@@ -1629,14 +1609,6 @@
VerifyClassifyTextDateInLATimezone(classifier.get());
}
-TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyTextDateInLATimezone(classifier.get());
-}
-
void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
EXPECT_TRUE(classifier);
ClassificationOptions options;
@@ -1658,14 +1630,6 @@
VerifyClassifyTextDateOnAotherLine(classifier.get());
}
-TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyTextDateOnAotherLine(classifier.get());
-}
-
void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
const Annotator* classifier) {
EXPECT_TRUE(classifier);
@@ -1688,15 +1652,6 @@
VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
}
-TEST_F(AnnotatorTest,
- ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
-}
-
TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
@@ -1795,7 +1750,7 @@
rules.Add(
"<famous_person>", {"<tv_detective>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/0 /* rule classification result */);
// Set result.
@@ -1837,10 +1792,11 @@
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -1857,10 +1813,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
@@ -1875,10 +1832,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
@@ -1893,10 +1851,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
@@ -1912,10 +1871,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
// Picks the first and the last annotations because they do not overlap.
EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
@@ -1947,10 +1907,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({2}));
}
@@ -1967,10 +1928,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales,
- AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
@@ -1985,9 +1947,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
}
@@ -2002,9 +1966,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
}
@@ -2020,9 +1986,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
}
@@ -2036,9 +2004,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -2058,9 +2028,11 @@
}};
std::vector<Locale> locales = {Locale::FromBCP47("en")};
+ BaseOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
- locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ locales, options,
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
}
@@ -2098,13 +2070,6 @@
VerifyLongInput(classifier.get());
}
-TEST_F(AnnotatorTest, LongInputWithDatetimeGrammar) {
- const std::string test_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
- VerifyLongInput(classifier.get());
-}
-
// These coarse tests are there only to make sure the execution happens in
// reasonable amount of time.
TEST_F(AnnotatorTest, LongInputNoResultCheck) {
@@ -2551,15 +2516,6 @@
VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
}
-TEST_F(AnnotatorTest,
- ClassifyTextOutputsDatetimeEntityDataWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
-}
-
void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
EXPECT_TRUE(classifier);
std::vector<AnnotatedSpan> result;
@@ -2612,14 +2568,6 @@
VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
}
-TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatetimeGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
- VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
-}
-
TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
@@ -2864,42 +2812,6 @@
}
} // namespace
-TEST_F(AnnotatorTest, WorksWithBothRegexAndGrammarDatetimeAtOnce) {
- // This creates a model that has both regex and grammar datetime. The regex
- // one is broken, and only matches strings "THIS_MATCHES_IN_REGEX_MODEL",
- // so that we can use it for testing that both models are used correctly.
- const std::string test_model = ReadFile(GetModelWithGrammarPath());
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
- AddDummyRegexDatetimeModel(unpacked_model.get());
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize(), unilib_.get(), calendarlib_.get());
-
- ASSERT_TRUE(classifier);
- EXPECT_THAT(classifier->Annotate("THIS_MATCHES_IN_REGEX_MODEL, and this in "
- "grammars: February 2, 2020"),
- ElementsAreArray({
- // Regex model match
- IsAnnotatedSpan(0, 27, "date"),
- // Grammar model match
- IsAnnotatedSpan(51, 67, "date"),
- }));
- EXPECT_EQ(
- FirstResult(classifier->ClassifyText(
- "THIS_MATCHES_IN_REGEX_MODEL, and this in grammars: February 2, 2020",
- {0, 27})),
- "date");
- EXPECT_EQ(
- FirstResult(classifier->ClassifyText(
- "THIS_MATCHES_IN_REGEX_MODEL, and this in grammars: February 2, 2020",
- {51, 67})),
- "date");
-}
-
TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
std::unique_ptr<Annotator> classifier = Annotator::FromPath(
GetTestModelPath(), unilib_.get(), calendarlib_.get());
@@ -2920,83 +2832,87 @@
EXPECT_EQ(num_phones, 1);
}
-TEST_F(AnnotatorTest, AnnotateUsingGrammar) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
+// This test tests the optimizations in Annotator, which make some of the
+// annotators not run in the RAW mode when not requested. We test here that the
+// results indeed don't contain such annotations. However, this is a bick hacky,
+// since one could also add post-filtering, in which case these tests would
+// trivially pass.
+TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
ASSERT_TRUE(classifier);
- const std::string test_string =
- "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
- "number is 853 225 3556";
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({
- IsAnnotatedSpan(19, 24, "date"),
- IsAnnotatedSpan(28, 55, "address"),
- IsAnnotatedSpan(79, 91, "phone"),
- }));
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ // Requesting a non-existing type to avoid overlap with existing types.
+ options.entity_types.insert("some_unknown_entity_type");
+
+ // Normally, the following command would produce the following annotations:
+ // Span(19, 24, date, 1.000000),
+ // Span(53, 56, number, 1.000000),
+ // Span(53, 80, address, 1.000000),
+ // Span(128, 142, phone, 1.000000),
+ // Span(129, 132, number, 1.000000),
+ // Span(192, 200, phone, 1.000000),
+ // Span(192, 206, datetime, 1.000000),
+ // Span(246, 253, number, 1.000000),
+ // Span(246, 253, phone, 1.000000),
+ // Span(292, 293, number, 1.000000),
+ // Span(292, 301, duration, 1.000000) }
+ // But because of the optimizations, it doesn't produce anything, since
+ // we didn't request any of these entities.
+ EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
+ 350 Third Street, Cambridge
+ my phone number is (853) 225-3556
+ this is when we met: 1.9.2021 13:00
+ my number: 1234567
+ duration: 3 minutes
+ )--",
+ options),
+ IsEmpty());
}
-TEST_F(AnnotatorTest, AnnotateGrammarPriority) {
- const std::string grammar_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier =
- Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
- unilib_.get(), calendarlib_.get());
-
+TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
ASSERT_TRUE(classifier);
+ struct Example {
+ std::string collection;
+ std::string text;
+ };
- // "May 8, 2015, 545" have two annotation
- // 1) {May 8, 2015}, 545 -> Datetime (High Priority Score)
- // 2) May 8, {2015, 545} -> phone number (Low Priority Score)
- const std::string test_string =
- "May 8, 2015, 545, shares Comment Now that Joss Whedon";
- EXPECT_THAT(classifier->Annotate(test_string),
- ElementsAreArray({IsAnnotatedSpan(0, 11, "date")}));
-}
+ // These examples contain one example per annotator, to check that each of
+ // the annotators can work in the RAW mode on its own.
+ //
+ // WARNING: This list doesn't contain yet entries for the app, contact, and
+ // person annotators. Hopefully this won't be needed once b/155214735 is
+ // fixed and the piping shared across annotators.
+ std::vector<Example> examples{
+ // ML Model.
+ {.collection = Collections::Address(),
+ .text = "... 350 Third Street, Cambridge ..."},
+ // Datetime annotator.
+ {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
+ // Duration annotator.
+ {.collection = Collections::Duration(),
+ .text = "... 3 hours and 9 seconds ..."},
+ // Regex annotator.
+ {.collection = Collections::Email(),
+ .text = "... platypus@theanimal.org ..."},
+ // Number annotator.
+ {.collection = Collections::Number(), .text = "... 100 ..."},
+ };
-TEST_F(AnnotatorTest, AnnotateGrammarDatetimeRangesDisable) {
- const std::string test_model = GetGrammarModel();
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
- ASSERT_TRUE(classifier);
+ for (const Example& example : examples) {
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+ options.entity_types.insert(example.collection);
- EXPECT_THAT(classifier->Annotate("11 Jan to 3 Feb"),
- ElementsAreArray({IsAnnotatedSpan(0, 6, "date"),
- IsAnnotatedSpan(10, 15, "date")}));
-
- EXPECT_THAT(classifier->Annotate("11 - 14 of Feb"),
- ElementsAreArray({IsAnnotatedSpan(5, 14, "date")}));
-
- EXPECT_THAT(classifier->Annotate("Monday 10 - 11pm"),
- ElementsAreArray({IsAnnotatedSpan(0, 6, "date"),
- IsAnnotatedSpan(12, 16, "datetime")}));
-
- EXPECT_THAT(classifier->Annotate("7:20am - 8:00pm"),
- ElementsAreArray({IsAnnotatedSpan(0, 6, "datetime"),
- IsAnnotatedSpan(9, 15, "datetime")}));
-}
-
-TEST_F(AnnotatorTest, AnnotateGrammarDatetimeRangesEnable) {
- DateAnnotationOptions options;
- options.enable_date_range = true;
- options.merge_adjacent_components = true;
- const std::string test_model = GetGrammarModel(options);
- std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
- test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
- ASSERT_TRUE(classifier);
-
- EXPECT_THAT(classifier->Annotate("11 Jan to 3 Feb"),
- ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
-
- EXPECT_THAT(classifier->Annotate("11 - 14 of Feb"),
- ElementsAreArray({IsAnnotatedSpan(0, 14, "date")}));
-
- EXPECT_THAT(classifier->Annotate("Monday 10 - 11pm"),
- ElementsAreArray({IsAnnotatedSpan(0, 16, "datetime")}));
-
- EXPECT_THAT(classifier->Annotate("7:20am - 8:00pm"),
- ElementsAreArray({IsAnnotatedSpan(0, 15, "datetime")}));
+ EXPECT_THAT(classifier->Annotate(example.text, options),
+ Contains(IsAnnotationWithType(example.collection)))
+ << " text: '" << example.text
+ << "', collection: " << example.collection;
+ }
}
TEST_F(AnnotatorTest, InitializeFromString) {
@@ -3008,5 +2924,54 @@
EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
}
+// Regression test for cl/338280366. Enabling only_use_line_with_click had
+// the effect, that some annotators in the previous code releases would
+// receive only the last line of the input text. This test has the entity on the
+// first line (duration).
+TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->selection_feature_options->only_use_line_with_click = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ const std::vector<AnnotatedSpan> annotations =
+ classifier->Annotate("let's meet in 3 hours\nbut not now", options);
+
+ EXPECT_THAT(annotations, Contains(IsDurationSpan(
+ /*start=*/14, /*end=*/21,
+ /*duration_ms=*/3 * 60 * 60 * 1000)));
+}
+
+TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ const std::string invalid_utf8_text_with_phone_number =
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
+
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
+ IsEmpty());
+ EXPECT_THAT(
+ classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
+ Eq(CodepointSpan{1, 4}));
+ EXPECT_THAT(
+ classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
+ IsEmpty());
+}
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs
new file mode 100755
index 0000000..8012cdc
--- /dev/null
+++ b/native/annotator/datetime/datetime.fbs
@@ -0,0 +1,145 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Meridiem field.
+namespace libtextclassifier3.grammar.datetime;
+enum Meridiem : int {
+ UNKNOWN = 0,
+
+ // Ante meridiem: Before noon
+ AM = 1,
+
+ // Post meridiem: After noon
+ PM = 2,
+}
+
+// Enum represents a unit of date and time in the expression.
+// Next field: 10
+namespace libtextclassifier3.grammar.datetime;
+enum ComponentType : int {
+ UNSPECIFIED = 0,
+
+ // Year of the date seen in the text match.
+ YEAR = 1,
+
+ // Month of the year starting with January = 1.
+ MONTH = 2,
+
+ // Week (7 days).
+ WEEK = 3,
+
+ // Day of week, start of the week is Sunday & its value is 1.
+ DAY_OF_WEEK = 4,
+
+ // Day of the month starting with 1.
+ DAY_OF_MONTH = 5,
+
+ // Hour of the day.
+ HOUR = 6,
+
+ // Minute of the hour with a range of 0-59.
+ MINUTE = 7,
+
+ // Seconds of the minute with a range of 0-59.
+ SECOND = 8,
+
+ // Meridiem field i.e. AM/PM.
+ MERIDIEM = 9,
+}
+
+namespace libtextclassifier3.grammar.datetime;
+table TimeZone {
+ // Offset from UTC/GTM in minutes.
+ utc_offset_mins:int;
+}
+
+namespace libtextclassifier3.grammar.datetime.RelativeDatetimeComponent_;
+enum Modifier : int {
+ UNSPECIFIED = 0,
+ NEXT = 1,
+ THIS = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+}
+
+// Message for representing the relative date-time component in date-time
+// expressions.
+// Next field: 4
+namespace libtextclassifier3.grammar.datetime;
+table RelativeDatetimeComponent {
+ component_type:ComponentType = UNSPECIFIED;
+ modifier:RelativeDatetimeComponent_.Modifier = UNSPECIFIED;
+ value:int;
+}
+
+// AbsoluteDateTime represents date-time expressions that is not ambiguous.
+// Next field: 11
+namespace libtextclassifier3.grammar.datetime;
+table AbsoluteDateTime {
+ // Year value of the date seen in the text match.
+ year:int = -1;
+
+ // Month value of the year starting with January = 1.
+ month:int = -1;
+
+ // Day value of the month starting with 1.
+ day:int = -1;
+
+ // Day of week, start of the week is Sunday and its value is 1.
+ week_day:int = -1;
+
+ // Hour value of the day.
+ hour:int = -1;
+
+ // Minute value of the hour with a range of 0-59.
+ minute:int = -1;
+
+ // Seconds value of the minute with a range of 0-59.
+ second:int = -1;
+
+ partial_second:double = -1;
+
+ // Meridiem field i.e. AM/PM.
+ meridiem:Meridiem;
+
+ time_zone:TimeZone;
+}
+
+// Message to represent relative datetime expressions.
+// It encode expressions
+// - Where modifier such as before/after shift the date e.g.[three days ago],
+// [2 days after March 1st].
+// - When prefix make the expression relative e.g. [next weekend],
+// [last Monday].
+// Next field: 3
+namespace libtextclassifier3.grammar.datetime;
+table RelativeDateTime {
+ relative_datetime_component:[RelativeDatetimeComponent];
+
+ // The base could be an absolute datetime point for example: "March 1", a
+ // relative datetime point, for example: "2 days before March 1"
+ base:AbsoluteDateTime;
+}
+
+// Datetime result.
+namespace libtextclassifier3.grammar.datetime;
+table UngroundedDatetime {
+ absolute_datetime:AbsoluteDateTime;
+ relative_datetime:RelativeDateTime;
+}
+
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 8b58388..3b3e578 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -19,18 +19,13 @@
#include <memory>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
-#include "annotator/datetime/extractor.h"
-#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar.h"
+#include "utils/base/statusor.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -38,87 +33,25 @@
// time.
class DatetimeParser {
public:
- static std::unique_ptr<DatetimeParser> Instance(
- const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+ virtual ~DatetimeParser() = default;
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
// If 'anchor_start_end' is true the extracted results need to start at the
// beginning of 'input' and end at the end of it.
- bool Parse(const std::string& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const = 0;
// Same as above but takes UnicodeText.
- bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- protected:
- explicit DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor);
-
- // Returns a list of locale ids for given locale spec string (comma-separated
- // locale names). Assigns the first parsed locale to reference_locale.
- std::vector<int> ParseAndExpandLocales(const std::string& locales,
- std::string* reference_locale) const;
-
- // Helper function that finds datetime spans, only using the rules associated
- // with the given locales.
- bool FindSpansUsingLocales(
- const std::vector<int>& locale_ids, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end, const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const;
-
- bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- // Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const;
-
- // Parse and extract information from current match in 'matcher'.
- bool HandleParseMatch(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- private:
- bool initialized_;
- const UniLib& unilib_;
- const CalendarLib& calendarlib_;
- std::vector<CompiledRule> rules_;
- std::unordered_map<int, std::vector<int>> locale_to_rules_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
- std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
- type_and_locale_to_extractor_rule_;
- std::unordered_map<std::string, int> locale_string_to_id_;
- std::vector<int> default_locale_ids_;
- bool use_extractors_for_locating_;
- bool generate_alternative_interpretations_when_ambiguous_;
- bool prefer_future_for_unspecified_date_;
+ bool anchor_start_end) const = 0;
};
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/regex-parser.cc
similarity index 69%
rename from native/annotator/datetime/parser.cc
rename to native/annotator/datetime/regex-parser.cc
index 72fd3ab..4dc9c56 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -14,33 +14,36 @@
* limitations under the License.
*/
-#include "annotator/datetime/parser.h"
+#include "annotator/datetime/regex-parser.h"
+#include <iterator>
#include <set>
#include <unordered_set>
#include "annotator/datetime/extractor.h"
#include "annotator/datetime/utils.h"
+#include "utils/base/statusor.h"
#include "utils/calendar/calendar.h"
#include "utils/i18n/locale.h"
#include "utils/strings/split.h"
#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3 {
-std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
const DatetimeModel* model, const UniLib* unilib,
const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
- std::unique_ptr<DatetimeParser> result(
- new DatetimeParser(model, unilib, calendarlib, decompressor));
+ std::unique_ptr<RegexDatetimeParser> result(
+ new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
if (!result->initialized_) {
result.reset();
}
return result;
}
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor)
+RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
+ const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor)
: unilib_(*unilib), calendarlib_(*calendarlib) {
initialized_ = false;
@@ -113,23 +116,24 @@
initialized_ = true;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
- reference_time_ms_utc, reference_timezone, locales, mode,
- annotation_usecase, anchor_start_end, results);
+ reference_time_ms_utc, reference_timezone, locale_list, mode,
+ annotation_usecase, anchor_start_end);
}
-bool DatetimeParser::FindSpansUsingLocales(
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const {
+ std::unordered_set<int>* executed_rules) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
for (const int locale_id : locale_ids) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
@@ -152,34 +156,33 @@
}
executed_rules->insert(rule_id);
-
- if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- anchor_start_end, found_spans)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
+ ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end));
+ found_spans.insert(std::end(found_spans),
+ std::begin(found_spans_per_rule),
+ std::end(found_spans_per_rule));
}
}
- return true;
+ return found_spans;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> found_spans;
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
std::unordered_set<int> executed_rules;
- std::string reference_locale;
const std::vector<int> requested_locales =
- ParseAndExpandLocales(locales, &reference_locale);
- if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, annotation_usecase,
- anchor_start_end, reference_locale,
- &executed_rules, &found_spans)) {
- return false;
- }
-
+ ParseAndExpandLocales(locale_list.GetLocaleTags());
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans,
+ FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, locale_list.GetReferenceLocale(),
+ &executed_rules));
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
indexed_found_spans.reserve(found_spans.size());
for (int i = 0; i < found_spans.size(); i++) {
@@ -200,39 +203,46 @@
}
});
- found_spans.clear();
+ std::vector<DatetimeParseResultSpan> results;
+ std::vector<DatetimeParseResultSpan> resolved_found_spans;
+ resolved_found_spans.reserve(indexed_found_spans.size());
for (auto& span_index_pair : indexed_found_spans) {
- found_spans.push_back(span_index_pair.first);
+ resolved_found_spans.push_back(span_index_pair.first);
}
std::set<int, std::function<bool(int, int)>> chosen_indices_set(
- [&found_spans](int a, int b) {
- return found_spans[a].span.first < found_spans[b].span.first;
+ [&resolved_found_spans](int a, int b) {
+ return resolved_found_spans[a].span.first <
+ resolved_found_spans[b].span.first;
});
- for (int i = 0; i < found_spans.size(); ++i) {
- if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ for (int i = 0; i < resolved_found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
chosen_indices_set.insert(i);
- results->push_back(found_spans[i]);
+ results.push_back(resolved_found_spans[i]);
}
}
-
- return true;
+ return results;
}
-bool DatetimeParser::HandleParseMatch(
- const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ int locale_id) const {
+ std::vector<DatetimeParseResultSpan> results;
int status = UniLib::RegexMatcher::kNoError;
const int start = matcher.Start(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the start offset of the last match.");
}
const int end = matcher.End(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the end offset of the last match.");
}
DatetimeParseResultSpan parse_result;
@@ -240,7 +250,7 @@
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
reference_locale, locale_id, &alternatives,
&parse_result.span)) {
- return false;
+ return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
}
if (!use_extractors_for_locating_) {
@@ -257,49 +267,44 @@
parse_result.data.push_back(alternative);
}
}
- result->push_back(parse_result);
- return true;
+ results.push_back(parse_result);
+ return results;
}
-bool DatetimeParser::ParseWithRule(
- const CompiledRule& rule, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
+ const UnicodeText& input,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ const int locale_id,
+ bool anchor_start_end) const {
+ std::vector<DatetimeParseResultSpan> results;
std::unique_ptr<UniLib::RegexMatcher> matcher =
rule.compiled_regex->Matcher(input);
int status = UniLib::RegexMatcher::kNoError;
if (anchor_start_end) {
if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id);
}
} else {
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
+ HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id));
+ results.insert(std::end(results), std::begin(pattern_occurrence),
+ std::end(pattern_occurrence));
}
}
- return true;
+ return results;
}
-std::vector<int> DatetimeParser::ParseAndExpandLocales(
- const std::string& locales, std::string* reference_locale) const {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- *reference_locale = split_locales[0].ToString();
- } else {
- *reference_locale = "";
- }
-
+std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const {
std::vector<int> result;
- for (const StringPiece& locale_str : split_locales) {
+ for (const StringPiece& locale_str : locales) {
auto locale_it = locale_string_to_id_.find(locale_str.ToString());
if (locale_it != locale_string_to_id_.end()) {
result.push_back(locale_it->second);
@@ -348,14 +353,12 @@
return result;
}
-bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- const int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const {
+bool RegexDatetimeParser::ExtractDatetime(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const {
DatetimeParsedData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
extractor_rules_,
diff --git a/native/annotator/datetime/regex-parser.h b/native/annotator/datetime/regex-parser.h
new file mode 100644
index 0000000..e820c21
--- /dev/null
+++ b/native/annotator/datetime/regex-parser.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class RegexDatetimeParser : public DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ // Same as above but takes UnicodeText.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ protected:
+ explicit RegexDatetimeParser(const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor);
+
+ // Returns a list of locale ids for given locale spec string (collection of
+ // locale names).
+ std::vector<int> ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ StatusOr<std::vector<DatetimeParseResultSpan>> FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules) const;
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const;
+
+ // Parse and extract information from current match in 'matcher'.
+ StatusOr<std::vector<DatetimeParseResultSpan>> HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ const CalendarLib& calendarlib_;
+ std::vector<CompiledRule> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
+ bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/regex-parser_test.cc
similarity index 88%
rename from native/annotator/datetime/parser_test.cc
rename to native/annotator/datetime/regex-parser_test.cc
index 76b033d..a0d9adf 100644
--- a/native/annotator/datetime/parser_test.cc
+++ b/native/annotator/datetime/regex-parser_test.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "annotator/datetime/parser.h"
+#include "annotator/datetime/regex-parser.h"
#include <time.h>
@@ -24,8 +24,11 @@
#include <string>
#include "annotator/annotator.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/datetime/testing/datetime-component-builder.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
+#include "utils/i18n/locale-list.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
#include "utils/testing/annotator.h"
@@ -33,48 +36,9 @@
#include "gtest/gtest.h"
using std::vector;
-using testing::ElementsAreArray;
namespace libtextclassifier3 {
namespace {
-// Builder class to construct the DatetimeComponents and make the test readable.
-class DatetimeComponentsBuilder {
- public:
- DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
- int value) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- return AddComponent(component);
- }
-
- DatetimeComponentsBuilder Add(
- DatetimeComponent::ComponentType type, int value,
- DatetimeComponent::RelativeQualifier relative_qualifier,
- int relative_count) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- component.relative_qualifier = relative_qualifier;
- component.relative_count = relative_count;
- return AddComponent(component);
- }
-
- std::vector<DatetimeComponent> Build() {
- std::vector<DatetimeComponent> result(datetime_components_);
- datetime_components_.clear();
- return result;
- }
-
- private:
- DatetimeComponentsBuilder AddComponent(
- const DatetimeComponent& datetime_component) {
- datetime_components_.push_back(datetime_component);
- return *this;
- }
- std::vector<DatetimeComponent> datetime_components_;
-};
-
std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
std::string ReadFile(const std::string& file_name) {
@@ -82,7 +46,7 @@
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
-class ParserTest : public testing::Test {
+class RegexDatetimeParserTest : public DateTimeParserTest {
public:
void SetUp() override {
// Loads default unmodified model. Individual tests can call LoadModel to
@@ -104,139 +68,12 @@
TC3_CHECK(parser_);
}
- bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- return results.empty();
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return classifier_->DatetimeParserForTests();
}
- bool ParsesCorrectly(const std::string& marked_text,
- const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- const UnicodeText marked_text_unicode =
- UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
- auto brace_open_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
- auto brace_end_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
- TC3_CHECK(brace_open_it != marked_text_unicode.end());
- TC3_CHECK(brace_end_it != marked_text_unicode.end());
-
- std::string text;
- text +=
- UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
- text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
- text += UnicodeText::UTF8Substring(std::next(brace_end_it),
- marked_text_unicode.end());
-
- std::vector<DatetimeParseResultSpan> results;
-
- if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- if (results.empty()) {
- TC3_LOG(ERROR) << "No results.";
- return false;
- }
-
- const int expected_start_index =
- std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 below is to account for the opening bracket character.
- const int expected_end_index =
- std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
-
- std::vector<DatetimeParseResultSpan> filtered_results;
- for (const DatetimeParseResultSpan& result : results) {
- if (SpansOverlap(result.span,
- {expected_start_index, expected_end_index})) {
- filtered_results.push_back(result);
- }
- }
- std::vector<DatetimeParseResultSpan> expected{
- {{expected_start_index, expected_end_index},
- {},
- /*target_classification_score=*/1.0,
- /*priority_score=*/1.0}};
- expected[0].data.resize(expected_ms_utcs.size());
- for (int i = 0; i < expected_ms_utcs.size(); i++) {
- expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
- datetime_components[i]};
- }
-
- const bool matches =
- testing::Matches(ElementsAreArray(expected))(filtered_results);
- if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0];
- if (filtered_results.empty()) {
- TC3_LOG(ERROR) << "But got no results.";
- }
- TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
- }
-
- return matches;
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
- expected_granularity, datetime_components,
- anchor_start_end, timezone, locales,
- annotation_usecase);
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyChinese(
- const std::string& marked_text, const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
- }
-
- protected:
+ private:
std::string model_buffer_;
std::unique_ptr<Annotator> classifier_;
const DatetimeParser* parser_;
@@ -245,7 +82,7 @@
};
// Test with just a few cases to make debugging of general failures easier.
-TEST_F(ParserTest, ParseShort) {
+TEST_F(RegexDatetimeParserTest, ParseShort) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -255,7 +92,7 @@
.Build()}));
}
-TEST_F(ParserTest, Parse) {
+TEST_F(RegexDatetimeParserTest, Parse) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -696,7 +533,7 @@
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
}
-TEST_F(ParserTest, ParseWithAnchor) {
+TEST_F(RegexDatetimeParserTest, ParseWithAnchor) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -725,7 +562,7 @@
/*anchor_start_end=*/true));
}
-TEST_F(ParserTest, ParseWithRawUsecase) {
+TEST_F(RegexDatetimeParserTest, ParseWithRawUsecase) {
// Annotated for RAW usecase.
EXPECT_TRUE(ParsesCorrectly(
"{tomorrow}", 82800000, GRANULARITY_DAY,
@@ -784,7 +621,7 @@
}
// For details please see b/155437137
-TEST_F(ParserTest, PastRelativeDatetime) {
+TEST_F(RegexDatetimeParserTest, PastRelativeDatetime) {
EXPECT_TRUE(ParsesCorrectly(
"called you {last Saturday}",
-432000000 /* Fri 1969-12-26 16:00:00 PST */, GRANULARITY_DAY,
@@ -830,7 +667,7 @@
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
}
-TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+TEST_F(RegexDatetimeParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
// ParsesCorrectly uses 0 as the reference time, which corresponds to:
// "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
// it is in the past, and so the parser should move this to the next day ->
@@ -845,7 +682,8 @@
.Build()}));
}
-TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
+TEST_F(RegexDatetimeParserTest,
+ DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
// ParsesCorrectly uses 0 as the reference time, which corresponds to:
// "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
// it is in the past. The parameter prefer_future_when_unspecified_day is
@@ -867,7 +705,7 @@
.Build()}));
}
-TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
+TEST_F(RegexDatetimeParserTest, ParsesNoonAndMidnightCorrectly) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
{DatetimeComponentsBuilder()
@@ -899,7 +737,7 @@
.Build()}));
}
-TEST_F(ParserTest, ParseGerman) {
+TEST_F(RegexDatetimeParserTest, ParseGerman) {
EXPECT_TRUE(ParsesCorrectlyGerman(
"{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -1309,7 +1147,7 @@
.Build()}));
}
-TEST_F(ParserTest, ParseChinese) {
+TEST_F(RegexDatetimeParserTest, ParseChinese) {
EXPECT_TRUE(ParsesCorrectlyChinese(
"{明天 7 上午}", 108000000, GRANULARITY_HOUR,
{DatetimeComponentsBuilder()
@@ -1320,7 +1158,7 @@
.Build()}));
}
-TEST_F(ParserTest, ParseNonUs) {
+TEST_F(RegexDatetimeParserTest, ParseNonUs) {
auto first_may_2015 =
DatetimeComponentsBuilder()
.Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
@@ -1339,7 +1177,7 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"en"));
}
-TEST_F(ParserTest, ParseUs) {
+TEST_F(RegexDatetimeParserTest, ParseUs) {
auto five_january_2015 =
DatetimeComponentsBuilder()
.Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
@@ -1359,7 +1197,7 @@
/*locales=*/"es-US"));
}
-TEST_F(ParserTest, ParseUnknownLanguage) {
+TEST_F(RegexDatetimeParserTest, ParseUnknownLanguage) {
EXPECT_TRUE(ParsesCorrectly(
"bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -1371,7 +1209,7 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
-TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
+TEST_F(RegexDatetimeParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
true;
@@ -1422,7 +1260,8 @@
.Build()}));
}
-TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
+TEST_F(RegexDatetimeParserTest,
+ WhenAlternativesDisabledDoesNotGenerateAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
false;
@@ -1490,19 +1329,19 @@
unilib_ = CreateUniLibForTesting();
calendarlib_ = CreateCalendarLibForTesting();
parser_ =
- DatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
- /*decompressor=*/nullptr);
+ RegexDatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
+ /*decompressor=*/nullptr);
ASSERT_TRUE(parser_);
}
bool ParserLocaleTest::HasResult(const std::string& input,
const std::string& locales) {
- std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(
+ StatusOr<std::vector<DatetimeParseResultSpan>> results = parser_->Parse(
input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
- AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
- return results.size() == 1;
+ /*reference_timezone=*/"", LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, AnnotationUsecase_ANNOTATION_USECASE_SMART, false);
+ EXPECT_TRUE(results.ok());
+ return results.ValueOrDie().size() == 1;
}
TEST_F(ParserLocaleTest, English) {
diff --git a/native/annotator/datetime/testing/base-parser-test.cc b/native/annotator/datetime/testing/base-parser-test.cc
new file mode 100644
index 0000000..d8dd723
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.cc
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/base-parser-test.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/i18n/locale-list.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using std::vector;
+using testing::ElementsAreArray;
+
+namespace libtextclassifier3 {
+
+bool DateTimeParserTest::HasNoResult(const std::string& text,
+ bool anchor_start_end,
+ const std::string& timezone,
+ AnnotationUsecase annotation_usecase) {
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(/*locale_tags=*/""),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ return results_status.ValueOrDie().empty();
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC3_CHECK(brace_open_it != marked_text_unicode.end());
+ TC3_CHECK(brace_end_it != marked_text_unicode.end());
+
+ std::string text;
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ // const std::vector<DatetimeParseResultSpan>& results =
+ // results_status.ValueOrDie();
+ if (results_status.ValueOrDie().empty()) {
+ TC3_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 below is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results_status.ValueOrDie()) {
+ if (SpansOverlap(result.span, {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+ std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index},
+ {},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
+ datetime_components[i]};
+ }
+
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
+ }
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
+ }
+
+ return matches;
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
+ expected_granularity, datetime_components,
+ anchor_start_end, timezone, locales,
+ annotation_usecase);
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/base-parser-test.h b/native/annotator/datetime/testing/base-parser-test.h
new file mode 100644
index 0000000..3465a04
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/datetime/parser.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+class DateTimeParserTest : public testing::Test {
+ public:
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ // Exposes the date time parser for tests and evaluations.
+ virtual const DatetimeParser* DatetimeParserForTests() const = 0;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
diff --git a/native/annotator/datetime/testing/datetime-component-builder.cc b/native/annotator/datetime/testing/datetime-component-builder.cc
new file mode 100644
index 0000000..f0764da
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/datetime-component-builder.h"
+
+namespace libtextclassifier3 {
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ return AddComponent(component);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ component.relative_qualifier = relative_qualifier;
+ component.relative_count = relative_count;
+ return AddComponent(component);
+}
+
+std::vector<DatetimeComponent> DatetimeComponentsBuilder::Build() {
+ return std::move(datetime_components_);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::AddComponent(
+ const DatetimeComponent& datetime_component) {
+ datetime_components_.push_back(datetime_component);
+ return *this;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/datetime-component-builder.h b/native/annotator/datetime/testing/datetime-component-builder.h
new file mode 100644
index 0000000..a6a9f36
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Builder class to construct the DatetimeComponents and make the test readable.
+class DatetimeComponentsBuilder {
+ public:
+ DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
+ int value);
+
+ DatetimeComponentsBuilder Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count);
+
+ std::vector<DatetimeComponent> Build();
+
+ private:
+ DatetimeComponentsBuilder AddComponent(
+ const DatetimeComponent& datetime_component);
+ std::vector<DatetimeComponent> datetime_components_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-options.h b/native/annotator/grammar/dates/annotations/annotation-options.h
deleted file mode 100755
index 29e9939..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-options.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier3 {
-
-// Options for date/datetime/date range annotations.
-struct DateAnnotationOptions {
- // If enabled, extract special day offset like today, yesterday, etc.
- bool enable_special_day_offset;
-
- // If true, merge the adjacent day of week, time and date. e.g.
- // "20/2/2016 at 8pm" is extracted as a single instance instead of two
- // instance: "20/2/2016" and "8pm".
- bool merge_adjacent_components;
-
- // List the extra id of requested dates.
- std::vector<std::string> extra_requested_dates;
-
- // If true, try to include preposition to the extracted annotation. e.g.
- // "at 6pm". if it's false, only 6pm is included. offline-actions has special
- // requirements to include preposition.
- bool include_preposition;
-
- // The base timestamp (milliseconds) which used to convert relative time to
- // absolute time.
- // e.g.:
- // base timestamp is 2016/4/25, then tomorrow will be converted to
- // 2016/4/26.
- // base timestamp is 2016/4/25 10:30:20am, then 1 days, 2 hours, 10 minutes
- // and 5 seconds ago will be converted to 2016/4/24 08:20:15am
- int64 base_timestamp_millis;
-
- // If enabled, extract range in date annotator.
- // input: Monday, 5-6pm
- // If the flag is true, The extracted annotation only contains 1 range
- // instance which is from Monday 5pm to 6pm.
- // If the flag is false, The extracted annotation contains two date
- // instance: "Monday" and "6pm".
- bool enable_date_range;
-
- // Timezone in which the input text was written
- std::string reference_timezone;
- // Localization params.
- // The format of the locale lists should be "<lang_code-<county_code>"
- // comma-separated list of two-character language/country pairs.
- std::string locales;
-
- // If enabled, the annotation/rule_match priority score is used to set the and
- // priority score of the annotation.
- // In case of false the annotation priority score are set from
- // GrammarDatetimeModel's priority_score
- bool use_rule_priority_score;
-
- // If enabled, annotator will try to resolve the ambiguity by generating
- // possible alternative interpretations of the input text
- // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
- bool generate_alternative_interpretations_when_ambiguous;
-
- // List the ignored span in the date string e.g. 12 March @12PM, here '@'
- // can be ignored tokens.
- std::vector<std::string> ignored_spans;
-
- // Default Constructor
- DateAnnotationOptions()
- : enable_special_day_offset(true),
- merge_adjacent_components(true),
- include_preposition(false),
- base_timestamp_millis(0),
- enable_date_range(false),
- use_rule_priority_score(false),
- generate_alternative_interpretations_when_ambiguous(false) {}
-};
-
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.cc b/native/annotator/grammar/dates/annotations/annotation-util.cc
deleted file mode 100644
index 438206f..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util.cc
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-
-#include <algorithm>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data) {
- for (int i = 0; i < annotation_data.properties.size(); ++i) {
- if (annotation_data.properties[i].name == name.ToString()) {
- return i;
- }
- }
- return -1;
-}
-
-int GetPropertyIndex(StringPiece name, const Annotation& annotation) {
- return GetPropertyIndex(name, annotation.data);
-}
-
-int GetIntProperty(StringPiece name, const Annotation& annotation) {
- return GetIntProperty(name, annotation.data);
-}
-
-int GetIntProperty(StringPiece name, const AnnotationData& annotation_data) {
- const int index = GetPropertyIndex(name, annotation_data);
- if (index < 0) {
- TC3_DCHECK_GE(index, 0)
- << "No property with name " << name.ToString() << ".";
- return 0;
- }
-
- if (annotation_data.properties.at(index).int_values.size() != 1) {
- TC3_DCHECK_EQ(annotation_data.properties[index].int_values.size(), 1);
- return 0;
- }
-
- return annotation_data.properties.at(index).int_values.at(0);
-}
-
-int AddIntProperty(StringPiece name, int value, Annotation* annotation) {
- return AddRepeatedIntProperty(name, &value, 1, annotation);
-}
-
-int AddIntProperty(StringPiece name, int value,
- AnnotationData* annotation_data) {
- return AddRepeatedIntProperty(name, &value, 1, annotation_data);
-}
-
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- Annotation* annotation) {
- return AddRepeatedIntProperty(name, start, size, &annotation->data);
-}
-
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- AnnotationData* annotation_data) {
- Property property;
- property.name = name.ToString();
- auto first = start;
- auto last = start + size;
- while (first != last) {
- property.int_values.push_back(*first);
- first++;
- }
- annotation_data->properties.push_back(property);
- return annotation_data->properties.size() - 1;
-}
-
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- AnnotationData* annotation_data) {
- Property property;
- property.name = key;
- property.annotation_data_values.push_back(value);
- annotation_data->properties.push_back(property);
- return annotation_data->properties.size() - 1;
-}
-
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- Annotation* annotation) {
- return AddAnnotationDataProperty(key, value, &annotation->data);
-}
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.h b/native/annotator/grammar/dates/annotations/annotation-util.h
deleted file mode 100644
index e4afbfe..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Return the index of property in annotation.data().properties().
-// Return -1 if the property does not exist.
-int GetPropertyIndex(StringPiece name, const Annotation& annotation);
-
-// Return the index of property in thing.properties().
-// Return -1 if the property does not exist.
-int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data);
-
-// Return the single int value for property 'name' of the annotation.
-// Returns 0 if the property does not exist or does not contain a single int
-// value.
-int GetIntProperty(StringPiece name, const Annotation& annotation);
-
-// Return the single float value for property 'name' of the annotation.
-// Returns 0 if the property does not exist or does not contain a single int
-// value.
-int GetIntProperty(StringPiece name, const AnnotationData& annotation_data);
-
-// Add a new property with a single int value to an Annotation instance.
-// Return the index of the property.
-int AddIntProperty(StringPiece name, const int value, Annotation* annotation);
-
-// Add a new property with a single int value to a Thing instance.
-// Return the index of the property.
-int AddIntProperty(StringPiece name, const int value,
- AnnotationData* annotation_data);
-
-// Add a new property with repeated int values to an Annotation instance.
-// Return the index of the property.
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- Annotation* annotation);
-
-// Add a new property with repeated int values to a Thing instance.
-// Return the index of the property.
-int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
- AnnotationData* annotation_data);
-
-// Add a new property with Thing value.
-// Return the index of the property.
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- Annotation* annotation);
-
-// Add a new property with Thing value.
-// Return the index of the property.
-int AddAnnotationDataProperty(const std::string& key,
- const AnnotationData& value,
- AnnotationData* annotation_data);
-
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util_test.cc b/native/annotator/grammar/dates/annotations/annotation-util_test.cc
deleted file mode 100644
index 6d25d64..0000000
--- a/native/annotator/grammar/dates/annotations/annotation-util_test.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(AnnotationUtilTest, VerifyIntFunctions) {
- Annotation annotation;
-
- int index_key1 = AddIntProperty("key1", 1, &annotation);
- int index_key2 = AddIntProperty("key2", 2, &annotation);
-
- static const int kValuesKey3[] = {3, 4, 5};
- int index_key3 =
- AddRepeatedIntProperty("key3", kValuesKey3, /*size=*/3, &annotation);
-
- EXPECT_EQ(2, GetIntProperty("key2", annotation));
- EXPECT_EQ(1, GetIntProperty("key1", annotation));
-
- EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
- EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
- EXPECT_EQ(index_key3, GetPropertyIndex("key3", annotation));
- EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
-}
-
-TEST(AnnotationUtilTest, VerifyAnnotationDataFunctions) {
- Annotation annotation;
-
- AnnotationData true_annotation_data;
- Property true_property;
- true_property.bool_values.push_back(true);
- true_annotation_data.properties.push_back(true_property);
- int index_key1 =
- AddAnnotationDataProperty("key1", true_annotation_data, &annotation);
-
- AnnotationData false_annotation_data;
- Property false_property;
- false_property.bool_values.push_back(false);
- true_annotation_data.properties.push_back(false_property);
- int index_key2 =
- AddAnnotationDataProperty("key2", false_annotation_data, &annotation);
-
- EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
- EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
- EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation.h b/native/annotator/grammar/dates/annotations/annotation.h
deleted file mode 100644
index e6ddb09..0000000
--- a/native/annotator/grammar/dates/annotations/annotation.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier3 {
-
-struct AnnotationData;
-
-// Define enum for each annotation.
-enum GrammarAnnotationType {
- // Date&time like "May 1", "12:20pm", etc.
- DATETIME = 0,
- // Datetime range like "2pm - 3pm".
- DATETIME_RANGE = 1,
-};
-
-struct Property {
- // TODO(hassan): Replace the name with enum e.g. PropertyType.
- std::string name;
- // At most one of these will have any values.
- std::vector<bool> bool_values;
- std::vector<int64> int_values;
- std::vector<double> double_values;
- std::vector<std::string> string_values;
- std::vector<AnnotationData> annotation_data_values;
-};
-
-struct AnnotationData {
- // TODO(hassan): Replace it type with GrammarAnnotationType
- std::string type;
- std::vector<Property> properties;
-};
-
-// Represents an annotation instance.
-// lets call it either AnnotationDetails
-struct Annotation {
- // Codepoint offsets into the original text specifying the substring of the
- // text that was annotated.
- int32 begin;
- int32 end;
-
- // Annotation priority score which can be used to resolve conflict between
- // annotators.
- float annotator_priority_score;
-
- // Represents the details of the annotation instance, including the type of
- // the annotation instance and its properties.
- AnnotationData data;
-};
-} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.cc b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
deleted file mode 100644
index 99d3be0..0000000
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/cfg-datetime-annotator.h"
-
-#include "annotator/datetime/utils.h"
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/utils.h"
-#include "utils/strings/split.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-static std::string GetReferenceLocale(const std::string& locales) {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- return split_locales[0].ToString();
- }
- return "";
-}
-
-static void InterpretParseData(const DatetimeParsedData& datetime_parsed_data,
- const DateAnnotationOptions& options,
- const CalendarLib& calendarlib,
- int64* interpreted_time_ms_utc,
- DatetimeGranularity* granularity) {
- DatetimeGranularity local_granularity =
- calendarlib.GetGranularity(datetime_parsed_data);
- if (!calendarlib.InterpretParseData(
- datetime_parsed_data, options.base_timestamp_millis,
- options.reference_timezone, GetReferenceLocale(options.locales),
- /*prefer_future_for_unspecified_date=*/true, interpreted_time_ms_utc,
- granularity)) {
- TC3_LOG(WARNING) << "Failed to extract time in millis and Granularity.";
- // Fallingback to DatetimeParsedData's finest granularity
- *granularity = local_granularity;
- }
-}
-
-} // namespace
-
-CfgDatetimeAnnotator::CfgDatetimeAnnotator(
- const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
- const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
- const float annotator_target_classification_score,
- const float annotator_priority_score)
- : calendar_lib_(*calendar_lib),
- tokenizer_(BuildTokenizer(unilib, tokenizer_options)),
- parser_(unilib, datetime_rules),
- annotator_target_classification_score_(
- annotator_target_classification_score),
- annotator_priority_score_(annotator_priority_score) {}
-
-void CfgDatetimeAnnotator::Parse(
- const std::string& input, const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
- Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), annotation_options,
- locales, results);
-}
-
-void CfgDatetimeAnnotator::ProcessDatetimeParseResult(
- const DateAnnotationOptions& annotation_options,
- const DatetimeParseResult& datetime_parse_result,
- std::vector<DatetimeParseResult>* results) const {
- DatetimeParsedData datetime_parsed_data;
- datetime_parsed_data.AddDatetimeComponents(
- datetime_parse_result.datetime_components);
-
- std::vector<DatetimeParsedData> interpretations;
- if (annotation_options.generate_alternative_interpretations_when_ambiguous) {
- FillInterpretations(datetime_parsed_data,
- calendar_lib_.GetGranularity(datetime_parsed_data),
- &interpretations);
- } else {
- interpretations.emplace_back(datetime_parsed_data);
- }
- for (const DatetimeParsedData& interpretation : interpretations) {
- results->emplace_back();
- interpretation.GetDatetimeComponents(&results->back().datetime_components);
- InterpretParseData(interpretation, annotation_options, calendar_lib_,
- &(results->back().time_ms_utc),
- &(results->back().granularity));
- std::sort(results->back().datetime_components.begin(),
- results->back().datetime_components.end(),
- [](const DatetimeComponent& a, const DatetimeComponent& b) {
- return a.component_type > b.component_type;
- });
- }
-}
-
-void CfgDatetimeAnnotator::Parse(
- const UnicodeText& input, const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> grammar_datetime_parse_result_spans =
- parser_.Parse(input.data(), tokenizer_.Tokenize(input), locales,
- annotation_options);
-
- for (const DatetimeParseResultSpan& grammar_datetime_parse_result_span :
- grammar_datetime_parse_result_spans) {
- DatetimeParseResultSpan datetime_parse_result_span;
- datetime_parse_result_span.span.first =
- grammar_datetime_parse_result_span.span.first;
- datetime_parse_result_span.span.second =
- grammar_datetime_parse_result_span.span.second;
- datetime_parse_result_span.priority_score = annotator_priority_score_;
- if (annotation_options.use_rule_priority_score) {
- datetime_parse_result_span.priority_score =
- grammar_datetime_parse_result_span.priority_score;
- }
- datetime_parse_result_span.target_classification_score =
- annotator_target_classification_score_;
- for (const DatetimeParseResult& grammar_datetime_parse_result :
- grammar_datetime_parse_result_span.data) {
- ProcessDatetimeParseResult(annotation_options,
- grammar_datetime_parse_result,
- &datetime_parse_result_span.data);
- }
- results->emplace_back(datetime_parse_result_span);
- }
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.h b/native/annotator/grammar/dates/cfg-datetime-annotator.h
deleted file mode 100644
index 73c9b7b..0000000
--- a/native/annotator/grammar/dates/cfg-datetime-annotator.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/parser.h"
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-#include "annotator/model_generated.h"
-#include "utils/calendar/calendar.h"
-#include "utils/i18n/locale.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::dates {
-
-// Helper class to convert the parsed datetime expression from AnnotationList
-// (List of annotation generated from Grammar rules) to DatetimeParseResultSpan.
-class CfgDatetimeAnnotator {
- public:
- explicit CfgDatetimeAnnotator(
- const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
- const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
- const float annotator_target_classification_score,
- const float annotator_priority_score);
-
- // CfgDatetimeAnnotator is neither copyable nor movable.
- CfgDatetimeAnnotator(const CfgDatetimeAnnotator&) = delete;
- CfgDatetimeAnnotator& operator=(const CfgDatetimeAnnotator&) = delete;
-
- // Parses the dates in 'input' and fills result. Makes sure that the results
- // do not overlap.
- // Method will return false if input does not contain any datetime span.
- void Parse(const std::string& input,
- const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- // UnicodeText version of parse.
- void Parse(const UnicodeText& input,
- const DateAnnotationOptions& annotation_options,
- const std::vector<Locale>& locales,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- private:
- void ProcessDatetimeParseResult(
- const DateAnnotationOptions& annotation_options,
- const DatetimeParseResult& datetime_parse_result,
- std::vector<DatetimeParseResult>* results) const;
-
- const CalendarLib& calendar_lib_;
- const Tokenizer tokenizer_;
- DateParser parser_;
- const float annotator_target_classification_score_;
- const float annotator_priority_score_;
-};
-
-} // namespace libtextclassifier3::dates
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
diff --git a/native/annotator/grammar/dates/dates.fbs b/native/annotator/grammar/dates/dates.fbs
deleted file mode 100755
index b54e0f0..0000000
--- a/native/annotator/grammar/dates/dates.fbs
+++ /dev/null
@@ -1,351 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "utils/grammar/rules.fbs";
-include "annotator/grammar/dates/timezone-code.fbs";
-
-// Type identifiers of all non-trivial matches.
-namespace libtextclassifier3.dates;
-enum MatchType : int {
- UNKNOWN = 0,
-
- // Match of a date extraction rule.
- DATETIME_RULE = 1,
-
- // Match of a date range extraction rule.
- DATETIME_RANGE_RULE = 2,
-
- // Match defined by an ExtractionRule (e.g., a single time-result that is
- // matched by a time-rule, which is ready to be output individually, with
- // this kind of match, we can retrieve it in range rules).
- DATETIME = 3,
-
- // Match defined by TermValue.
- TERM_VALUE = 4,
-
- // Matches defined by Nonterminal.
- NONTERMINAL = 5,
-
- DIGITS = 6,
- YEAR = 7,
- MONTH = 8,
- DAY = 9,
- HOUR = 10,
- MINUTE = 11,
- SECOND = 12,
- FRACTION_SECOND = 13,
- DAY_OF_WEEK = 14,
- TIME_VALUE = 15,
- TIME_SPAN = 16,
- TIME_ZONE_NAME = 17,
- TIME_ZONE_OFFSET = 18,
- TIME_PERIOD = 19,
- RELATIVE_DATE = 20,
- COMBINED_DIGITS = 21,
-}
-
-namespace libtextclassifier3.dates;
-enum BCAD : int {
- BCAD_NONE = -1,
- BC = 0,
- AD = 1,
-}
-
-namespace libtextclassifier3.dates;
-enum DayOfWeek : int {
- DOW_NONE = -1,
- SUNDAY = 1,
- MONDAY = 2,
- TUESDAY = 3,
- WEDNESDAY = 4,
- THURSDAY = 5,
- FRIDAY = 6,
- SATURDAY = 7,
-}
-
-namespace libtextclassifier3.dates;
-enum TimespanCode : int {
- TIMESPAN_CODE_NONE = -1,
- AM = 0,
- PM = 1,
- NOON = 2,
- MIDNIGHT = 3,
-
- // English "tonight".
- TONIGHT = 11,
-}
-
-// The datetime grammar rules.
-namespace libtextclassifier3.dates;
-table DatetimeRules {
- // The context free grammar rules.
- rules:grammar.RulesSet;
-
- // Values associated with grammar rule matches.
- extraction_rule:[ExtractionRuleParameter];
-
- term_value:[TermValue];
- nonterminal_value:[NonterminalValue];
-}
-
-namespace libtextclassifier3.dates;
-table TermValue {
- value:int;
-
- // A time segment e.g. 10AM - 12AM
- time_span_spec:TimeSpanSpec;
-
- // Time zone information representation
- time_zone_name_spec:TimeZoneNameSpec;
-}
-
-// Define nonterms from terms or other nonterms.
-namespace libtextclassifier3.dates;
-table NonterminalValue {
- // Mapping value.
- value:TermValue;
-
- // Parameter describing formatting choices for nonterminal messages
- nonterminal_parameter:NonterminalParameter;
-
- // Parameter interpreting past/future dates (e.g. "last year")
- relative_parameter:RelativeParameter;
-
- // Format info for nonterminals representing times.
- time_value_parameter:TimeValueParameter;
-
- // Parameter describing the format of time-zone info - e.g. "UTC-8"
- time_zone_offset_parameter:TimeZoneOffsetParameter;
-}
-
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum RelativeType : int {
- NONE = 0,
- YEAR = 1,
- MONTH = 2,
- DAY = 3,
- WEEK = 4,
- HOUR = 5,
- MINUTE = 6,
- SECOND = 7,
-}
-
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum Period : int {
- PERIOD_UNKNOWN = 0,
- PERIOD_PAST = 1,
- PERIOD_FUTURE = 2,
-}
-
-// Relative interpretation.
-// Indicates which day the day of week could be, for example "next Friday"
-// could means the Friday which is the closest Friday or the Friday in the
-// next week.
-namespace libtextclassifier3.dates.RelativeParameter_;
-enum Interpretation : int {
- UNKNOWN = 0,
-
- // The closest X in the past.
- NEAREST_LAST = 1,
-
- // The X before the closest X in the past.
- SECOND_LAST = 2,
-
- // The closest X in the future.
- NEAREST_NEXT = 3,
-
- // The X after the closest X in the future.
- SECOND_NEXT = 4,
-
- // X in the previous one.
- PREVIOUS = 5,
-
- // X in the coming one.
- COMING = 6,
-
- // X in current one, it can be both past and future.
- CURRENT = 7,
-
- // Some X.
- SOME = 8,
-
- // The closest X, it can be both past and future.
- NEAREST = 9,
-}
-
-namespace libtextclassifier3.dates;
-table RelativeParameter {
- type:RelativeParameter_.RelativeType = NONE;
- period:RelativeParameter_.Period = PERIOD_UNKNOWN;
- day_of_week_interpretation:[RelativeParameter_.Interpretation];
-}
-
-namespace libtextclassifier3.dates.NonterminalParameter_;
-enum Flag : int {
- IS_SPELLED = 1,
-}
-
-namespace libtextclassifier3.dates;
-table NonterminalParameter {
- // Bit-wise OR Flag.
- flag:uint = 0;
-
- combined_digits_format:string (shared);
-}
-
-namespace libtextclassifier3.dates.TimeValueParameter_;
-enum TimeValueValidation : int {
- // Allow extra spaces between sub-components in time-value.
- ALLOW_EXTRA_SPACE = 1,
- // 1 << 0
-
- // Disallow colon- or dot-context with digits for time-value.
- DISALLOW_COLON_DOT_CONTEXT = 2,
- // 1 << 1
-}
-
-namespace libtextclassifier3.dates;
-table TimeValueParameter {
- validation:uint = 0;
- // Bitwise-OR
-
- flag:uint = 0;
- // Bitwise-OR
-}
-
-namespace libtextclassifier3.dates.TimeZoneOffsetParameter_;
-enum Format : int {
- // Offset is in an uncategorized format.
- FORMAT_UNKNOWN = 0,
-
- // Offset contains 1-digit hour only, e.g. "UTC-8".
- FORMAT_H = 1,
-
- // Offset contains 2-digit hour only, e.g. "UTC-08".
- FORMAT_HH = 2,
-
- // Offset contains 1-digit hour and minute, e.g. "UTC-8:00".
- FORMAT_H_MM = 3,
-
- // Offset contains 2-digit hour and minute, e.g. "UTC-08:00".
- FORMAT_HH_MM = 4,
-
- // Offset contains 3-digit hour-and-minute, e.g. "UTC-800".
- FORMAT_HMM = 5,
-
- // Offset contains 4-digit hour-and-minute, e.g. "UTC-0800".
- FORMAT_HHMM = 6,
-}
-
-namespace libtextclassifier3.dates;
-table TimeZoneOffsetParameter {
- format:TimeZoneOffsetParameter_.Format = FORMAT_UNKNOWN;
-}
-
-namespace libtextclassifier3.dates.ExtractionRuleParameter_;
-enum ExtractionValidation : int {
- // Boundary checking for final match.
- LEFT_BOUND = 1,
-
- RIGHT_BOUND = 2,
- SPELLED_YEAR = 4,
- SPELLED_MONTH = 8,
- SPELLED_DAY = 16,
-
- // Without this validation-flag set, unconfident time-zone expression
- // are discarded in the output-callback, e.g. "-08:00, +8".
- ALLOW_UNCONFIDENT_TIME_ZONE = 32,
-}
-
-// Parameter info for extraction rule, help rule explanation.
-namespace libtextclassifier3.dates;
-table ExtractionRuleParameter {
- // Bit-wise OR Validation.
- validation:uint = 0;
-
- priority_delta:int;
- id:string (shared);
-
- // The score reflects the confidence score of the date/time match, which is
- // set while creating grammar rules.
- // e.g. given we have the rule which detect "22.33" as a HH.MM then because
- // of ambiguity the confidence of this match maybe relatively less.
- annotator_priority_score:float;
-}
-
-// Internal structure used to describe an hour-mapping segment.
-namespace libtextclassifier3.dates.TimeSpanSpec_;
-table Segment {
- // From 0 to 24, the beginning hour of the segment, always included.
- begin:int;
-
- // From 0 to 24, the ending hour of the segment, not included if the
- // segment is not closed. The value 0 means the beginning of the next
- // day, the same value as "begin" means a time-point.
- end:int;
-
- // From -24 to 24, the mapping offset in hours from spanned expressions
- // to 24-hour expressions. The value 0 means identical mapping.
- offset:int;
-
- // True if the segment is a closed one instead of a half-open one.
- // Always set it to true when describing time-points.
- is_closed:bool = false;
-
- // True if a strict check should be performed onto the segment which
- // disallows already-offset hours to be used in spanned expressions,
- // e.g. 15:30PM.
- is_strict:bool = false;
-
- // True if the time-span can be used without an explicitly specified
- // hour value, then it can generate an exact time point (the "begin"
- // o'clock sharp, like "noon") or a time range, like "Tonight".
- is_stand_alone:bool = false;
-}
-
-namespace libtextclassifier3.dates;
-table TimeSpanSpec {
- code:TimespanCode;
- segment:[TimeSpanSpec_.Segment];
-}
-
-namespace libtextclassifier3.dates.TimeZoneNameSpec_;
-enum TimeZoneType : int {
- // The corresponding name might represent a standard or daylight-saving
- // time-zone, depending on some external information, e.g. the date.
- AMBIGUOUS = 0,
-
- // The corresponding name represents a standard time-zone.
- STANDARD = 1,
-
- // The corresponding name represents a daylight-saving time-zone.
- DAYLIGHT = 2,
-}
-
-namespace libtextclassifier3.dates;
-table TimeZoneNameSpec {
- code:TimezoneCode;
- type:TimeZoneNameSpec_.TimeZoneType = AMBIGUOUS;
-
- // Set to true if the corresponding name is internationally used as an
- // abbreviation (or expression) of UTC. For example, "GMT" and "Z".
- is_utc:bool = false;
-
- // Set to false if the corresponding name is not an abbreviation. For example,
- // "Pacific Time" and "China Standard Time".
- is_abbreviation:bool = true;
-}
-
diff --git a/native/annotator/grammar/dates/extractor.cc b/native/annotator/grammar/dates/extractor.cc
deleted file mode 100644
index d2db23e..0000000
--- a/native/annotator/grammar/dates/extractor.cc
+++ /dev/null
@@ -1,913 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/extractor.h"
-
-#include <initializer_list>
-#include <map>
-
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/base/casts.h"
-#include "utils/base/logging.h"
-#include "utils/strings/numbers.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-// Helper struct for time-related components.
-// Extracts all subnodes of a specified type.
-struct MatchComponents {
- MatchComponents(const grammar::Match* root,
- std::initializer_list<int16> types)
- : root(root),
- components(grammar::SelectAll(
- root, [root, &types](const grammar::Match* node) {
- if (node == root || node->type == grammar::Match::kUnknownType) {
- return false;
- }
- for (const int64 type : types) {
- if (node->type == type) {
- return true;
- }
- }
- return false;
- })) {}
-
- // Returns the index of the first submatch of the specified type or -1 if not
- // found.
- int IndexOf(const int16 type, const int start_index = 0) const {
- for (int i = start_index; i < components.size(); i++) {
- if (components[i]->type == type) {
- return i;
- }
- }
- return -1;
- }
-
- // Returns the first submatch of the specified type, or nullptr if not found.
- template <typename T>
- const T* SubmatchOf(const int16 type, const int start_index = 0) const {
- return SubmatchAt<T>(IndexOf(type, start_index));
- }
-
- template <typename T>
- const T* SubmatchAt(const int index) const {
- if (index < 0) {
- return nullptr;
- }
- return static_cast<const T*>(components[index]);
- }
-
- const grammar::Match* root;
- std::vector<const grammar::Match*> components;
-};
-
-// Helper method to check whether a time value has valid components.
-bool IsValidTimeValue(const TimeValueMatch& time_value) {
- // Can only specify seconds if minutes are present.
- if (time_value.minute == NO_VAL && time_value.second != NO_VAL) {
- return false;
- }
- // Can only specify fraction of seconds if seconds are present.
- if (time_value.second == NO_VAL && time_value.fraction_second >= 0.0) {
- return false;
- }
-
- const int8 h = time_value.hour;
- const int8 m = (time_value.minute < 0 ? 0 : time_value.minute);
- const int8 s = (time_value.second < 0 ? 0 : time_value.second);
- const double f =
- (time_value.fraction_second < 0.0 ? 0.0 : time_value.fraction_second);
-
- // Check value bounds.
- if (h == NO_VAL || h > 24 || m > 59 || s > 60) {
- return false;
- }
- if (h == 24 && (m != 0 || s != 0 || f > 0.0)) {
- return false;
- }
- if (s == 60 && m != 59) {
- return false;
- }
- return true;
-}
-
-int ParseLeadingDec32Value(const char* c_str) {
- int value;
- if (ParseInt32(c_str, &value)) {
- return value;
- }
- return NO_VAL;
-}
-
-double ParseLeadingDoubleValue(const char* c_str) {
- double value;
- if (ParseDouble(c_str, &value)) {
- return value;
- }
- return NO_VAL;
-}
-
-// Extracts digits as an integer and adds a typed match accordingly.
-template <typename T>
-void CheckDigits(const grammar::Match* match,
- const NonterminalValue* nonterminal, StringPiece match_text,
- grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const int value = ParseLeadingDec32Value(match_text.ToString().c_str());
- if (!T::IsValid(value)) {
- return;
- }
- const int num_digits = match_text.size();
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- result->count_of_digits = num_digits;
- result->is_zero_prefixed = (num_digits >= 2 && match_text[0] == '0');
- matcher->AddMatch(result);
-}
-
-// Extracts digits as a decimal (as fraction, as if a "0." is prefixed) and
-// adds a typed match to the `er accordingly.
-template <typename T>
-void CheckDigitsAsFraction(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text, grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- // TODO(smillius): Should should be achievable in a more straight-forward way.
- const double value =
- ParseLeadingDoubleValue(("0." + match_text.ToString()).data());
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- result->count_of_digits = match_text.size();
- matcher->AddMatch(result);
-}
-
-// Extracts consecutive digits as multiple integers according to a format and
-// adds a type match to the matcher accordingly.
-template <typename T>
-void CheckCombinedDigits(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text, grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const std::string& format =
- nonterminal->nonterminal_parameter()->combined_digits_format()->str();
- if (match_text.size() != format.size()) {
- return;
- }
-
- static std::map<char, CombinedDigitsMatch::Index>& kCombinedDigitsMatchIndex =
- *[]() {
- return new std::map<char, CombinedDigitsMatch::Index>{
- {'Y', CombinedDigitsMatch::INDEX_YEAR},
- {'M', CombinedDigitsMatch::INDEX_MONTH},
- {'D', CombinedDigitsMatch::INDEX_DAY},
- {'h', CombinedDigitsMatch::INDEX_HOUR},
- {'m', CombinedDigitsMatch::INDEX_MINUTE},
- {'s', CombinedDigitsMatch::INDEX_SECOND}};
- }();
-
- struct Segment {
- const int index;
- const int length;
- const int value;
- };
- std::vector<Segment> segments;
- int slice_start = 0;
- while (slice_start < format.size()) {
- int slice_end = slice_start + 1;
- // Advace right as long as we have the same format character.
- while (slice_end < format.size() &&
- format[slice_start] == format[slice_end]) {
- slice_end++;
- }
-
- const int slice_length = slice_end - slice_start;
- const int value = ParseLeadingDec32Value(
- std::string(match_text.data() + slice_start, slice_length).c_str());
-
- auto index = kCombinedDigitsMatchIndex.find(format[slice_start]);
- if (index == kCombinedDigitsMatchIndex.end()) {
- return;
- }
- if (!T::IsValid(index->second, value)) {
- return;
- }
- segments.push_back(Segment{index->second, slice_length, value});
- slice_start = slice_end;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- for (const Segment& segment : segments) {
- result->values[segment.index] = segment.value;
- }
- result->count_of_digits = match_text.size();
- result->is_zero_prefixed =
- (match_text[0] == '0' && segments.front().length >= 2);
- matcher->AddMatch(result);
-}
-
-// Retrieves the corresponding value from an associated term-value mapping for
-// the nonterminal and adds a typed match to the matcher accordingly.
-template <typename T>
-void CheckMappedValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const TermValueMatch* term =
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
- if (term == nullptr) {
- return;
- }
- const int value = term->term_value->value();
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- matcher->AddMatch(result);
-}
-
-// Checks if there is an associated value in the corresponding nonterminal and
-// adds a typed match to the matcher accordingly.
-template <typename T>
-void CheckDirectValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const int value = nonterminal->value()->value();
- if (!T::IsValid(value)) {
- return;
- }
- T* result = matcher->AllocateAndInitMatch<T>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = value;
- matcher->AddMatch(result);
-}
-
-template <typename T>
-void CheckAndAddDirectOrMappedValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- if (nonterminal->value() != nullptr) {
- CheckDirectValue<T>(match, nonterminal, matcher);
- } else {
- CheckMappedValue<T>(match, nonterminal, matcher);
- }
-}
-
-template <typename T>
-void CheckAndAddNumericValue(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- StringPiece match_text,
- grammar::Matcher* matcher) {
- if (nonterminal->nonterminal_parameter() != nullptr &&
- nonterminal->nonterminal_parameter()->flag() &
- NonterminalParameter_::Flag_IS_SPELLED) {
- CheckMappedValue<T>(match, nonterminal, matcher);
- } else {
- CheckDigits<T>(match, nonterminal, match_text, matcher);
- }
-}
-
-// Tries to parse as digital time value.
-bool ParseDigitalTimeValue(const std::vector<UnicodeText::const_iterator>& text,
- const MatchComponents& components,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- // Required fields.
- const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
- if (hour == nullptr || hour->count_of_digits == 0) {
- return false;
- }
-
- // Optional fields.
- const MinuteMatch* minute =
- components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
- if (minute != nullptr && minute->count_of_digits == 0) {
- return false;
- }
- const SecondMatch* second =
- components.SubmatchOf<SecondMatch>(MatchType_SECOND);
- if (second != nullptr && second->count_of_digits == 0) {
- return false;
- }
- const FractionSecondMatch* fraction_second =
- components.SubmatchOf<FractionSecondMatch>(MatchType_FRACTION_SECOND);
- if (fraction_second != nullptr && fraction_second->count_of_digits == 0) {
- return false;
- }
-
- // Validation.
- uint32 validation = nonterminal->time_value_parameter()->validation();
- const grammar::Match* end = hour;
- if (minute != nullptr) {
- if (second != nullptr) {
- if (fraction_second != nullptr) {
- end = fraction_second;
- } else {
- end = second;
- }
- } else {
- end = minute;
- }
- }
-
- // Check if there is any extra space between h m s f.
- if ((validation &
- TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
- // Check whether there is whitespace between token.
- if (minute != nullptr && minute->HasLeadingWhitespace()) {
- return false;
- }
- if (second != nullptr && second->HasLeadingWhitespace()) {
- return false;
- }
- if (fraction_second != nullptr && fraction_second->HasLeadingWhitespace()) {
- return false;
- }
- }
-
- // Check if there is any ':' or '.' as a prefix or suffix.
- if (validation &
- TimeValueParameter_::TimeValueValidation_DISALLOW_COLON_DOT_CONTEXT) {
- const int begin_pos = hour->codepoint_span.first;
- const int end_pos = end->codepoint_span.second;
- if (begin_pos > 1 &&
- (*text[begin_pos - 1] == ':' || *text[begin_pos - 1] == '.') &&
- isdigit(*text[begin_pos - 2])) {
- return false;
- }
- // Last valid codepoint is at text.size() - 2 as we added the end position
- // of text for easier span extraction.
- if (end_pos < text.size() - 2 &&
- (*text[end_pos] == ':' || *text[end_pos] == '.') &&
- isdigit(*text[end_pos + 1])) {
- return false;
- }
- }
-
- TimeValueMatch time_value;
- time_value.Init(components.root->lhs, components.root->codepoint_span,
- components.root->match_offset);
- time_value.Reset();
- time_value.hour_match = hour;
- time_value.minute_match = minute;
- time_value.second_match = second;
- time_value.fraction_second_match = fraction_second;
- time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
- time_value.is_minute_one_digit =
- (minute != nullptr && minute->count_of_digits == 1);
- time_value.is_second_one_digit =
- (second != nullptr && second->count_of_digits == 1);
- time_value.hour = hour->value;
- time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
- time_value.second = (second != nullptr ? second->value : NO_VAL);
- time_value.fraction_second =
- (fraction_second != nullptr ? fraction_second->value : NO_VAL);
-
- if (!IsValidTimeValue(time_value)) {
- return false;
- }
-
- TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
- *result = time_value;
- matcher->AddMatch(result);
- return true;
-}
-
-// Tries to parsing a time from spelled out time components.
-bool ParseSpelledTimeValue(const MatchComponents& components,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- // Required fields.
- const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
- if (hour == nullptr || hour->count_of_digits != 0) {
- return false;
- }
- // Optional fields.
- const MinuteMatch* minute =
- components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
- if (minute != nullptr && minute->count_of_digits != 0) {
- return false;
- }
- const SecondMatch* second =
- components.SubmatchOf<SecondMatch>(MatchType_SECOND);
- if (second != nullptr && second->count_of_digits != 0) {
- return false;
- }
-
- uint32 validation = nonterminal->time_value_parameter()->validation();
- // Check if there is any extra space between h m s.
- if ((validation &
- TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
- // Check whether there is whitespace between token.
- if (minute != nullptr && minute->HasLeadingWhitespace()) {
- return false;
- }
- if (second != nullptr && second->HasLeadingWhitespace()) {
- return false;
- }
- }
-
- TimeValueMatch time_value;
- time_value.Init(components.root->lhs, components.root->codepoint_span,
- components.root->match_offset);
- time_value.Reset();
- time_value.hour_match = hour;
- time_value.minute_match = minute;
- time_value.second_match = second;
- time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
- time_value.is_minute_one_digit =
- (minute != nullptr && minute->count_of_digits == 1);
- time_value.is_second_one_digit =
- (second != nullptr && second->count_of_digits == 1);
- time_value.hour = hour->value;
- time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
- time_value.second = (second != nullptr ? second->value : NO_VAL);
-
- if (!IsValidTimeValue(time_value)) {
- return false;
- }
-
- TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
- *result = time_value;
- matcher->AddMatch(result);
- return true;
-}
-
-// Reconstructs and validates a time value from a match.
-void CheckTimeValue(const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- MatchComponents components(
- match, {MatchType_HOUR, MatchType_MINUTE, MatchType_SECOND,
- MatchType_FRACTION_SECOND});
- if (ParseDigitalTimeValue(text, components, nonterminal, matcher)) {
- return;
- }
- if (ParseSpelledTimeValue(components, nonterminal, matcher)) {
- return;
- }
-}
-
-// Validates a time span match.
-void CheckTimeSpan(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- const TermValueMatch* ts_name =
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
- const TermValue* term_value = ts_name->term_value;
- TC3_CHECK(term_value != nullptr);
- TC3_CHECK(term_value->time_span_spec() != nullptr);
- const TimeSpanSpec* ts_spec = term_value->time_span_spec();
- TimeSpanMatch* time_span = matcher->AllocateAndInitMatch<TimeSpanMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- time_span->Reset();
- time_span->nonterminal = nonterminal;
- time_span->time_span_spec = ts_spec;
- time_span->time_span_code = ts_spec->code();
- matcher->AddMatch(time_span);
-}
-
-// Validates a time period match.
-void CheckTimePeriod(const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- int period_value = NO_VAL;
-
- // If a value mapping exists, use it.
- if (nonterminal->value() != nullptr) {
- period_value = nonterminal->value()->value();
- } else if (const TermValueMatch* term =
- grammar::SelectFirstOfType<TermValueMatch>(
- match, MatchType_TERM_VALUE)) {
- period_value = term->term_value->value();
- } else if (const grammar::Match* digits =
- grammar::SelectFirstOfType<grammar::Match>(
- match, grammar::Match::kDigitsType)) {
- period_value = ParseLeadingDec32Value(
- std::string(text[digits->codepoint_span.first].utf8_data(),
- text[digits->codepoint_span.second].utf8_data() -
- text[digits->codepoint_span.first].utf8_data())
- .c_str());
- }
-
- if (period_value <= NO_VAL) {
- return;
- }
-
- TimePeriodMatch* result = matcher->AllocateAndInitMatch<TimePeriodMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->value = period_value;
- matcher->AddMatch(result);
-}
-
-// Reconstructs a date from a relative date rule match.
-void CheckRelativeDate(const DateAnnotationOptions& options,
- const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- if (!options.enable_special_day_offset &&
- grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE) !=
- nullptr) {
- // Special day offsets, like "Today", "Tomorrow" etc. are not enabled.
- return;
- }
-
- RelativeMatch* relative_match = matcher->AllocateAndInitMatch<RelativeMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- relative_match->Reset();
- relative_match->nonterminal = nonterminal;
-
- // Fill relative date information from individual components.
- grammar::Traverse(match, [match, relative_match](const grammar::Match* node) {
- // Ignore the current match.
- if (node == match || node->type == grammar::Match::kUnknownType) {
- return true;
- }
-
- if (node->type == MatchType_TERM_VALUE) {
- const int value =
- static_cast<const TermValueMatch*>(node)->term_value->value();
- relative_match->day = abs(value);
- if (value >= 0) {
- // Marks "today" as in the future.
- relative_match->is_future_date = true;
- }
- relative_match->existing |=
- (RelativeMatch::HAS_DAY | RelativeMatch::HAS_IS_FUTURE);
- return false;
- }
-
- // Parse info from nonterminal.
- const NonterminalValue* nonterminal =
- static_cast<const NonterminalMatch*>(node)->nonterminal;
- if (nonterminal != nullptr &&
- nonterminal->relative_parameter() != nullptr) {
- const RelativeParameter* relative_parameter =
- nonterminal->relative_parameter();
- if (relative_parameter->period() !=
- RelativeParameter_::Period_PERIOD_UNKNOWN) {
- relative_match->is_future_date =
- (relative_parameter->period() ==
- RelativeParameter_::Period_PERIOD_FUTURE);
- relative_match->existing |= RelativeMatch::HAS_IS_FUTURE;
- }
- if (relative_parameter->day_of_week_interpretation() != nullptr) {
- relative_match->day_of_week_nonterminal = nonterminal;
- relative_match->existing |= RelativeMatch::HAS_DAY_OF_WEEK;
- }
- }
-
- // Relative day of week.
- if (node->type == MatchType_DAY_OF_WEEK) {
- relative_match->day_of_week =
- static_cast<const DayOfWeekMatch*>(node)->value;
- return false;
- }
-
- if (node->type != MatchType_TIME_PERIOD) {
- return true;
- }
-
- const TimePeriodMatch* period = static_cast<const TimePeriodMatch*>(node);
- switch (nonterminal->relative_parameter()->type()) {
- case RelativeParameter_::RelativeType_YEAR: {
- relative_match->year = period->value;
- relative_match->existing |= RelativeMatch::HAS_YEAR;
- break;
- }
- case RelativeParameter_::RelativeType_MONTH: {
- relative_match->month = period->value;
- relative_match->existing |= RelativeMatch::HAS_MONTH;
- break;
- }
- case RelativeParameter_::RelativeType_WEEK: {
- relative_match->week = period->value;
- relative_match->existing |= RelativeMatch::HAS_WEEK;
- break;
- }
- case RelativeParameter_::RelativeType_DAY: {
- relative_match->day = period->value;
- relative_match->existing |= RelativeMatch::HAS_DAY;
- break;
- }
- case RelativeParameter_::RelativeType_HOUR: {
- relative_match->hour = period->value;
- relative_match->existing |= RelativeMatch::HAS_HOUR;
- break;
- }
- case RelativeParameter_::RelativeType_MINUTE: {
- relative_match->minute = period->value;
- relative_match->existing |= RelativeMatch::HAS_MINUTE;
- break;
- }
- case RelativeParameter_::RelativeType_SECOND: {
- relative_match->second = period->value;
- relative_match->existing |= RelativeMatch::HAS_SECOND;
- break;
- }
- default:
- break;
- }
-
- return true;
- });
- matcher->AddMatch(relative_match);
-}
-
-bool IsValidTimeZoneOffset(const int time_zone_offset) {
- return (time_zone_offset >= -720 && time_zone_offset <= 840 &&
- time_zone_offset % 15 == 0);
-}
-
-// Parses, validates and adds a time zone offset match.
-void CheckTimeZoneOffset(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- MatchComponents components(
- match, {MatchType_DIGITS, MatchType_TERM_VALUE, MatchType_NONTERMINAL});
- const TermValueMatch* tz_sign =
- components.SubmatchOf<TermValueMatch>(MatchType_TERM_VALUE);
- if (tz_sign == nullptr) {
- return;
- }
- const int sign = tz_sign->term_value->value();
- TC3_CHECK(sign == -1 || sign == 1);
-
- const int tz_digits_index = components.IndexOf(MatchType_DIGITS);
- if (tz_digits_index < 0) {
- return;
- }
- const DigitsMatch* tz_digits =
- components.SubmatchAt<DigitsMatch>(tz_digits_index);
- if (tz_digits == nullptr) {
- return;
- }
-
- int offset;
- if (tz_digits->count_of_digits >= 3) {
- offset = (tz_digits->value / 100) * 60 + (tz_digits->value % 100);
- } else {
- offset = tz_digits->value * 60;
- if (const DigitsMatch* tz_digits_extra = components.SubmatchOf<DigitsMatch>(
- MatchType_DIGITS, /*start_index=*/tz_digits_index + 1)) {
- offset += tz_digits_extra->value;
- }
- }
-
- const NonterminalMatch* tz_offset =
- components.SubmatchOf<NonterminalMatch>(MatchType_NONTERMINAL);
- if (tz_offset == nullptr) {
- return;
- }
-
- const int time_zone_offset = sign * offset;
- if (!IsValidTimeZoneOffset(time_zone_offset)) {
- return;
- }
-
- TimeZoneOffsetMatch* result =
- matcher->AllocateAndInitMatch<TimeZoneOffsetMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->time_zone_offset_param =
- tz_offset->nonterminal->time_zone_offset_parameter();
- result->time_zone_offset = time_zone_offset;
- matcher->AddMatch(result);
-}
-
-// Validates and adds a time zone name match.
-void CheckTimeZoneName(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- TC3_CHECK(match->IsUnaryRule());
- const TermValueMatch* tz_name =
- static_cast<const TermValueMatch*>(match->unary_rule_rhs());
- if (tz_name == nullptr) {
- return;
- }
- const TimeZoneNameSpec* tz_name_spec =
- tz_name->term_value->time_zone_name_spec();
- TimeZoneNameMatch* result = matcher->AllocateAndInitMatch<TimeZoneNameMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- result->Reset();
- result->nonterminal = nonterminal;
- result->time_zone_name_spec = tz_name_spec;
- result->time_zone_code = tz_name_spec->code();
- matcher->AddMatch(result);
-}
-
-// Adds a mapped term value match containing its value.
-void AddTermValue(const grammar::Match* match, const TermValue* term_value,
- grammar::Matcher* matcher) {
- TermValueMatch* term_match = matcher->AllocateAndInitMatch<TermValueMatch>(
- match->lhs, match->codepoint_span, match->match_offset);
- term_match->Reset();
- term_match->term_value = term_value;
- matcher->AddMatch(term_match);
-}
-
-// Adds a match for a nonterminal.
-void AddNonterminal(const grammar::Match* match,
- const NonterminalValue* nonterminal,
- grammar::Matcher* matcher) {
- NonterminalMatch* result =
- matcher->AllocateAndInitMatch<NonterminalMatch>(*match);
- result->Reset();
- result->nonterminal = nonterminal;
- matcher->AddMatch(result);
-}
-
-// Adds a match for an extraction rule that is potentially used in a date range
-// rule.
-void AddExtractionRuleMatch(const grammar::Match* match,
- const ExtractionRuleParameter* rule,
- grammar::Matcher* matcher) {
- ExtractionMatch* result =
- matcher->AllocateAndInitMatch<ExtractionMatch>(*match);
- result->Reset();
- result->extraction_rule = rule;
- matcher->AddMatch(result);
-}
-
-} // namespace
-
-void DateExtractor::HandleExtractionRuleMatch(
- const ExtractionRuleParameter* rule, const grammar::Match* match,
- grammar::Matcher* matcher) {
- if (rule->id() != nullptr) {
- const std::string rule_id = rule->id()->str();
- bool keep = false;
- for (const std::string& extra_requested_dates_id :
- options_.extra_requested_dates) {
- if (extra_requested_dates_id == rule_id) {
- keep = true;
- break;
- }
- }
- if (!keep) {
- return;
- }
- }
- output_.push_back(
- Output{rule, matcher->AllocateAndInitMatch<grammar::Match>(*match)});
-}
-
-void DateExtractor::HandleRangeExtractionRuleMatch(const grammar::Match* match,
- grammar::Matcher* matcher) {
- // Collect the two datetime roots that make up the range.
- std::vector<const grammar::Match*> parts;
- grammar::Traverse(match, [match, &parts](const grammar::Match* node) {
- if (node == match || node->type == grammar::Match::kUnknownType) {
- // Just continue traversing the match.
- return true;
- }
-
- // Collect, but don't expand the individual datetime nodes.
- parts.push_back(node);
- return false;
- });
- TC3_CHECK_EQ(parts.size(), 2);
- range_output_.push_back(
- RangeOutput{matcher->AllocateAndInitMatch<grammar::Match>(*match),
- /*from=*/parts[0], /*to=*/parts[1]});
-}
-
-void DateExtractor::MatchFound(const grammar::Match* match,
- const grammar::CallbackId type,
- const int64 value, grammar::Matcher* matcher) {
- switch (type) {
- case MatchType_DATETIME_RULE: {
- HandleExtractionRuleMatch(
- /*rule=*/
- datetime_rules_->extraction_rule()->Get(value), match, matcher);
- return;
- }
- case MatchType_DATETIME_RANGE_RULE: {
- HandleRangeExtractionRuleMatch(match, matcher);
- return;
- }
- case MatchType_DATETIME: {
- // If an extraction rule is also part of a range extraction rule, then the
- // extraction rule is treated as a rule match and nonterminal match.
- // This type is used to match the rule as non terminal.
- AddExtractionRuleMatch(
- match, datetime_rules_->extraction_rule()->Get(value), matcher);
- return;
- }
- case MatchType_TERM_VALUE: {
- // Handle mapped terms.
- AddTermValue(match, datetime_rules_->term_value()->Get(value), matcher);
- return;
- }
- default:
- break;
- }
-
- // Handle non-terminals.
- const NonterminalValue* nonterminal =
- datetime_rules_->nonterminal_value()->Get(value);
- StringPiece match_text =
- StringPiece(text_[match->codepoint_span.first].utf8_data(),
- text_[match->codepoint_span.second].utf8_data() -
- text_[match->codepoint_span.first].utf8_data());
- switch (type) {
- case MatchType_NONTERMINAL:
- AddNonterminal(match, nonterminal, matcher);
- break;
- case MatchType_DIGITS:
- CheckDigits<DigitsMatch>(match, nonterminal, match_text, matcher);
- break;
- case MatchType_YEAR:
- CheckDigits<YearMatch>(match, nonterminal, match_text, matcher);
- break;
- case MatchType_MONTH:
- CheckAndAddNumericValue<MonthMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_DAY:
- CheckAndAddNumericValue<DayMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_DAY_OF_WEEK:
- CheckAndAddDirectOrMappedValue<DayOfWeekMatch>(match, nonterminal,
- matcher);
- break;
- case MatchType_HOUR:
- CheckAndAddNumericValue<HourMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_MINUTE:
- CheckAndAddNumericValue<MinuteMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_SECOND:
- CheckAndAddNumericValue<SecondMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_FRACTION_SECOND:
- CheckDigitsAsFraction<FractionSecondMatch>(match, nonterminal, match_text,
- matcher);
- break;
- case MatchType_TIME_VALUE:
- CheckTimeValue(text_, match, nonterminal, matcher);
- break;
- case MatchType_TIME_SPAN:
- CheckTimeSpan(match, nonterminal, matcher);
- break;
- case MatchType_TIME_ZONE_NAME:
- CheckTimeZoneName(match, nonterminal, matcher);
- break;
- case MatchType_TIME_ZONE_OFFSET:
- CheckTimeZoneOffset(match, nonterminal, matcher);
- break;
- case MatchType_TIME_PERIOD:
- CheckTimePeriod(text_, match, nonterminal, matcher);
- break;
- case MatchType_RELATIVE_DATE:
- CheckRelativeDate(options_, match, nonterminal, matcher);
- break;
- case MatchType_COMBINED_DIGITS:
- CheckCombinedDigits<CombinedDigitsMatch>(match, nonterminal, match_text,
- matcher);
- break;
- default:
- TC3_VLOG(ERROR) << "Unhandled match type: " << type;
- }
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/extractor.h b/native/annotator/grammar/dates/extractor.h
deleted file mode 100644
index 58c8880..0000000
--- a/native/annotator/grammar/dates/extractor.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
-
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "utils/base/integral_types.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3::dates {
-
-// A helper class for the datetime parser that extracts structured data from
-// the datetime grammar matches.
-// It handles simple sanity checking of the rule matches and interacts with the
-// grammar matcher to extract all datetime occurrences in a text.
-class DateExtractor : public grammar::CallbackDelegate {
- public:
- // Represents a date match for an extraction rule.
- struct Output {
- const ExtractionRuleParameter* rule = nullptr;
- const grammar::Match* match = nullptr;
- };
-
- // Represents a date match from a range extraction rule.
- struct RangeOutput {
- const grammar::Match* match = nullptr;
- const grammar::Match* from = nullptr;
- const grammar::Match* to = nullptr;
- };
-
- DateExtractor(const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options,
- const DatetimeRules* datetime_rules)
- : text_(text), options_(options), datetime_rules_(datetime_rules) {}
-
- // Handle a rule match in the date time grammar.
- // This checks the type of the match and does type dependent checks.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override;
-
- const std::vector<Output>& output() const { return output_; }
- const std::vector<RangeOutput>& range_output() const { return range_output_; }
-
- private:
- // Extracts a date from a root rule match.
- void HandleExtractionRuleMatch(const ExtractionRuleParameter* rule,
- const grammar::Match* match,
- grammar::Matcher* matcher);
-
- // Extracts a date range from a root rule match.
- void HandleRangeExtractionRuleMatch(const grammar::Match* match,
- grammar::Matcher* matcher);
-
- const std::vector<UnicodeText::const_iterator>& text_;
- const DateAnnotationOptions& options_;
- const DatetimeRules* datetime_rules_;
-
- // Extraction results.
- std::vector<Output> output_;
- std::vector<RangeOutput> range_output_;
-};
-
-} // namespace libtextclassifier3::dates
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
diff --git a/native/annotator/grammar/dates/parser.cc b/native/annotator/grammar/dates/parser.cc
deleted file mode 100644
index 37e65fc..0000000
--- a/native/annotator/grammar/dates/parser.cc
+++ /dev/null
@@ -1,794 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/parser.h"
-
-#include "annotator/grammar/dates/extractor.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/base/macros.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/split.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3::dates {
-namespace {
-
-// Helper methods to validate individual components from a date match.
-
-// Checks the validation requirement of a rule against a match.
-// For example if the rule asks for `SPELLED_MONTH`, then we check that the
-// match has the right flag.
-bool CheckMatchValidationAndFlag(
- const grammar::Match* match, const ExtractionRuleParameter* rule,
- const ExtractionRuleParameter_::ExtractionValidation validation,
- const NonterminalParameter_::Flag flag) {
- if (rule == nullptr || (rule->validation() & validation) == 0) {
- // No validation requirement.
- return true;
- }
- const NonterminalParameter* nonterminal_parameter =
- static_cast<const NonterminalMatch*>(match)
- ->nonterminal->nonterminal_parameter();
- return (nonterminal_parameter != nullptr &&
- (nonterminal_parameter->flag() & flag) != 0);
-}
-
-bool GenerateDate(const ExtractionRuleParameter* rule,
- const grammar::Match* match, DateMatch* date) {
- bool is_valid = true;
-
- // Post check and assign date components.
- grammar::Traverse(match, [rule, date, &is_valid](const grammar::Match* node) {
- switch (node->type) {
- case MatchType_YEAR: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_YEAR,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->year_match = static_cast<const YearMatch*>(node);
- date->year = date->year_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_MONTH: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->month_match = static_cast<const MonthMatch*>(node);
- date->month = date->month_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_DAY: {
- if (CheckMatchValidationAndFlag(
- node, rule,
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_DAY,
- NonterminalParameter_::Flag_IS_SPELLED)) {
- date->day_match = static_cast<const DayMatch*>(node);
- date->day = date->day_match->value;
- } else {
- is_valid = false;
- }
- break;
- }
- case MatchType_DAY_OF_WEEK: {
- date->day_of_week_match = static_cast<const DayOfWeekMatch*>(node);
- date->day_of_week =
- static_cast<DayOfWeek>(date->day_of_week_match->value);
- break;
- }
- case MatchType_TIME_VALUE: {
- date->time_value_match = static_cast<const TimeValueMatch*>(node);
- date->hour = date->time_value_match->hour;
- date->minute = date->time_value_match->minute;
- date->second = date->time_value_match->second;
- date->fraction_second = date->time_value_match->fraction_second;
- return false;
- }
- case MatchType_TIME_SPAN: {
- date->time_span_match = static_cast<const TimeSpanMatch*>(node);
- date->time_span_code = date->time_span_match->time_span_code;
- return false;
- }
- case MatchType_TIME_ZONE_NAME: {
- date->time_zone_name_match =
- static_cast<const TimeZoneNameMatch*>(node);
- date->time_zone_code = date->time_zone_name_match->time_zone_code;
- return false;
- }
- case MatchType_TIME_ZONE_OFFSET: {
- date->time_zone_offset_match =
- static_cast<const TimeZoneOffsetMatch*>(node);
- date->time_zone_offset = date->time_zone_offset_match->time_zone_offset;
- return false;
- }
- case MatchType_RELATIVE_DATE: {
- date->relative_match = static_cast<const RelativeMatch*>(node);
- return false;
- }
- case MatchType_COMBINED_DIGITS: {
- date->combined_digits_match =
- static_cast<const CombinedDigitsMatch*>(node);
- if (date->combined_digits_match->HasYear()) {
- date->year = date->combined_digits_match->GetYear();
- }
- if (date->combined_digits_match->HasMonth()) {
- date->month = date->combined_digits_match->GetMonth();
- }
- if (date->combined_digits_match->HasDay()) {
- date->day = date->combined_digits_match->GetDay();
- }
- if (date->combined_digits_match->HasHour()) {
- date->hour = date->combined_digits_match->GetHour();
- }
- if (date->combined_digits_match->HasMinute()) {
- date->minute = date->combined_digits_match->GetMinute();
- }
- if (date->combined_digits_match->HasSecond()) {
- date->second = date->combined_digits_match->GetSecond();
- }
- return false;
- }
- default:
- // Expand node further.
- return true;
- }
-
- return false;
- });
-
- if (is_valid) {
- date->begin = match->codepoint_span.first;
- date->end = match->codepoint_span.second;
- date->priority = rule ? rule->priority_delta() : 0;
- date->annotator_priority_score =
- rule ? rule->annotator_priority_score() : 0.0;
- }
- return is_valid;
-}
-
-bool GenerateFromOrToDateRange(const grammar::Match* match, DateMatch* date) {
- return GenerateDate(
- /*rule=*/(
- match->type == MatchType_DATETIME
- ? static_cast<const ExtractionMatch*>(match)->extraction_rule
- : nullptr),
- match, date);
-}
-
-bool GenerateDateRange(const grammar::Match* match, const grammar::Match* from,
- const grammar::Match* to, DateRangeMatch* date_range) {
- if (!GenerateFromOrToDateRange(from, &date_range->from)) {
- TC3_LOG(WARNING) << "Failed to generate date for `from`.";
- return false;
- }
- if (!GenerateFromOrToDateRange(to, &date_range->to)) {
- TC3_LOG(WARNING) << "Failed to generate date for `to`.";
- return false;
- }
- date_range->begin = match->codepoint_span.first;
- date_range->end = match->codepoint_span.second;
- return true;
-}
-
-bool NormalizeHour(DateMatch* date) {
- if (date->time_span_match == nullptr) {
- // Nothing to do.
- return true;
- }
- return NormalizeHourByTimeSpan(date->time_span_match->time_span_spec, date);
-}
-
-void CheckAndSetAmbiguousHour(DateMatch* date) {
- if (date->HasHour()) {
- // Use am-pm ambiguity as default.
- if (!date->HasTimeSpanCode() && date->hour >= 1 && date->hour <= 12 &&
- !(date->time_value_match != nullptr &&
- date->time_value_match->hour_match != nullptr &&
- date->time_value_match->hour_match->is_zero_prefixed)) {
- date->SetAmbiguousHourProperties(2, 12);
- }
- }
-}
-
-// Normalizes a date candidate.
-// Returns whether the candidate was successfully normalized.
-bool NormalizeDate(DateMatch* date) {
- // Normalize hour.
- if (!NormalizeHour(date)) {
- TC3_VLOG(ERROR) << "Hour normalization (according to time-span) failed."
- << date->DebugString();
- return false;
- }
- CheckAndSetAmbiguousHour(date);
- if (!date->IsValid()) {
- TC3_VLOG(ERROR) << "Fields inside date instance are ill-formed "
- << date->DebugString();
- }
- return true;
-}
-
-// Copies the field from one DateMatch to another whose field is null. for
-// example: if the from is "May 1, 8pm", and the to is "9pm", "May 1" will be
-// copied to "to". Now we only copy fields for date range requirement.fv
-void CopyFieldsForDateMatch(const DateMatch& from, DateMatch* to) {
- if (from.time_span_match != nullptr && to->time_span_match == nullptr) {
- to->time_span_match = from.time_span_match;
- to->time_span_code = from.time_span_code;
- }
- if (from.month_match != nullptr && to->month_match == nullptr) {
- to->month_match = from.month_match;
- to->month = from.month;
- }
-}
-
-// Normalizes a date range candidate.
-// Returns whether the date range was successfully normalized.
-bool NormalizeDateRange(DateRangeMatch* date_range) {
- CopyFieldsForDateMatch(date_range->from, &date_range->to);
- CopyFieldsForDateMatch(date_range->to, &date_range->from);
- return (NormalizeDate(&date_range->from) && NormalizeDate(&date_range->to));
-}
-
-bool CheckDate(const DateMatch& date, const ExtractionRuleParameter* rule) {
- // It's possible that "time_zone_name_match == NULL" when
- // "HasTimeZoneCode() == true", or "time_zone_offset_match == NULL" when
- // "HasTimeZoneOffset() == true" due to inference between endpoints, so we
- // must check if they really exist before using them.
- if (date.HasTimeZoneOffset()) {
- if (date.HasTimeZoneCode()) {
- if (date.time_zone_name_match != nullptr) {
- TC3_CHECK(date.time_zone_name_match->time_zone_name_spec != nullptr);
- const TimeZoneNameSpec* spec =
- date.time_zone_name_match->time_zone_name_spec;
- if (!spec->is_utc()) {
- return false;
- }
- if (!spec->is_abbreviation()) {
- return false;
- }
- }
- } else if (date.time_zone_offset_match != nullptr) {
- TC3_CHECK(date.time_zone_offset_match->time_zone_offset_param != nullptr);
- const TimeZoneOffsetParameter* param =
- date.time_zone_offset_match->time_zone_offset_param;
- if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH) {
- return false;
- }
- if (!(rule->validation() &
- ExtractionRuleParameter_::
- ExtractionValidation_ALLOW_UNCONFIDENT_TIME_ZONE)) {
- if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H_MM ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH_MM ||
- param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HMM) {
- return false;
- }
- }
- }
- }
-
- // Case: 1 April could be extracted as year 1, month april.
- // We simply remove this case.
- if (!date.HasBcAd() && date.year_match != nullptr && date.year < 1000) {
- // We allow case like 11/5/01
- if (date.HasMonth() && date.HasDay() &&
- date.year_match->count_of_digits == 2) {
- } else {
- return false;
- }
- }
-
- // Ignore the date if the year is larger than 9999 (The maximum number of 4
- // digits).
- if (date.year_match != nullptr && date.year > 9999) {
- TC3_VLOG(ERROR) << "Year is greater than 9999.";
- return false;
- }
-
- // Case: spelled may could be month 5, it also used very common as modal
- // verbs. We ignore spelled may as month.
- if ((rule->validation() &
- ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH) &&
- date.month == 5 && !date.HasYear() && !date.HasDay()) {
- return false;
- }
-
- return true;
-}
-
-bool CheckContext(const std::vector<UnicodeText::const_iterator>& text,
- const DateExtractor::Output& output) {
- const uint32 validation = output.rule->validation();
-
- // Nothing to check if we don't have any validation requirements for the
- // span boundaries.
- if ((validation &
- (ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND |
- ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND)) == 0) {
- return true;
- }
-
- const int begin = output.match->codepoint_span.first;
- const int end = output.match->codepoint_span.second;
-
- // So far, we only check that the adjacent character cannot be a separator,
- // like /, - or .
- if ((validation &
- ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND) != 0) {
- if (begin > 0 && (*text[begin - 1] == '/' || *text[begin - 1] == '-' ||
- *text[begin - 1] == ':')) {
- return false;
- }
- }
- if ((validation &
- ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND) != 0) {
- // Last valid codepoint is at text.size() - 2 as we added the end position
- // of text for easier span extraction.
- if (end < text.size() - 1 &&
- (*text[end] == '/' || *text[end] == '-' || *text[end] == ':')) {
- return false;
- }
- }
-
- return true;
-}
-
-// Validates a date match. Returns true if the candidate is valid.
-bool ValidateDate(const std::vector<UnicodeText::const_iterator>& text,
- const DateExtractor::Output& output, const DateMatch& date) {
- if (!CheckDate(date, output.rule)) {
- return false;
- }
- if (!CheckContext(text, output)) {
- return false;
- }
- return true;
-}
-
-// Builds matched date instances from the grammar output.
-std::vector<DateMatch> BuildDateMatches(
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<DateExtractor::Output>& outputs) {
- std::vector<DateMatch> result;
- for (const DateExtractor::Output& output : outputs) {
- DateMatch date;
- if (GenerateDate(output.rule, output.match, &date)) {
- if (!NormalizeDate(&date)) {
- continue;
- }
- if (!ValidateDate(text, output, date)) {
- continue;
- }
- result.push_back(date);
- }
- }
- return result;
-}
-
-// Builds matched date range instances from the grammar output.
-std::vector<DateRangeMatch> BuildDateRangeMatches(
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<DateExtractor::RangeOutput>& range_outputs) {
- std::vector<DateRangeMatch> result;
- for (const DateExtractor::RangeOutput& range_output : range_outputs) {
- DateRangeMatch date_range;
- if (GenerateDateRange(range_output.match, range_output.from,
- range_output.to, &date_range)) {
- if (!NormalizeDateRange(&date_range)) {
- continue;
- }
- result.push_back(date_range);
- }
- }
- return result;
-}
-
-template <typename T>
-void RemoveDeletedMatches(const std::vector<bool>& removed,
- std::vector<T>* matches) {
- int input = 0;
- for (int next = 0; next < matches->size(); ++next) {
- if (removed[next]) {
- continue;
- }
- if (input != next) {
- (*matches)[input] = (*matches)[next];
- }
- input++;
- }
- matches->resize(input);
-}
-
-// Removes duplicated date or date range instances.
-// Overlapping date and date ranges are not considered here.
-template <typename T>
-void RemoveDuplicatedDates(std::vector<T>* matches) {
- // Assumption: matches are sorted ascending by (begin, end).
- std::vector<bool> removed(matches->size(), false);
- for (int i = 0; i < matches->size(); i++) {
- if (removed[i]) {
- continue;
- }
- const T& candidate = matches->at(i);
- for (int j = i + 1; j < matches->size(); j++) {
- if (removed[j]) {
- continue;
- }
- const T& next = matches->at(j);
-
- // Not overlapping.
- if (next.begin >= candidate.end) {
- break;
- }
-
- // If matching the same span of text, then check the priority.
- if (candidate.begin == next.begin && candidate.end == next.end) {
- if (candidate.GetPriority() < next.GetPriority()) {
- removed[i] = true;
- break;
- } else {
- removed[j] = true;
- continue;
- }
- }
-
- // Checks if `next` is fully covered by fields of `candidate`.
- if (next.end <= candidate.end) {
- removed[j] = true;
- continue;
- }
-
- // Checks whether `candidate`/`next` is a refinement.
- if (IsRefinement(candidate, next)) {
- removed[j] = true;
- continue;
- } else if (IsRefinement(next, candidate)) {
- removed[i] = true;
- break;
- }
- }
- }
- RemoveDeletedMatches(removed, matches);
-}
-
-// Filters out simple overtriggering simple matches.
-bool IsBlacklistedDate(const UniLib& unilib,
- const std::vector<UnicodeText::const_iterator>& text,
- const DateMatch& match) {
- const int begin = match.begin;
- const int end = match.end;
- if (end - begin != 3) {
- return false;
- }
-
- std::string text_lower =
- unilib
- .ToLowerText(
- UTF8ToUnicodeText(text[begin].utf8_data(),
- text[end].utf8_data() - text[begin].utf8_data(),
- /*do_copy=*/false))
- .ToUTF8String();
-
- // "sun" is not a good abbreviation for a standalone day of the week.
- if (match.IsStandaloneRelativeDayOfWeek() &&
- (text_lower == "sun" || text_lower == "mon")) {
- return true;
- }
-
- // "mar" is not a good abbreviation for single month.
- if (match.HasMonth() && text_lower == "mar") {
- return true;
- }
-
- return false;
-}
-
-// Checks if two date matches are adjacent and mergeable.
-bool AreDateMatchesAdjacentAndMergeable(
- const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<std::string>& ignored_spans, const DateMatch& prev,
- const DateMatch& next) {
- // Check the context between the two matches.
- if (next.begin <= prev.end) {
- // The two matches are not adjacent.
- return false;
- }
- UnicodeText span;
- for (int i = prev.end; i < next.begin; i++) {
- const char32 codepoint = *text[i];
- if (unilib.IsWhitespace(codepoint)) {
- continue;
- }
- span.push_back(unilib.ToLower(codepoint));
- }
- if (span.empty()) {
- return true;
- }
- const std::string span_text = span.ToUTF8String();
- bool matched = false;
- for (const std::string& ignored_span : ignored_spans) {
- if (span_text == ignored_span) {
- matched = true;
- break;
- }
- }
- if (!matched) {
- return false;
- }
- return IsDateMatchMergeable(prev, next);
-}
-
-// Merges adjacent date and date range.
-// For e.g. Monday, 5-10pm, the date "Monday" and the time range "5-10pm" will
-// be merged
-void MergeDateRangeAndDate(const UniLib& unilib,
- const std::vector<UnicodeText::const_iterator>& text,
- const std::vector<std::string>& ignored_spans,
- const std::vector<DateMatch>& dates,
- std::vector<DateRangeMatch>* date_ranges) {
- // For each range, check the date before or after the it to see if they could
- // be merged. Both the range and date array are sorted, so we only need to
- // scan the date array once.
- int next_date = 0;
- for (int i = 0; i < date_ranges->size(); i++) {
- DateRangeMatch* date_range = &date_ranges->at(i);
- // So far we only merge time range with a date.
- if (!date_range->from.HasHour()) {
- continue;
- }
-
- for (; next_date < dates.size(); next_date++) {
- const DateMatch& date = dates[next_date];
-
- // If the range is before the date, we check whether `date_range->to` can
- // be merged with the date.
- if (date_range->end <= date.begin) {
- DateMatch merged_date = date;
- if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
- date_range->to, date)) {
- MergeDateMatch(date_range->to, &merged_date, /*update_span=*/true);
- date_range->to = merged_date;
- date_range->end = date_range->to.end;
- MergeDateMatch(date, &date_range->from, /*update_span=*/false);
- next_date++;
-
- // Check the second date after the range to see if it could be merged
- // further. For example: 10-11pm, Monday, May 15. 10-11pm is merged
- // with Monday and then we check that it could be merged with May 15
- // as well.
- if (next_date < dates.size()) {
- DateMatch next_match = dates[next_date];
- if (AreDateMatchesAdjacentAndMergeable(
- unilib, text, ignored_spans, date_range->to, next_match)) {
- MergeDateMatch(date_range->to, &next_match, /*update_span=*/true);
- date_range->to = next_match;
- date_range->end = date_range->to.end;
- MergeDateMatch(dates[next_date], &date_range->from,
- /*update_span=*/false);
- next_date++;
- }
- }
- }
- // Since the range is before the date, we try to check if the next range
- // could be merged with the current date.
- break;
- } else if (date_range->end > date.end && date_range->begin > date.begin) {
- // If the range is after the date, we check if `date_range.from` can be
- // merged with the date. Here is a special case, the date before range
- // could be partially overlapped. This is because the range.from could
- // be extracted as year in date. For example: March 3, 10-11pm is
- // extracted as date March 3, 2010 and the range 10-11pm. In this
- // case, we simply clear the year from date.
- DateMatch merged_date = date;
- if (date.HasYear() &&
- date.year_match->codepoint_span.second > date_range->begin) {
- merged_date.year_match = nullptr;
- merged_date.year = NO_VAL;
- merged_date.end = date.year_match->match_offset;
- }
- // Check and merge the range and the date before the range.
- if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
- merged_date, date_range->from)) {
- MergeDateMatch(merged_date, &date_range->from, /*update_span=*/true);
- date_range->begin = date_range->from.begin;
- MergeDateMatch(merged_date, &date_range->to, /*update_span=*/false);
-
- // Check if the second date before the range can be merged as well.
- if (next_date > 0) {
- DateMatch prev_match = dates[next_date - 1];
- if (prev_match.end <= date_range->from.begin) {
- if (AreDateMatchesAdjacentAndMergeable(unilib, text,
- ignored_spans, prev_match,
- date_range->from)) {
- MergeDateMatch(prev_match, &date_range->from,
- /*update_span=*/true);
- date_range->begin = date_range->from.begin;
- MergeDateMatch(prev_match, &date_range->to,
- /*update_span=*/false);
- }
- }
- }
- next_date++;
- break;
- } else {
- // Since the date is before the date range, we move to the next date
- // to check if it could be merged with the current range.
- continue;
- }
- } else {
- // The date is either fully overlapped by the date range or the date
- // span end is after the date range. Move to the next date in both
- // cases.
- }
- }
- }
-}
-
-// Removes the dates which are part of a range. e.g. in "May 1 - 3", the date
-// "May 1" is fully contained in the range.
-void RemoveOverlappedDateByRange(const std::vector<DateRangeMatch>& ranges,
- std::vector<DateMatch>* dates) {
- int next_date = 0;
- std::vector<bool> removed(dates->size(), false);
- for (int i = 0; i < ranges.size(); ++i) {
- const auto& range = ranges[i];
- for (; next_date < dates->size(); ++next_date) {
- const auto& date = dates->at(next_date);
- // So far we don't touch the partially overlapped case.
- if (date.begin >= range.begin && date.end <= range.end) {
- // Fully contained.
- removed[next_date] = true;
- } else if (date.end <= range.begin) {
- continue; // date is behind range, go to next date
- } else if (date.begin >= range.end) {
- break; // range is behind date, go to next range
- }
- }
- }
- RemoveDeletedMatches(removed, dates);
-}
-
-// Converts candidate dates and date ranges.
-void FillDateInstances(
- const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options, std::vector<DateMatch>* date_matches,
- std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
- int i = 0;
- for (int j = 1; j < date_matches->size(); j++) {
- if (options.merge_adjacent_components &&
- AreDateMatchesAdjacentAndMergeable(unilib, text, options.ignored_spans,
- date_matches->at(i),
- date_matches->at(j))) {
- MergeDateMatch(date_matches->at(i), &date_matches->at(j), true);
- } else {
- if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
- }
- i = j;
- }
- if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
-}
-
-void FillDateRangeInstances(
- const std::vector<DateRangeMatch>& date_range_matches,
- std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
- for (const DateRangeMatch& date_range_match : date_range_matches) {
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateRangeInstance(date_range_match, &datetime_parse_result_span);
- datetime_parse_result_spans->push_back(datetime_parse_result_span);
- }
-}
-
-// Fills `DatetimeParseResultSpan` from `DateMatch` and `DateRangeMatch`
-// instances.
-std::vector<DatetimeParseResultSpan> GetOutputAsAnnotationList(
- const UniLib& unilib, const DateExtractor& extractor,
- const std::vector<UnicodeText::const_iterator>& text,
- const DateAnnotationOptions& options) {
- std::vector<DatetimeParseResultSpan> datetime_parse_result_spans;
- std::vector<DateMatch> date_matches =
- BuildDateMatches(text, extractor.output());
-
- std::sort(
- date_matches.begin(), date_matches.end(),
- // Order by increasing begin, and decreasing end (decreasing length).
- [](const DateMatch& a, const DateMatch& b) {
- return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
- });
-
- if (!date_matches.empty()) {
- RemoveDuplicatedDates(&date_matches);
- }
-
- if (options.enable_date_range) {
- std::vector<DateRangeMatch> date_range_matches =
- BuildDateRangeMatches(text, extractor.range_output());
-
- if (!date_range_matches.empty()) {
- std::sort(
- date_range_matches.begin(), date_range_matches.end(),
- // Order by increasing begin, and decreasing end (decreasing length).
- [](const DateRangeMatch& a, const DateRangeMatch& b) {
- return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
- });
- RemoveDuplicatedDates(&date_range_matches);
- }
-
- if (!date_matches.empty()) {
- MergeDateRangeAndDate(unilib, text, options.ignored_spans, date_matches,
- &date_range_matches);
- RemoveOverlappedDateByRange(date_range_matches, &date_matches);
- }
- FillDateRangeInstances(date_range_matches, &datetime_parse_result_spans);
- }
-
- if (!date_matches.empty()) {
- FillDateInstances(unilib, text, options, &date_matches,
- &datetime_parse_result_spans);
- }
- return datetime_parse_result_spans;
-}
-
-} // namespace
-
-std::vector<DatetimeParseResultSpan> DateParser::Parse(
- StringPiece text, const std::vector<Token>& tokens,
- const std::vector<Locale>& locales,
- const DateAnnotationOptions& options) const {
- std::vector<UnicodeText::const_iterator> codepoint_offsets;
- const UnicodeText text_unicode = UTF8ToUnicodeText(text,
- /*do_copy=*/false);
- for (auto it = text_unicode.begin(); it != text_unicode.end(); it++) {
- codepoint_offsets.push_back(it);
- }
- codepoint_offsets.push_back(text_unicode.end());
- DateExtractor extractor(codepoint_offsets, options, datetime_rules_);
- // Select locale matching rules.
- // Only use a shard if locales match or the shard doesn't specify a locale
- // restriction.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(datetime_rules_->rules(), rules_locales_,
- locales);
- if (locale_rules.empty()) {
- return {};
- }
- grammar::Matcher matcher(&unilib_, datetime_rules_->rules(), locale_rules,
- &extractor);
- lexer_.Process(text_unicode, tokens, /*annotations=*/nullptr, &matcher);
- return GetOutputAsAnnotationList(unilib_, extractor, codepoint_offsets,
- options);
-}
-
-} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/parser.h b/native/annotator/grammar/dates/parser.h
deleted file mode 100644
index be919df..0000000
--- a/native/annotator/grammar/dates/parser.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
-
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation-options.h"
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "utils/grammar/lexer.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/i18n/locale.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::dates {
-
-// Parses datetime expressions in the input with the datetime grammar and
-// constructs, validates, deduplicates and normalizes date time annotations.
-class DateParser {
- public:
- explicit DateParser(const UniLib* unilib, const DatetimeRules* datetime_rules)
- : unilib_(*unilib),
- lexer_(unilib, datetime_rules->rules()),
- datetime_rules_(datetime_rules),
- rules_locales_(ParseRulesLocales(datetime_rules->rules())) {}
-
- // Parses the dates in the input. Makes sure that the results do not
- // overlap.
- std::vector<DatetimeParseResultSpan> Parse(
- StringPiece text, const std::vector<Token>& tokens,
- const std::vector<Locale>& locales,
- const DateAnnotationOptions& options) const;
-
- private:
- const UniLib& unilib_;
- const grammar::Lexer lexer_;
-
- // The datetime grammar.
- const DatetimeRules* datetime_rules_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
-};
-
-} // namespace libtextclassifier3::dates
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
diff --git a/native/annotator/grammar/dates/timezone-code.fbs b/native/annotator/grammar/dates/timezone-code.fbs
deleted file mode 100755
index ff615ee..0000000
--- a/native/annotator/grammar/dates/timezone-code.fbs
+++ /dev/null
@@ -1,593 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-namespace libtextclassifier3.dates;
-enum TimezoneCode : int {
- TIMEZONE_CODE_NONE = -1,
- ETC_UNKNOWN = 0,
- PST8PDT = 1,
- // Delegate.
-
- AFRICA_ABIDJAN = 2,
- AFRICA_ACCRA = 3,
- AFRICA_ADDIS_ABABA = 4,
- AFRICA_ALGIERS = 5,
- AFRICA_ASMARA = 6,
- AFRICA_BAMAKO = 7,
- // Delegate.
-
- AFRICA_BANGUI = 8,
- AFRICA_BANJUL = 9,
- AFRICA_BISSAU = 10,
- AFRICA_BLANTYRE = 11,
- AFRICA_BRAZZAVILLE = 12,
- AFRICA_BUJUMBURA = 13,
- EGYPT = 14,
- // Delegate.
-
- AFRICA_CASABLANCA = 15,
- AFRICA_CEUTA = 16,
- AFRICA_CONAKRY = 17,
- AFRICA_DAKAR = 18,
- AFRICA_DAR_ES_SALAAM = 19,
- AFRICA_DJIBOUTI = 20,
- AFRICA_DOUALA = 21,
- AFRICA_EL_AAIUN = 22,
- AFRICA_FREETOWN = 23,
- AFRICA_GABORONE = 24,
- AFRICA_HARARE = 25,
- AFRICA_JOHANNESBURG = 26,
- AFRICA_KAMPALA = 27,
- AFRICA_KHARTOUM = 28,
- AFRICA_KIGALI = 29,
- AFRICA_KINSHASA = 30,
- AFRICA_LAGOS = 31,
- AFRICA_LIBREVILLE = 32,
- AFRICA_LOME = 33,
- AFRICA_LUANDA = 34,
- AFRICA_LUBUMBASHI = 35,
- AFRICA_LUSAKA = 36,
- AFRICA_MALABO = 37,
- AFRICA_MAPUTO = 38,
- AFRICA_MASERU = 39,
- AFRICA_MBABANE = 40,
- AFRICA_MOGADISHU = 41,
- AFRICA_MONROVIA = 42,
- AFRICA_NAIROBI = 43,
- AFRICA_NDJAMENA = 44,
- AFRICA_NIAMEY = 45,
- AFRICA_NOUAKCHOTT = 46,
- AFRICA_OUAGADOUGOU = 47,
- AFRICA_PORTO_NOVO = 48,
- AFRICA_SAO_TOME = 49,
- LIBYA = 51,
- // Delegate.
-
- AFRICA_TUNIS = 52,
- AFRICA_WINDHOEK = 53,
- US_ALEUTIAN = 54,
- // Delegate.
-
- US_ALASKA = 55,
- // Delegate.
-
- AMERICA_ANGUILLA = 56,
- AMERICA_ANTIGUA = 57,
- AMERICA_ARAGUAINA = 58,
- AMERICA_BUENOS_AIRES = 59,
- AMERICA_CATAMARCA = 60,
- AMERICA_CORDOBA = 62,
- AMERICA_JUJUY = 63,
- AMERICA_ARGENTINA_LA_RIOJA = 64,
- AMERICA_MENDOZA = 65,
- AMERICA_ARGENTINA_RIO_GALLEGOS = 66,
- AMERICA_ARGENTINA_SAN_JUAN = 67,
- AMERICA_ARGENTINA_TUCUMAN = 68,
- AMERICA_ARGENTINA_USHUAIA = 69,
- AMERICA_ARUBA = 70,
- AMERICA_ASUNCION = 71,
- AMERICA_BAHIA = 72,
- AMERICA_BARBADOS = 73,
- AMERICA_BELEM = 74,
- AMERICA_BELIZE = 75,
- AMERICA_BOA_VISTA = 76,
- AMERICA_BOGOTA = 77,
- AMERICA_BOISE = 78,
- AMERICA_CAMBRIDGE_BAY = 79,
- AMERICA_CAMPO_GRANDE = 80,
- AMERICA_CANCUN = 81,
- AMERICA_CARACAS = 82,
- AMERICA_CAYENNE = 83,
- AMERICA_CAYMAN = 84,
- CST6CDT = 85,
- // Delegate.
-
- AMERICA_CHIHUAHUA = 86,
- AMERICA_COSTA_RICA = 87,
- AMERICA_CUIABA = 88,
- AMERICA_CURACAO = 89,
- AMERICA_DANMARKSHAVN = 90,
- AMERICA_DAWSON = 91,
- AMERICA_DAWSON_CREEK = 92,
- NAVAJO = 93,
- // Delegate.
-
- US_MICHIGAN = 94,
- // Delegate.
-
- AMERICA_DOMINICA = 95,
- CANADA_MOUNTAIN = 96,
- // Delegate.
-
- AMERICA_EIRUNEPE = 97,
- AMERICA_EL_SALVADOR = 98,
- AMERICA_FORTALEZA = 99,
- AMERICA_GLACE_BAY = 100,
- AMERICA_GODTHAB = 101,
- AMERICA_GOOSE_BAY = 102,
- AMERICA_GRAND_TURK = 103,
- AMERICA_GRENADA = 104,
- AMERICA_GUADELOUPE = 105,
- AMERICA_GUATEMALA = 106,
- AMERICA_GUAYAQUIL = 107,
- AMERICA_GUYANA = 108,
- AMERICA_HALIFAX = 109,
- // Delegate.
-
- CUBA = 110,
- // Delegate.
-
- AMERICA_HERMOSILLO = 111,
- AMERICA_KNOX_IN = 113,
- // Delegate.
-
- AMERICA_INDIANA_MARENGO = 114,
- US_EAST_INDIANA = 115,
- AMERICA_INDIANA_VEVAY = 116,
- AMERICA_INUVIK = 117,
- AMERICA_IQALUIT = 118,
- JAMAICA = 119,
- // Delegate.
-
- AMERICA_JUNEAU = 120,
- AMERICA_KENTUCKY_MONTICELLO = 122,
- AMERICA_LA_PAZ = 123,
- AMERICA_LIMA = 124,
- AMERICA_LOUISVILLE = 125,
- AMERICA_MACEIO = 126,
- AMERICA_MANAGUA = 127,
- BRAZIL_WEST = 128,
- // Delegate.
-
- AMERICA_MARTINIQUE = 129,
- MEXICO_BAJASUR = 130,
- // Delegate.
-
- AMERICA_MENOMINEE = 131,
- AMERICA_MERIDA = 132,
- MEXICO_GENERAL = 133,
- // Delegate.
-
- AMERICA_MIQUELON = 134,
- AMERICA_MONTERREY = 135,
- AMERICA_MONTEVIDEO = 136,
- AMERICA_MONTREAL = 137,
- AMERICA_MONTSERRAT = 138,
- AMERICA_NASSAU = 139,
- EST5EDT = 140,
- // Delegate.
-
- AMERICA_NIPIGON = 141,
- AMERICA_NOME = 142,
- AMERICA_NORONHA = 143,
- // Delegate.
-
- AMERICA_NORTH_DAKOTA_CENTER = 144,
- AMERICA_PANAMA = 145,
- AMERICA_PANGNIRTUNG = 146,
- AMERICA_PARAMARIBO = 147,
- US_ARIZONA = 148,
- // Delegate.
-
- AMERICA_PORT_AU_PRINCE = 149,
- AMERICA_PORT_OF_SPAIN = 150,
- AMERICA_PORTO_VELHO = 151,
- AMERICA_PUERTO_RICO = 152,
- AMERICA_RAINY_RIVER = 153,
- AMERICA_RANKIN_INLET = 154,
- AMERICA_RECIFE = 155,
- AMERICA_REGINA = 156,
- // Delegate.
-
- BRAZIL_ACRE = 157,
- AMERICA_SANTIAGO = 158,
- // Delegate.
-
- AMERICA_SANTO_DOMINGO = 159,
- BRAZIL_EAST = 160,
- // Delegate.
-
- AMERICA_SCORESBYSUND = 161,
- AMERICA_ST_JOHNS = 163,
- // Delegate.
-
- AMERICA_ST_KITTS = 164,
- AMERICA_ST_LUCIA = 165,
- AMERICA_VIRGIN = 166,
- // Delegate.
-
- AMERICA_ST_VINCENT = 167,
- AMERICA_SWIFT_CURRENT = 168,
- AMERICA_TEGUCIGALPA = 169,
- AMERICA_THULE = 170,
- AMERICA_THUNDER_BAY = 171,
- AMERICA_TIJUANA = 172,
- CANADA_EASTERN = 173,
- // Delegate.
-
- AMERICA_TORTOLA = 174,
- CANADA_PACIFIC = 175,
- // Delegate.
-
- CANADA_YUKON = 176,
- // Delegate.
-
- CANADA_CENTRAL = 177,
- // Delegate.
-
- AMERICA_YAKUTAT = 178,
- AMERICA_YELLOWKNIFE = 179,
- ANTARCTICA_CASEY = 180,
- ANTARCTICA_DAVIS = 181,
- ANTARCTICA_DUMONTDURVILLE = 182,
- ANTARCTICA_MAWSON = 183,
- ANTARCTICA_MCMURDO = 184,
- ANTARCTICA_PALMER = 185,
- ANTARCTICA_ROTHERA = 186,
- ANTARCTICA_SYOWA = 188,
- ANTARCTICA_VOSTOK = 189,
- ATLANTIC_JAN_MAYEN = 190,
- // Delegate.
-
- ASIA_ADEN = 191,
- ASIA_ALMATY = 192,
- ASIA_AMMAN = 193,
- ASIA_ANADYR = 194,
- ASIA_AQTAU = 195,
- ASIA_AQTOBE = 196,
- ASIA_ASHGABAT = 197,
- // Delegate.
-
- ASIA_BAGHDAD = 198,
- ASIA_BAHRAIN = 199,
- ASIA_BAKU = 200,
- ASIA_BANGKOK = 201,
- ASIA_BEIRUT = 202,
- ASIA_BISHKEK = 203,
- ASIA_BRUNEI = 204,
- ASIA_KOLKATA = 205,
- // Delegate.
-
- ASIA_CHOIBALSAN = 206,
- ASIA_COLOMBO = 208,
- ASIA_DAMASCUS = 209,
- ASIA_DACCA = 210,
- ASIA_DILI = 211,
- ASIA_DUBAI = 212,
- ASIA_DUSHANBE = 213,
- ASIA_GAZA = 214,
- HONGKONG = 216,
- // Delegate.
-
- ASIA_HOVD = 217,
- ASIA_IRKUTSK = 218,
- ASIA_JAKARTA = 220,
- ASIA_JAYAPURA = 221,
- ISRAEL = 222,
- // Delegate.
-
- ASIA_KABUL = 223,
- ASIA_KAMCHATKA = 224,
- ASIA_KARACHI = 225,
- ASIA_KATMANDU = 227,
- ASIA_KRASNOYARSK = 228,
- ASIA_KUALA_LUMPUR = 229,
- ASIA_KUCHING = 230,
- ASIA_KUWAIT = 231,
- ASIA_MACAO = 232,
- ASIA_MAGADAN = 233,
- ASIA_MAKASSAR = 234,
- // Delegate.
-
- ASIA_MANILA = 235,
- ASIA_MUSCAT = 236,
- ASIA_NICOSIA = 237,
- // Delegate.
-
- ASIA_NOVOSIBIRSK = 238,
- ASIA_OMSK = 239,
- ASIA_ORAL = 240,
- ASIA_PHNOM_PENH = 241,
- ASIA_PONTIANAK = 242,
- ASIA_PYONGYANG = 243,
- ASIA_QATAR = 244,
- ASIA_QYZYLORDA = 245,
- ASIA_RANGOON = 246,
- ASIA_RIYADH = 247,
- ASIA_SAIGON = 248,
- ASIA_SAKHALIN = 249,
- ASIA_SAMARKAND = 250,
- ROK = 251,
- // Delegate.
-
- PRC = 252,
- SINGAPORE = 253,
- // Delegate.
-
- ROC = 254,
- // Delegate.
-
- ASIA_TASHKENT = 255,
- ASIA_TBILISI = 256,
- IRAN = 257,
- // Delegate.
-
- ASIA_THIMBU = 258,
- JAPAN = 259,
- // Delegate.
-
- ASIA_ULAN_BATOR = 260,
- // Delegate.
-
- ASIA_URUMQI = 261,
- ASIA_VIENTIANE = 262,
- ASIA_VLADIVOSTOK = 263,
- ASIA_YAKUTSK = 264,
- ASIA_YEKATERINBURG = 265,
- ASIA_YEREVAN = 266,
- ATLANTIC_AZORES = 267,
- ATLANTIC_BERMUDA = 268,
- ATLANTIC_CANARY = 269,
- ATLANTIC_CAPE_VERDE = 270,
- ATLANTIC_FAROE = 271,
- // Delegate.
-
- ATLANTIC_MADEIRA = 273,
- ICELAND = 274,
- // Delegate.
-
- ATLANTIC_SOUTH_GEORGIA = 275,
- ATLANTIC_STANLEY = 276,
- ATLANTIC_ST_HELENA = 277,
- AUSTRALIA_SOUTH = 278,
- // Delegate.
-
- AUSTRALIA_BRISBANE = 279,
- // Delegate.
-
- AUSTRALIA_YANCOWINNA = 280,
- // Delegate.
-
- AUSTRALIA_NORTH = 281,
- // Delegate.
-
- AUSTRALIA_HOBART = 282,
- // Delegate.
-
- AUSTRALIA_LINDEMAN = 283,
- AUSTRALIA_LHI = 284,
- AUSTRALIA_VICTORIA = 285,
- // Delegate.
-
- AUSTRALIA_WEST = 286,
- // Delegate.
-
- AUSTRALIA_ACT = 287,
- EUROPE_AMSTERDAM = 288,
- EUROPE_ANDORRA = 289,
- EUROPE_ATHENS = 290,
- EUROPE_BELGRADE = 292,
- EUROPE_BERLIN = 293,
- EUROPE_BRATISLAVA = 294,
- EUROPE_BRUSSELS = 295,
- EUROPE_BUCHAREST = 296,
- EUROPE_BUDAPEST = 297,
- EUROPE_CHISINAU = 298,
- // Delegate.
-
- EUROPE_COPENHAGEN = 299,
- EIRE = 300,
- EUROPE_GIBRALTAR = 301,
- EUROPE_HELSINKI = 302,
- TURKEY = 303,
- EUROPE_KALININGRAD = 304,
- EUROPE_KIEV = 305,
- PORTUGAL = 306,
- // Delegate.
-
- EUROPE_LJUBLJANA = 307,
- GB = 308,
- EUROPE_LUXEMBOURG = 309,
- EUROPE_MADRID = 310,
- EUROPE_MALTA = 311,
- EUROPE_MARIEHAMN = 312,
- EUROPE_MINSK = 313,
- EUROPE_MONACO = 314,
- W_SU = 315,
- // Delegate.
-
- EUROPE_OSLO = 317,
- EUROPE_PARIS = 318,
- EUROPE_PRAGUE = 319,
- EUROPE_RIGA = 320,
- EUROPE_ROME = 321,
- EUROPE_SAMARA = 322,
- EUROPE_SAN_MARINO = 323,
- EUROPE_SARAJEVO = 324,
- EUROPE_SIMFEROPOL = 325,
- EUROPE_SKOPJE = 326,
- EUROPE_SOFIA = 327,
- EUROPE_STOCKHOLM = 328,
- EUROPE_TALLINN = 329,
- EUROPE_TIRANE = 330,
- EUROPE_UZHGOROD = 331,
- EUROPE_VADUZ = 332,
- EUROPE_VATICAN = 333,
- EUROPE_VIENNA = 334,
- EUROPE_VILNIUS = 335,
- POLAND = 336,
- // Delegate.
-
- EUROPE_ZAGREB = 337,
- EUROPE_ZAPOROZHYE = 338,
- EUROPE_ZURICH = 339,
- INDIAN_ANTANANARIVO = 340,
- INDIAN_CHAGOS = 341,
- INDIAN_CHRISTMAS = 342,
- INDIAN_COCOS = 343,
- INDIAN_COMORO = 344,
- INDIAN_KERGUELEN = 345,
- INDIAN_MAHE = 346,
- INDIAN_MALDIVES = 347,
- INDIAN_MAURITIUS = 348,
- INDIAN_MAYOTTE = 349,
- INDIAN_REUNION = 350,
- PACIFIC_APIA = 351,
- NZ = 352,
- NZ_CHAT = 353,
- PACIFIC_EASTER = 354,
- PACIFIC_EFATE = 355,
- PACIFIC_ENDERBURY = 356,
- PACIFIC_FAKAOFO = 357,
- PACIFIC_FIJI = 358,
- PACIFIC_FUNAFUTI = 359,
- PACIFIC_GALAPAGOS = 360,
- PACIFIC_GAMBIER = 361,
- PACIFIC_GUADALCANAL = 362,
- PACIFIC_GUAM = 363,
- US_HAWAII = 364,
- // Delegate.
-
- PACIFIC_JOHNSTON = 365,
- PACIFIC_KIRITIMATI = 366,
- PACIFIC_KOSRAE = 367,
- KWAJALEIN = 368,
- PACIFIC_MAJURO = 369,
- PACIFIC_MARQUESAS = 370,
- PACIFIC_MIDWAY = 371,
- PACIFIC_NAURU = 372,
- PACIFIC_NIUE = 373,
- PACIFIC_NORFOLK = 374,
- PACIFIC_NOUMEA = 375,
- US_SAMOA = 376,
- // Delegate.
-
- PACIFIC_PALAU = 377,
- PACIFIC_PITCAIRN = 378,
- PACIFIC_PONAPE = 379,
- PACIFIC_PORT_MORESBY = 380,
- PACIFIC_RAROTONGA = 381,
- PACIFIC_SAIPAN = 382,
- PACIFIC_TAHITI = 383,
- PACIFIC_TARAWA = 384,
- PACIFIC_TONGATAPU = 385,
- PACIFIC_YAP = 386,
- PACIFIC_WAKE = 387,
- PACIFIC_WALLIS = 388,
- AMERICA_ATIKOKAN = 390,
- AUSTRALIA_CURRIE = 391,
- ETC_GMT_EAST_14 = 392,
- ETC_GMT_EAST_13 = 393,
- ETC_GMT_EAST_12 = 394,
- ETC_GMT_EAST_11 = 395,
- ETC_GMT_EAST_10 = 396,
- ETC_GMT_EAST_9 = 397,
- ETC_GMT_EAST_8 = 398,
- ETC_GMT_EAST_7 = 399,
- ETC_GMT_EAST_6 = 400,
- ETC_GMT_EAST_5 = 401,
- ETC_GMT_EAST_4 = 402,
- ETC_GMT_EAST_3 = 403,
- ETC_GMT_EAST_2 = 404,
- ETC_GMT_EAST_1 = 405,
- GMT = 406,
- // Delegate.
-
- ETC_GMT_WEST_1 = 407,
- ETC_GMT_WEST_2 = 408,
- ETC_GMT_WEST_3 = 409,
- SYSTEMV_AST4 = 410,
- // Delegate.
-
- EST = 411,
- SYSTEMV_CST6 = 412,
- // Delegate.
-
- MST = 413,
- // Delegate.
-
- SYSTEMV_PST8 = 414,
- // Delegate.
-
- SYSTEMV_YST9 = 415,
- // Delegate.
-
- HST = 416,
- // Delegate.
-
- ETC_GMT_WEST_11 = 417,
- ETC_GMT_WEST_12 = 418,
- AMERICA_NORTH_DAKOTA_NEW_SALEM = 419,
- AMERICA_INDIANA_PETERSBURG = 420,
- AMERICA_INDIANA_VINCENNES = 421,
- AMERICA_MONCTON = 422,
- AMERICA_BLANC_SABLON = 423,
- EUROPE_GUERNSEY = 424,
- EUROPE_ISLE_OF_MAN = 425,
- EUROPE_JERSEY = 426,
- EUROPE_PODGORICA = 427,
- EUROPE_VOLGOGRAD = 428,
- AMERICA_INDIANA_WINAMAC = 429,
- AUSTRALIA_EUCLA = 430,
- AMERICA_INDIANA_TELL_CITY = 431,
- AMERICA_RESOLUTE = 432,
- AMERICA_ARGENTINA_SAN_LUIS = 433,
- AMERICA_SANTAREM = 434,
- AMERICA_ARGENTINA_SALTA = 435,
- AMERICA_BAHIA_BANDERAS = 436,
- AMERICA_MARIGOT = 437,
- AMERICA_MATAMOROS = 438,
- AMERICA_OJINAGA = 439,
- AMERICA_SANTA_ISABEL = 440,
- AMERICA_ST_BARTHELEMY = 441,
- ANTARCTICA_MACQUARIE = 442,
- ASIA_NOVOKUZNETSK = 443,
- AFRICA_JUBA = 444,
- AMERICA_METLAKATLA = 445,
- AMERICA_NORTH_DAKOTA_BEULAH = 446,
- AMERICA_SITKA = 447,
- ASIA_HEBRON = 448,
- AMERICA_CRESTON = 449,
- AMERICA_KRALENDIJK = 450,
- AMERICA_LOWER_PRINCES = 451,
- ANTARCTICA_TROLL = 452,
- ASIA_KHANDYGA = 453,
- ASIA_UST_NERA = 454,
- EUROPE_BUSINGEN = 455,
- ASIA_CHITA = 456,
- ASIA_SREDNEKOLYMSK = 457,
-}
-
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.cc b/native/annotator/grammar/dates/utils/annotation-keys.cc
deleted file mode 100644
index 3438c6d..0000000
--- a/native/annotator/grammar/dates/utils/annotation-keys.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-const char* const kDateTimeType = "dateTime";
-const char* const kDateTimeRangeType = "dateTimeRange";
-const char* const kDateTime = "dateTime";
-const char* const kDateTimeSupplementary = "dateTimeSupplementary";
-const char* const kDateTimeRelative = "dateTimeRelative";
-const char* const kDateTimeRangeFrom = "dateTimeRangeFrom";
-const char* const kDateTimeRangeTo = "dateTimeRangeTo";
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.h b/native/annotator/grammar/dates/utils/annotation-keys.h
deleted file mode 100644
index f970a51..0000000
--- a/native/annotator/grammar/dates/utils/annotation-keys.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
-
-namespace libtextclassifier3 {
-namespace dates {
-
-// Date time specific constants not defined in standard schemas.
-//
-// Date annotator output two type of annotation. One is date&time like "May 1",
-// "12:20pm", etc. Another is range like "2pm - 3pm". The two string identify
-// the type of annotation and are used as type in Thing proto.
-extern const char* const kDateTimeType;
-extern const char* const kDateTimeRangeType;
-
-// kDateTime contains most common field for date time. It's integer array and
-// the format is (year, month, day, hour, minute, second, fraction_sec,
-// day_of_week). All eight fields must be provided. If the field is not
-// extracted, the value is -1 in the array.
-extern const char* const kDateTime;
-
-// kDateTimeSupplementary contains uncommon field like timespan, timezone. It's
-// integer array and the format is (bc_ad, timespan_code, timezone_code,
-// timezone_offset). Al four fields must be provided. If the field is not
-// extracted, the value is -1 in the array.
-extern const char* const kDateTimeSupplementary;
-
-// kDateTimeRelative contains fields for relative date time. It's integer
-// array and the format is (is_future, year, month, day, week, hour, minute,
-// second, day_of_week, dow_interpretation*). The first nine fields must be
-// provided and dow_interpretation could have zero or multiple values.
-// If the field is not extracted, the value is -1 in the array.
-extern const char* const kDateTimeRelative;
-
-// Date time range specific constants not defined in standard schemas.
-// kDateTimeRangeFrom and kDateTimeRangeTo define the from/to of a date/time
-// range. The value is thing object which contains a date time.
-extern const char* const kDateTimeRangeFrom;
-extern const char* const kDateTimeRangeTo;
-
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
diff --git a/native/annotator/grammar/dates/utils/date-match.cc b/native/annotator/grammar/dates/utils/date-match.cc
deleted file mode 100644
index d9fca52..0000000
--- a/native/annotator/grammar/dates/utils/date-match.cc
+++ /dev/null
@@ -1,440 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-match.h"
-
-#include <algorithm>
-
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "annotator/types.h"
-#include "utils/strings/append.h"
-
-static const int kAM = 0;
-static const int kPM = 1;
-
-namespace libtextclassifier3 {
-namespace dates {
-
-namespace {
-static int GetMeridiemValue(const TimespanCode& timespan_code) {
- switch (timespan_code) {
- case TimespanCode_AM:
- case TimespanCode_MIDNIGHT:
- // MIDNIGHT [3] -> AM
- return kAM;
- case TimespanCode_TONIGHT:
- // TONIGHT [11] -> PM
- case TimespanCode_NOON:
- // NOON [2] -> PM
- case TimespanCode_PM:
- return kPM;
- case TimespanCode_TIMESPAN_CODE_NONE:
- default:
- TC3_LOG(WARNING) << "Failed to extract time span code.";
- }
- return NO_VAL;
-}
-
-static int GetRelativeCount(const RelativeParameter* relative_parameter) {
- for (const int interpretation :
- *relative_parameter->day_of_week_interpretation()) {
- switch (interpretation) {
- case RelativeParameter_::Interpretation_NEAREST_LAST:
- case RelativeParameter_::Interpretation_PREVIOUS:
- return -1;
- case RelativeParameter_::Interpretation_SECOND_LAST:
- return -2;
- case RelativeParameter_::Interpretation_SECOND_NEXT:
- return 2;
- case RelativeParameter_::Interpretation_COMING:
- case RelativeParameter_::Interpretation_SOME:
- case RelativeParameter_::Interpretation_NEAREST:
- case RelativeParameter_::Interpretation_NEAREST_NEXT:
- return 1;
- case RelativeParameter_::Interpretation_CURRENT:
- return 0;
- }
- }
- return 0;
-}
-} // namespace
-
-using strings::JoinStrings;
-using strings::SStringAppendF;
-
-std::string DateMatch::DebugString() const {
- std::string res;
-#if !defined(NDEBUG)
- if (begin >= 0 && end >= 0) {
- SStringAppendF(&res, 0, "[%u,%u)", begin, end);
- }
-
- if (HasDayOfWeek()) {
- SStringAppendF(&res, 0, "%u", day_of_week);
- }
-
- if (HasYear()) {
- int year_output = year;
- if (HasBcAd() && bc_ad == BCAD_BC) {
- year_output = -year;
- }
- SStringAppendF(&res, 0, "%u/", year_output);
- } else {
- SStringAppendF(&res, 0, "____/");
- }
-
- if (HasMonth()) {
- SStringAppendF(&res, 0, "%u/", month);
- } else {
- SStringAppendF(&res, 0, "__/");
- }
-
- if (HasDay()) {
- SStringAppendF(&res, 0, "%u ", day);
- } else {
- SStringAppendF(&res, 0, "__ ");
- }
-
- if (HasHour()) {
- SStringAppendF(&res, 0, "%u:", hour);
- } else {
- SStringAppendF(&res, 0, "__:");
- }
-
- if (HasMinute()) {
- SStringAppendF(&res, 0, "%u:", minute);
- } else {
- SStringAppendF(&res, 0, "__:");
- }
-
- if (HasSecond()) {
- if (HasFractionSecond()) {
- SStringAppendF(&res, 0, "%u.%lf ", second, fraction_second);
- } else {
- SStringAppendF(&res, 0, "%u ", second);
- }
- } else {
- SStringAppendF(&res, 0, "__ ");
- }
-
- if (HasTimeSpanCode() && TimespanCode_TIMESPAN_CODE_NONE < time_span_code &&
- time_span_code <= TimespanCode_MAX) {
- SStringAppendF(&res, 0, "TS=%u ", time_span_code);
- }
-
- if (HasTimeZoneCode() && time_zone_code != -1) {
- SStringAppendF(&res, 0, "TZ= %u ", time_zone_code);
- }
-
- if (HasTimeZoneOffset()) {
- SStringAppendF(&res, 0, "TZO=%u ", time_zone_offset);
- }
-
- if (HasRelativeDate()) {
- const RelativeMatch* rm = relative_match;
- SStringAppendF(&res, 0, (rm->is_future_date ? "future " : "past "));
- if (rm->day_of_week != NO_VAL) {
- SStringAppendF(&res, 0, "DOW:%d ", rm->day_of_week);
- }
- if (rm->year != NO_VAL) {
- SStringAppendF(&res, 0, "Y:%d ", rm->year);
- }
- if (rm->month != NO_VAL) {
- SStringAppendF(&res, 0, "M:%d ", rm->month);
- }
- if (rm->day != NO_VAL) {
- SStringAppendF(&res, 0, "D:%d ", rm->day);
- }
- if (rm->week != NO_VAL) {
- SStringAppendF(&res, 0, "W:%d ", rm->week);
- }
- if (rm->hour != NO_VAL) {
- SStringAppendF(&res, 0, "H:%d ", rm->hour);
- }
- if (rm->minute != NO_VAL) {
- SStringAppendF(&res, 0, "M:%d ", rm->minute);
- }
- if (rm->second != NO_VAL) {
- SStringAppendF(&res, 0, "S:%d ", rm->second);
- }
- }
-
- SStringAppendF(&res, 0, "prio=%d ", priority);
- SStringAppendF(&res, 0, "conf-score=%lf ", annotator_priority_score);
-
- if (IsHourAmbiguous()) {
- std::vector<int8> values;
- GetPossibleHourValues(&values);
- std::string str_values;
-
- for (unsigned int i = 0; i < values.size(); ++i) {
- SStringAppendF(&str_values, 0, "%u,", values[i]);
- }
- SStringAppendF(&res, 0, "amb=%s ", str_values.c_str());
- }
-
- std::vector<std::string> tags;
- if (is_inferred) {
- tags.push_back("inferred");
- }
- if (!tags.empty()) {
- SStringAppendF(&res, 0, "tag=%s ", JoinStrings(",", tags).c_str());
- }
-#endif // !defined(NDEBUG)
- return res;
-}
-
-void DateMatch::GetPossibleHourValues(std::vector<int8>* values) const {
- TC3_CHECK(values != nullptr);
- values->clear();
- if (HasHour()) {
- int8 possible_hour = hour;
- values->push_back(possible_hour);
- for (int count = 1; count < ambiguous_hour_count; ++count) {
- possible_hour += ambiguous_hour_interval;
- if (possible_hour >= 24) {
- possible_hour -= 24;
- }
- values->push_back(possible_hour);
- }
- }
-}
-
-DatetimeComponent::RelativeQualifier DateMatch::GetRelativeQualifier() const {
- if (HasRelativeDate()) {
- if (relative_match->existing & RelativeMatch::HAS_IS_FUTURE) {
- if (!relative_match->is_future_date) {
- return DatetimeComponent::RelativeQualifier::PAST;
- }
- }
- return DatetimeComponent::RelativeQualifier::FUTURE;
- }
- return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
-}
-
-// Embed RelativeQualifier information of DatetimeComponent as a sign of
-// relative counter field of datetime component i.e. relative counter is
-// negative when relative qualifier RelativeQualifier::PAST.
-int GetAdjustedRelativeCounter(
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const int relative_counter) {
- if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) {
- return -relative_counter;
- }
- return relative_counter;
-}
-
-Optional<DatetimeComponent> CreateDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type,
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const int absolute_value, const int relative_value) {
- if (absolute_value == NO_VAL && relative_value == NO_VAL) {
- return Optional<DatetimeComponent>();
- }
- return Optional<DatetimeComponent>(DatetimeComponent(
- component_type,
- (relative_value != NO_VAL)
- ? relative_qualifier
- : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
- (absolute_value != NO_VAL) ? absolute_value : 0,
- (relative_value != NO_VAL)
- ? GetAdjustedRelativeCounter(relative_qualifier, relative_value)
- : 0));
-}
-
-Optional<DatetimeComponent> CreateDayOfWeekComponent(
- const RelativeMatch* relative_match,
- const DatetimeComponent::RelativeQualifier& relative_qualifier,
- const DayOfWeek& absolute_day_of_week) {
- DatetimeComponent::RelativeQualifier updated_relative_qualifier =
- relative_qualifier;
- int absolute_value = absolute_day_of_week;
- int relative_value = NO_VAL;
- if (relative_match) {
- relative_value = relative_match->day_of_week;
- if (relative_match->existing & RelativeMatch::HAS_DAY_OF_WEEK) {
- if (relative_match->IsStandaloneRelativeDayOfWeek() &&
- absolute_day_of_week == DayOfWeek_DOW_NONE) {
- absolute_value = relative_match->day_of_week;
- }
- // Check if the relative date has day of week with week period.
- if (relative_match->existing & RelativeMatch::HAS_WEEK) {
- relative_value = 1;
- } else {
- const NonterminalValue* nonterminal =
- relative_match->day_of_week_nonterminal;
- TC3_CHECK(nonterminal != nullptr);
- TC3_CHECK(nonterminal->relative_parameter());
- const RelativeParameter* rp = nonterminal->relative_parameter();
- if (rp->day_of_week_interpretation()) {
- relative_value = GetRelativeCount(rp);
- if (relative_value < 0) {
- relative_value = abs(relative_value);
- updated_relative_qualifier =
- DatetimeComponent::RelativeQualifier::PAST;
- } else if (relative_value > 0) {
- updated_relative_qualifier =
- DatetimeComponent::RelativeQualifier::FUTURE;
- }
- }
- }
- }
- }
- return CreateDatetimeComponent(DatetimeComponent::ComponentType::DAY_OF_WEEK,
- updated_relative_qualifier, absolute_value,
- relative_value);
-}
-
-// Resolve the year’s ambiguity.
-// If the year in the date has 4 digits i.e. DD/MM/YYYY then there is no
-// ambiguity, the year value is YYYY but certain format i.e. MM/DD/YY is
-// ambiguous e.g. in {April/23/15} year value can be 15 or 1915 or 2015.
-// Following heuristic is used to resolve the ambiguity.
-// - For YYYY there is nothing to resolve.
-// - For all YY years
-// - Value less than 50 will be resolved to 20YY
-// - Value greater or equal 50 will be resolved to 19YY
-static int InterpretYear(int parsed_year) {
- if (parsed_year == NO_VAL) {
- return parsed_year;
- }
- if (parsed_year < 100) {
- if (parsed_year < 50) {
- return parsed_year + 2000;
- }
- return parsed_year + 1900;
- }
- return parsed_year;
-}
-
-Optional<DatetimeComponent> DateMatch::GetDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type) const {
- switch (component_type) {
- case DatetimeComponent::ComponentType::YEAR:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), InterpretYear(year),
- (relative_match != nullptr) ? relative_match->year : NO_VAL);
- case DatetimeComponent::ComponentType::MONTH:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), month,
- (relative_match != nullptr) ? relative_match->month : NO_VAL);
- case DatetimeComponent::ComponentType::DAY_OF_MONTH:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), day,
- (relative_match != nullptr) ? relative_match->day : NO_VAL);
- case DatetimeComponent::ComponentType::HOUR:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), hour,
- (relative_match != nullptr) ? relative_match->hour : NO_VAL);
- case DatetimeComponent::ComponentType::MINUTE:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), minute,
- (relative_match != nullptr) ? relative_match->minute : NO_VAL);
- case DatetimeComponent::ComponentType::SECOND:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), second,
- (relative_match != nullptr) ? relative_match->second : NO_VAL);
- case DatetimeComponent::ComponentType::DAY_OF_WEEK:
- return CreateDayOfWeekComponent(relative_match, GetRelativeQualifier(),
- day_of_week);
- case DatetimeComponent::ComponentType::MERIDIEM:
- return CreateDatetimeComponent(component_type, GetRelativeQualifier(),
- GetMeridiemValue(time_span_code), NO_VAL);
- case DatetimeComponent::ComponentType::ZONE_OFFSET:
- if (HasTimeZoneOffset()) {
- return Optional<DatetimeComponent>(DatetimeComponent(
- component_type, DatetimeComponent::RelativeQualifier::UNSPECIFIED,
- time_zone_offset, /*arg_relative_count=*/0));
- }
- return Optional<DatetimeComponent>();
- case DatetimeComponent::ComponentType::WEEK:
- return CreateDatetimeComponent(
- component_type, GetRelativeQualifier(), NO_VAL,
- HasRelativeDate() ? relative_match->week : NO_VAL);
- default:
- return Optional<DatetimeComponent>();
- }
-}
-
-bool DateMatch::IsValid() const {
- if (!HasYear() && HasBcAd()) {
- return false;
- }
- if (!HasMonth() && HasYear() && (HasDay() || HasDayOfWeek())) {
- return false;
- }
- if (!HasDay() && HasDayOfWeek() && (HasYear() || HasMonth())) {
- return false;
- }
- if (!HasDay() && !HasDayOfWeek() && HasHour() && (HasYear() || HasMonth())) {
- return false;
- }
- if (!HasHour() && (HasMinute() || HasSecond() || HasFractionSecond())) {
- return false;
- }
- if (!HasMinute() && (HasSecond() || HasFractionSecond())) {
- return false;
- }
- if (!HasSecond() && HasFractionSecond()) {
- return false;
- }
- // Check whether day exists in a month, to exclude cases like "April 31".
- if (HasDay() && HasMonth() && day > GetLastDayOfMonth(year, month)) {
- return false;
- }
- return (HasDateFields() || HasTimeFields() || HasRelativeDate());
-}
-
-void DateMatch::FillDatetimeComponents(
- std::vector<DatetimeComponent>* datetime_component) const {
- static const std::vector<DatetimeComponent::ComponentType>*
- kDatetimeComponents = new std::vector<DatetimeComponent::ComponentType>{
- DatetimeComponent::ComponentType::ZONE_OFFSET,
- DatetimeComponent::ComponentType::MERIDIEM,
- DatetimeComponent::ComponentType::SECOND,
- DatetimeComponent::ComponentType::MINUTE,
- DatetimeComponent::ComponentType::HOUR,
- DatetimeComponent::ComponentType::DAY_OF_MONTH,
- DatetimeComponent::ComponentType::DAY_OF_WEEK,
- DatetimeComponent::ComponentType::WEEK,
- DatetimeComponent::ComponentType::MONTH,
- DatetimeComponent::ComponentType::YEAR};
-
- for (const DatetimeComponent::ComponentType& component_type :
- *kDatetimeComponents) {
- Optional<DatetimeComponent> date_time =
- GetDatetimeComponent(component_type);
- if (date_time.has_value()) {
- datetime_component->emplace_back(date_time.value());
- }
- }
-}
-
-std::string DateRangeMatch::DebugString() const {
- std::string res;
- // The method is only called for debugging purposes.
-#if !defined(NDEBUG)
- if (begin >= 0 && end >= 0) {
- SStringAppendF(&res, 0, "[%u,%u)\n", begin, end);
- }
- SStringAppendF(&res, 0, "from: %s \n", from.DebugString().c_str());
- SStringAppendF(&res, 0, "to: %s\n", to.DebugString().c_str());
-#endif // !defined(NDEBUG)
- return res;
-}
-
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-match.h b/native/annotator/grammar/dates/utils/date-match.h
deleted file mode 100644
index 285e9b3..0000000
--- a/native/annotator/grammar/dates/utils/date-match.h
+++ /dev/null
@@ -1,537 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
-#include <vector>
-
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/timezone-code_generated.h"
-#include "utils/grammar/match.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-static constexpr int NO_VAL = -1;
-
-// POD match data structure.
-struct MatchBase : public grammar::Match {
- void Reset() { type = MatchType::MatchType_UNKNOWN; }
-};
-
-struct ExtractionMatch : public MatchBase {
- const ExtractionRuleParameter* extraction_rule;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_DATETIME_RULE;
- extraction_rule = nullptr;
- }
-};
-
-struct TermValueMatch : public MatchBase {
- const TermValue* term_value;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_TERM_VALUE;
- term_value = nullptr;
- }
-};
-
-struct NonterminalMatch : public MatchBase {
- const NonterminalValue* nonterminal;
-
- void Reset() {
- MatchBase::Reset();
- type = MatchType::MatchType_NONTERMINAL;
- nonterminal = nullptr;
- }
-};
-
-struct IntegerMatch : public NonterminalMatch {
- int value;
- int8 count_of_digits; // When expression is in digits format.
- bool is_zero_prefixed; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- value = NO_VAL;
- count_of_digits = 0;
- is_zero_prefixed = false;
- }
-};
-
-struct DigitsMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DIGITS;
- }
-
- static bool IsValid(int x) { return true; }
-};
-
-struct YearMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_YEAR;
- }
-
- static bool IsValid(int x) { return x >= 1; }
-};
-
-struct MonthMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_MONTH;
- }
-
- static bool IsValid(int x) { return (x >= 1 && x <= 12); }
-};
-
-struct DayMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DAY;
- }
-
- static bool IsValid(int x) { return (x >= 1 && x <= 31); }
-};
-
-struct HourMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_HOUR;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 24); }
-};
-
-struct MinuteMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_MINUTE;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 59); }
-};
-
-struct SecondMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_SECOND;
- }
-
- static bool IsValid(int x) { return (x >= 0 && x <= 60); }
-};
-
-struct DecimalMatch : public NonterminalMatch {
- double value;
- int8 count_of_digits; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- value = NO_VAL;
- count_of_digits = 0;
- }
-};
-
-struct FractionSecondMatch : public DecimalMatch {
- void Reset() {
- DecimalMatch::Reset();
- type = MatchType::MatchType_FRACTION_SECOND;
- }
-
- static bool IsValid(double x) { return (x >= 0.0 && x < 1.0); }
-};
-
-// CombinedIntegersMatch<N> is used for expressions containing multiple (up
-// to N) matches of integers without delimeters between them (because
-// CFG-grammar is based on tokenizer, it could not split a token into several
-// pieces like using regular-expression). For example, "1130" contains "11"
-// and "30" meaning November 30.
-template <int N>
-struct CombinedIntegersMatch : public NonterminalMatch {
- enum {
- SIZE = N,
- };
-
- int values[SIZE];
- int8 count_of_digits; // When expression is in digits format.
- bool is_zero_prefixed; // When expression is in digits format.
-
- void Reset() {
- NonterminalMatch::Reset();
- for (int i = 0; i < SIZE; ++i) {
- values[i] = NO_VAL;
- }
- count_of_digits = 0;
- is_zero_prefixed = false;
- }
-};
-
-struct CombinedDigitsMatch : public CombinedIntegersMatch<6> {
- enum Index {
- INDEX_YEAR = 0,
- INDEX_MONTH = 1,
- INDEX_DAY = 2,
- INDEX_HOUR = 3,
- INDEX_MINUTE = 4,
- INDEX_SECOND = 5,
- };
-
- bool HasYear() const { return values[INDEX_YEAR] != NO_VAL; }
- bool HasMonth() const { return values[INDEX_MONTH] != NO_VAL; }
- bool HasDay() const { return values[INDEX_DAY] != NO_VAL; }
- bool HasHour() const { return values[INDEX_HOUR] != NO_VAL; }
- bool HasMinute() const { return values[INDEX_MINUTE] != NO_VAL; }
- bool HasSecond() const { return values[INDEX_SECOND] != NO_VAL; }
-
- int GetYear() const { return values[INDEX_YEAR]; }
- int GetMonth() const { return values[INDEX_MONTH]; }
- int GetDay() const { return values[INDEX_DAY]; }
- int GetHour() const { return values[INDEX_HOUR]; }
- int GetMinute() const { return values[INDEX_MINUTE]; }
- int GetSecond() const { return values[INDEX_SECOND]; }
-
- void Reset() {
- CombinedIntegersMatch<SIZE>::Reset();
- type = MatchType::MatchType_COMBINED_DIGITS;
- }
-
- static bool IsValid(int i, int x) {
- switch (i) {
- case INDEX_YEAR:
- return YearMatch::IsValid(x);
- case INDEX_MONTH:
- return MonthMatch::IsValid(x);
- case INDEX_DAY:
- return DayMatch::IsValid(x);
- case INDEX_HOUR:
- return HourMatch::IsValid(x);
- case INDEX_MINUTE:
- return MinuteMatch::IsValid(x);
- case INDEX_SECOND:
- return SecondMatch::IsValid(x);
- default:
- return false;
- }
- }
-};
-
-struct TimeValueMatch : public NonterminalMatch {
- const HourMatch* hour_match;
- const MinuteMatch* minute_match;
- const SecondMatch* second_match;
- const FractionSecondMatch* fraction_second_match;
-
- bool is_hour_zero_prefixed : 1;
- bool is_minute_one_digit : 1;
- bool is_second_one_digit : 1;
-
- int8 hour;
- int8 minute;
- int8 second;
- double fraction_second;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_VALUE;
- hour_match = nullptr;
- minute_match = nullptr;
- second_match = nullptr;
- fraction_second_match = nullptr;
- is_hour_zero_prefixed = false;
- is_minute_one_digit = false;
- is_second_one_digit = false;
- hour = NO_VAL;
- minute = NO_VAL;
- second = NO_VAL;
- fraction_second = NO_VAL;
- }
-};
-
-struct TimeSpanMatch : public NonterminalMatch {
- const TimeSpanSpec* time_span_spec;
- TimespanCode time_span_code;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_SPAN;
- time_span_spec = nullptr;
- time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
- }
-};
-
-struct TimeZoneNameMatch : public NonterminalMatch {
- const TimeZoneNameSpec* time_zone_name_spec;
- TimezoneCode time_zone_code;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_ZONE_NAME;
- time_zone_name_spec = nullptr;
- time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
- }
-};
-
-struct TimeZoneOffsetMatch : public NonterminalMatch {
- const TimeZoneOffsetParameter* time_zone_offset_param;
- int16 time_zone_offset;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_ZONE_OFFSET;
- time_zone_offset_param = nullptr;
- time_zone_offset = 0;
- }
-};
-
-struct DayOfWeekMatch : public IntegerMatch {
- void Reset() {
- IntegerMatch::Reset();
- type = MatchType::MatchType_DAY_OF_WEEK;
- }
-
- static bool IsValid(int x) {
- return (x > DayOfWeek_DOW_NONE && x <= DayOfWeek_MAX);
- }
-};
-
-struct TimePeriodMatch : public NonterminalMatch {
- int value;
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_TIME_PERIOD;
- value = NO_VAL;
- }
-};
-
-struct RelativeMatch : public NonterminalMatch {
- enum {
- HAS_NONE = 0,
- HAS_YEAR = 1 << 0,
- HAS_MONTH = 1 << 1,
- HAS_DAY = 1 << 2,
- HAS_WEEK = 1 << 3,
- HAS_HOUR = 1 << 4,
- HAS_MINUTE = 1 << 5,
- HAS_SECOND = 1 << 6,
- HAS_DAY_OF_WEEK = 1 << 7,
- HAS_IS_FUTURE = 1 << 31,
- };
- uint32 existing;
-
- int year;
- int month;
- int day;
- int week;
- int hour;
- int minute;
- int second;
- const NonterminalValue* day_of_week_nonterminal;
- int8 day_of_week;
- bool is_future_date;
-
- bool HasDay() const { return existing & HAS_DAY; }
-
- bool HasDayFields() const { return existing & (HAS_DAY | HAS_DAY_OF_WEEK); }
-
- bool HasTimeValueFields() const {
- return existing & (HAS_HOUR | HAS_MINUTE | HAS_SECOND);
- }
-
- bool IsStandaloneRelativeDayOfWeek() const {
- return (existing & HAS_DAY_OF_WEEK) && (existing & ~HAS_DAY_OF_WEEK) == 0;
- }
-
- void Reset() {
- NonterminalMatch::Reset();
- type = MatchType::MatchType_RELATIVE_DATE;
- existing = HAS_NONE;
- year = NO_VAL;
- month = NO_VAL;
- day = NO_VAL;
- week = NO_VAL;
- hour = NO_VAL;
- minute = NO_VAL;
- second = NO_VAL;
- day_of_week = NO_VAL;
- is_future_date = false;
- }
-};
-
-// This is not necessarily POD, it is used to keep the final matched result.
-struct DateMatch {
- // Sub-matches in the date match.
- const YearMatch* year_match = nullptr;
- const MonthMatch* month_match = nullptr;
- const DayMatch* day_match = nullptr;
- const DayOfWeekMatch* day_of_week_match = nullptr;
- const TimeValueMatch* time_value_match = nullptr;
- const TimeSpanMatch* time_span_match = nullptr;
- const TimeZoneNameMatch* time_zone_name_match = nullptr;
- const TimeZoneOffsetMatch* time_zone_offset_match = nullptr;
- const RelativeMatch* relative_match = nullptr;
- const CombinedDigitsMatch* combined_digits_match = nullptr;
-
- // [begin, end) indicates the Document position where the date or date range
- // was found.
- int begin = -1;
- int end = -1;
- int priority = 0;
- float annotator_priority_score = 0.0;
-
- int year = NO_VAL;
- int8 month = NO_VAL;
- int8 day = NO_VAL;
- DayOfWeek day_of_week = DayOfWeek_DOW_NONE;
- BCAD bc_ad = BCAD_BCAD_NONE;
- int8 hour = NO_VAL;
- int8 minute = NO_VAL;
- int8 second = NO_VAL;
- double fraction_second = NO_VAL;
- TimespanCode time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
- int time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
- int16 time_zone_offset = std::numeric_limits<int16>::min();
-
- // Fields about ambiguous hours. These fields are used to interpret the
- // possible values of ambiguous hours. Since all kinds of known ambiguities
- // are in the form of arithmetic progression (starting from .hour field),
- // we can use "ambiguous_hour_count" to denote the count of ambiguous hours,
- // and use "ambiguous_hour_interval" to denote the distance between a pair
- // of adjacent possible hours. Values in the arithmetic progression are
- // shrunk into [0, 23] (MOD 24). One can use the GetPossibleHourValues()
- // method for the complete list of possible hours.
- uint8 ambiguous_hour_count = 0;
- uint8 ambiguous_hour_interval = 0;
-
- bool is_inferred = false;
-
- // This field is set in function PerformRefinements to remove some DateMatch
- // like overlapped, duplicated, etc.
- bool is_removed = false;
-
- std::string DebugString() const;
-
- bool HasYear() const { return year != NO_VAL; }
- bool HasMonth() const { return month != NO_VAL; }
- bool HasDay() const { return day != NO_VAL; }
- bool HasDayOfWeek() const { return day_of_week != DayOfWeek_DOW_NONE; }
- bool HasBcAd() const { return bc_ad != BCAD_BCAD_NONE; }
- bool HasHour() const { return hour != NO_VAL; }
- bool HasMinute() const { return minute != NO_VAL; }
- bool HasSecond() const { return second != NO_VAL; }
- bool HasFractionSecond() const { return fraction_second != NO_VAL; }
- bool HasTimeSpanCode() const {
- return time_span_code != TimespanCode_TIMESPAN_CODE_NONE;
- }
- bool HasTimeZoneCode() const {
- return time_zone_code != TimezoneCode_TIMEZONE_CODE_NONE;
- }
- bool HasTimeZoneOffset() const {
- return time_zone_offset != std::numeric_limits<int16>::min();
- }
-
- bool HasRelativeDate() const { return relative_match != nullptr; }
-
- bool IsHourAmbiguous() const { return ambiguous_hour_count >= 2; }
-
- bool IsStandaloneTime() const {
- return (HasHour() || HasMinute()) && !HasDayOfWeek() && !HasDay() &&
- !HasMonth() && !HasYear();
- }
-
- void SetAmbiguousHourProperties(uint8 count, uint8 interval) {
- ambiguous_hour_count = count;
- ambiguous_hour_interval = interval;
- }
-
- // Outputs all the possible hour values. If current DateMatch does not
- // contain an hour, nothing will be output. If the hour is not ambiguous,
- // only one value (= .hour) will be output. This method clears the vector
- // "values" first, and it is not guaranteed that the values in the vector
- // are in a sorted order.
- void GetPossibleHourValues(std::vector<int8>* values) const;
-
- int GetPriority() const { return priority; }
-
- float GetAnnotatorPriorityScore() const { return annotator_priority_score; }
-
- bool IsStandaloneRelativeDayOfWeek() const {
- return (HasRelativeDate() &&
- relative_match->IsStandaloneRelativeDayOfWeek() &&
- !HasDateFields() && !HasTimeFields() && !HasTimeSpanCode());
- }
-
- bool HasDateFields() const {
- return (HasYear() || HasMonth() || HasDay() || HasDayOfWeek() || HasBcAd());
- }
- bool HasTimeValueFields() const {
- return (HasHour() || HasMinute() || HasSecond() || HasFractionSecond());
- }
- bool HasTimeSpanFields() const { return HasTimeSpanCode(); }
- bool HasTimeZoneFields() const {
- return (HasTimeZoneCode() || HasTimeZoneOffset());
- }
- bool HasTimeFields() const {
- return (HasTimeValueFields() || HasTimeSpanFields() || HasTimeZoneFields());
- }
-
- bool IsValid() const;
-
- // Overall relative qualifier of the DateMatch e.g. 2 year ago is 'PAST' and
- // next week is 'FUTURE'.
- DatetimeComponent::RelativeQualifier GetRelativeQualifier() const;
-
- // Getter method to get the 'DatetimeComponent' of given 'ComponentType'.
- Optional<DatetimeComponent> GetDatetimeComponent(
- const DatetimeComponent::ComponentType& component_type) const;
-
- void FillDatetimeComponents(
- std::vector<DatetimeComponent>* datetime_component) const;
-};
-
-// Represent a matched date range which includes the from and to matched date.
-struct DateRangeMatch {
- int begin = -1;
- int end = -1;
-
- DateMatch from;
- DateMatch to;
-
- std::string DebugString() const;
-
- int GetPriority() const {
- return std::max(from.GetPriority(), to.GetPriority());
- }
-
- float GetAnnotatorPriorityScore() const {
- return std::max(from.GetAnnotatorPriorityScore(),
- to.GetAnnotatorPriorityScore());
- }
-};
-
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
diff --git a/native/annotator/grammar/dates/utils/date-match_test.cc b/native/annotator/grammar/dates/utils/date-match_test.cc
deleted file mode 100644
index f10f32a..0000000
--- a/native/annotator/grammar/dates/utils/date-match_test.cc
+++ /dev/null
@@ -1,397 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-match.h"
-
-#include <stdint.h>
-
-#include <string>
-
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/timezone-code_generated.h"
-#include "annotator/grammar/dates/utils/date-utils.h"
-#include "utils/strings/append.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-namespace {
-
-class DateMatchTest : public ::testing::Test {
- protected:
- enum {
- X = NO_VAL,
- };
-
- static DayOfWeek DOW_X() { return DayOfWeek_DOW_NONE; }
- static DayOfWeek SUN() { return DayOfWeek_SUNDAY; }
-
- static BCAD BCAD_X() { return BCAD_BCAD_NONE; }
- static BCAD BC() { return BCAD_BC; }
-
- DateMatch& SetDate(DateMatch* date, int year, int8 month, int8 day,
- DayOfWeek day_of_week = DOW_X(), BCAD bc_ad = BCAD_X()) {
- date->year = year;
- date->month = month;
- date->day = day;
- date->day_of_week = day_of_week;
- date->bc_ad = bc_ad;
- return *date;
- }
-
- DateMatch& SetTimeValue(DateMatch* date, int8 hour, int8 minute = X,
- int8 second = X, double fraction_second = X) {
- date->hour = hour;
- date->minute = minute;
- date->second = second;
- date->fraction_second = fraction_second;
- return *date;
- }
-
- DateMatch& SetTimeSpan(DateMatch* date, TimespanCode time_span_code) {
- date->time_span_code = time_span_code;
- return *date;
- }
-
- DateMatch& SetTimeZone(DateMatch* date, TimezoneCode time_zone_code,
- int16 time_zone_offset = INT16_MIN) {
- date->time_zone_code = time_zone_code;
- date->time_zone_offset = time_zone_offset;
- return *date;
- }
-
- bool SameDate(const DateMatch& a, const DateMatch& b) {
- return (a.day == b.day && a.month == b.month && a.year == b.year &&
- a.day_of_week == b.day_of_week);
- }
-
- DateMatch& SetDayOfWeek(DateMatch* date, DayOfWeek dow) {
- date->day_of_week = dow;
- return *date;
- }
-};
-
-TEST_F(DateMatchTest, BitFieldWidth) {
- // For DateMatch::day_of_week (:8).
- EXPECT_GE(DayOfWeek_MIN, INT8_MIN);
- EXPECT_LE(DayOfWeek_MAX, INT8_MAX);
-
- // For DateMatch::bc_ad (:8).
- EXPECT_GE(BCAD_MIN, INT8_MIN);
- EXPECT_LE(BCAD_MAX, INT8_MAX);
-
- // For DateMatch::time_span_code (:16).
- EXPECT_GE(TimespanCode_MIN, INT16_MIN);
- EXPECT_LE(TimespanCode_MAX, INT16_MAX);
-}
-
-TEST_F(DateMatchTest, IsValid) {
- // Valid: dates.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, 26, SUN());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26, DOW_X(), BC());
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Valid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, 26);
- SetTimeValue(&d, 12, 30, 59, 0.99);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, 26);
- SetTimeValue(&d, 12, 30, 59);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, X, X, SUN());
- SetTimeValue(&d, 12, 30);
- EXPECT_TRUE(d.IsValid()) << d.DebugString();
- }
- // Invalid: dates.
- {
- DateMatch d;
- SetDate(&d, X, 1, 26, DOW_X(), BC());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, 26);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, 2014, X, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetDate(&d, X, 1, X, SUN());
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: times.
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, 59);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, X, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, 12, 30, X, 0.99);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- {
- DateMatch d;
- SetTimeValue(&d, X, 30);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: mixed.
- {
- DateMatch d;
- SetDate(&d, 2014, 1, X);
- SetTimeValue(&d, 12);
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
- // Invalid: empty.
- {
- DateMatch d;
- EXPECT_FALSE(d.IsValid()) << d.DebugString();
- }
-}
-
-std::string DebugStrings(const std::vector<DateMatch>& instances) {
- std::string res;
- for (int i = 0; i < instances.size(); ++i) {
- ::libtextclassifier3::strings::SStringAppendF(
- &res, 0, "[%d] == %s\n", i, instances[i].DebugString().c_str());
- }
- return res;
-}
-
-TEST_F(DateMatchTest, IsRefinement) {
- {
- DateMatch a;
- SetDate(&a, 2014, 2, X);
- DateMatch b;
- SetDate(&b, 2014, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, 2014, 2, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeSpan(&a, TimespanCode_AM);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- SetTimeZone(&a, TimezoneCode_PST8PDT);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- a.priority += 10;
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- SetTimeValue(&b, 9, X, X);
- EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, 2014, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, X, 2, 24);
- SetTimeValue(&b, 9, 0, X);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetDate(&a, X, 2, 24);
- SetTimeValue(&a, 9, X, X);
- DateMatch b;
- SetDate(&b, 2014, 2, 24);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
- {
- DateMatch a;
- SetTimeValue(&a, 9, 0, 0);
- DateMatch b;
- SetTimeValue(&b, 9, X, X);
- SetTimeSpan(&b, TimespanCode_AM);
- EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
- }
-}
-
-TEST_F(DateMatchTest, FillDateInstance_AnnotatorPriorityScore) {
- DateMatch date_match;
- SetDate(&date_match, 2014, 2, X);
- date_match.annotator_priority_score = 0.5;
- DatetimeParseResultSpan datetime_parse_result_span;
- FillDateInstance(date_match, &datetime_parse_result_span);
- EXPECT_FLOAT_EQ(datetime_parse_result_span.priority_score, 0.5)
- << DebugStrings({date_match});
-}
-
-TEST_F(DateMatchTest, MergeDateMatch_AnnotatorPriorityScore) {
- DateMatch a;
- SetDate(&a, 2014, 2, 4);
- a.annotator_priority_score = 0.5;
-
- DateMatch b;
- SetTimeValue(&b, 10, 45, 23);
- b.annotator_priority_score = 1.0;
-
- MergeDateMatch(b, &a, false);
- EXPECT_FLOAT_EQ(a.annotator_priority_score, 1.0);
-}
-
-} // namespace
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.cc b/native/annotator/grammar/dates/utils/date-utils.cc
deleted file mode 100644
index ea8015d..0000000
--- a/native/annotator/grammar/dates/utils/date-utils.cc
+++ /dev/null
@@ -1,399 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/grammar/dates/utils/date-utils.h"
-
-#include <algorithm>
-#include <ctime>
-
-#include "annotator/grammar/dates/annotations/annotation-util.h"
-#include "annotator/grammar/dates/dates_generated.h"
-#include "annotator/grammar/dates/utils/annotation-keys.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "annotator/types.h"
-#include "utils/base/macros.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-bool IsLeapYear(int year) {
- // For the sake of completeness, we want to be able to decide
- // whether a year is a leap year all the way back to 0 Julian, or
- // 4714 BCE. But we don't want to take the modulus of a negative
- // number, because this may not be very well-defined or portable. So
- // we increment the year by some large multiple of 400, which is the
- // periodicity of this leap-year calculation.
- if (year < 0) {
- year += 8000;
- }
- return ((year) % 4 == 0 && ((year) % 100 != 0 || (year) % 400 == 0));
-}
-
-namespace {
-#define SECSPERMIN (60)
-#define MINSPERHOUR (60)
-#define HOURSPERDAY (24)
-#define DAYSPERWEEK (7)
-#define DAYSPERNYEAR (365)
-#define DAYSPERLYEAR (366)
-#define MONSPERYEAR (12)
-
-const int8 kDaysPerMonth[2][1 + MONSPERYEAR] = {
- {-1, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
- {-1, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
-};
-} // namespace
-
-int8 GetLastDayOfMonth(int year, int month) {
- if (year == 0) { // No year specified
- return kDaysPerMonth[1][month];
- }
- return kDaysPerMonth[IsLeapYear(year)][month];
-}
-
-namespace {
-inline bool IsHourInSegment(const TimeSpanSpec_::Segment* segment, int8 hour,
- bool is_exact) {
- return (hour >= segment->begin() &&
- (hour < segment->end() ||
- (hour == segment->end() && is_exact && segment->is_closed())));
-}
-
-Property* FindOrCreateDefaultDateTime(AnnotationData* inst) {
- // Refer comments for kDateTime in annotation-keys.h to see the format.
- static constexpr int kDefault[] = {-1, -1, -1, -1, -1, -1, -1, -1};
-
- int idx = GetPropertyIndex(kDateTime, *inst);
- if (idx < 0) {
- idx = AddRepeatedIntProperty(kDateTime, kDefault, TC3_ARRAYSIZE(kDefault),
- inst);
- }
- return &inst->properties[idx];
-}
-
-void IncrementDayOfWeek(DayOfWeek* dow) {
- static const DayOfWeek dow_ring[] = {DayOfWeek_MONDAY, DayOfWeek_TUESDAY,
- DayOfWeek_WEDNESDAY, DayOfWeek_THURSDAY,
- DayOfWeek_FRIDAY, DayOfWeek_SATURDAY,
- DayOfWeek_SUNDAY, DayOfWeek_MONDAY};
- const auto& cur_dow =
- std::find(std::begin(dow_ring), std::end(dow_ring), *dow);
- if (cur_dow != std::end(dow_ring)) {
- *dow = *std::next(cur_dow);
- }
-}
-} // namespace
-
-bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date) {
- if (ts_spec->segment() == nullptr) {
- return false;
- }
- if (date->HasHour()) {
- const bool is_exact =
- (!date->HasMinute() ||
- (date->minute == 0 &&
- (!date->HasSecond() ||
- (date->second == 0 &&
- (!date->HasFractionSecond() || date->fraction_second == 0.0)))));
- for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
- if (IsHourInSegment(segment, date->hour + segment->offset(), is_exact)) {
- date->hour += segment->offset();
- return true;
- }
- if (!segment->is_strict() &&
- IsHourInSegment(segment, date->hour, is_exact)) {
- return true;
- }
- }
- } else {
- for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
- if (segment->is_stand_alone()) {
- if (segment->begin() == segment->end()) {
- date->hour = segment->begin();
- }
- // Allow stand-alone time-span points and ranges.
- return true;
- }
- }
- }
- return false;
-}
-
-bool IsRefinement(const DateMatch& a, const DateMatch& b) {
- int count = 0;
- if (b.HasBcAd()) {
- if (!a.HasBcAd() || a.bc_ad != b.bc_ad) return false;
- } else if (a.HasBcAd()) {
- if (a.bc_ad == BCAD_BC) return false;
- ++count;
- }
- if (b.HasYear()) {
- if (!a.HasYear() || a.year != b.year) return false;
- } else if (a.HasYear()) {
- ++count;
- }
- if (b.HasMonth()) {
- if (!a.HasMonth() || a.month != b.month) return false;
- } else if (a.HasMonth()) {
- ++count;
- }
- if (b.HasDay()) {
- if (!a.HasDay() || a.day != b.day) return false;
- } else if (a.HasDay()) {
- ++count;
- }
- if (b.HasDayOfWeek()) {
- if (!a.HasDayOfWeek() || a.day_of_week != b.day_of_week) return false;
- } else if (a.HasDayOfWeek()) {
- ++count;
- }
- if (b.HasHour()) {
- if (!a.HasHour()) return false;
- std::vector<int8> possible_hours;
- b.GetPossibleHourValues(&possible_hours);
- if (std::find(possible_hours.begin(), possible_hours.end(), a.hour) ==
- possible_hours.end()) {
- return false;
- }
- } else if (a.HasHour()) {
- ++count;
- }
- if (b.HasMinute()) {
- if (!a.HasMinute() || a.minute != b.minute) return false;
- } else if (a.HasMinute()) {
- ++count;
- }
- if (b.HasSecond()) {
- if (!a.HasSecond() || a.second != b.second) return false;
- } else if (a.HasSecond()) {
- ++count;
- }
- if (b.HasFractionSecond()) {
- if (!a.HasFractionSecond() || a.fraction_second != b.fraction_second)
- return false;
- } else if (a.HasFractionSecond()) {
- ++count;
- }
- if (b.HasTimeSpanCode()) {
- if (!a.HasTimeSpanCode() || a.time_span_code != b.time_span_code)
- return false;
- } else if (a.HasTimeSpanCode()) {
- ++count;
- }
- if (b.HasTimeZoneCode()) {
- if (!a.HasTimeZoneCode() || a.time_zone_code != b.time_zone_code)
- return false;
- } else if (a.HasTimeZoneCode()) {
- ++count;
- }
- if (b.HasTimeZoneOffset()) {
- if (!a.HasTimeZoneOffset() || a.time_zone_offset != b.time_zone_offset)
- return false;
- } else if (a.HasTimeZoneOffset()) {
- ++count;
- }
- return (count > 0 || a.priority >= b.priority);
-}
-
-bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b) {
- return false;
-}
-
-bool IsPrecedent(const DateMatch& a, const DateMatch& b) {
- if (a.HasYear() && b.HasYear()) {
- if (a.year < b.year) return true;
- if (a.year > b.year) return false;
- }
-
- if (a.HasMonth() && b.HasMonth()) {
- if (a.month < b.month) return true;
- if (a.month > b.month) return false;
- }
-
- if (a.HasDay() && b.HasDay()) {
- if (a.day < b.day) return true;
- if (a.day > b.day) return false;
- }
-
- if (a.HasHour() && b.HasHour()) {
- if (a.hour < b.hour) return true;
- if (a.hour > b.hour) return false;
- }
-
- if (a.HasMinute() && b.HasHour()) {
- if (a.minute < b.hour) return true;
- if (a.minute > b.hour) return false;
- }
-
- if (a.HasSecond() && b.HasSecond()) {
- if (a.second < b.hour) return true;
- if (a.second > b.hour) return false;
- }
-
- return false;
-}
-
-void FillDateInstance(const DateMatch& date,
- DatetimeParseResultSpan* instance) {
- instance->span.first = date.begin;
- instance->span.second = date.end;
- instance->priority_score = date.GetAnnotatorPriorityScore();
- DatetimeParseResult datetime_parse_result;
- date.FillDatetimeComponents(&datetime_parse_result.datetime_components);
- instance->data.emplace_back(datetime_parse_result);
-}
-
-void FillDateRangeInstance(const DateRangeMatch& range,
- DatetimeParseResultSpan* instance) {
- instance->span.first = range.begin;
- instance->span.second = range.end;
- instance->priority_score = range.GetAnnotatorPriorityScore();
-
- // Filling from DatetimeParseResult.
- instance->data.emplace_back();
- range.from.FillDatetimeComponents(&instance->data.back().datetime_components);
-
- // Filling to DatetimeParseResult.
- instance->data.emplace_back();
- range.to.FillDatetimeComponents(&instance->data.back().datetime_components);
-}
-
-namespace {
-bool AnyOverlappedField(const DateMatch& prev, const DateMatch& next) {
-#define Field(f) \
- if (prev.f && next.f) return true
- Field(year_match);
- Field(month_match);
- Field(day_match);
- Field(day_of_week_match);
- Field(time_value_match);
- Field(time_span_match);
- Field(time_zone_name_match);
- Field(time_zone_offset_match);
- Field(relative_match);
- Field(combined_digits_match);
-#undef Field
- return false;
-}
-
-void MergeDateMatchImpl(const DateMatch& prev, DateMatch* next,
- bool update_span) {
-#define RM(f) \
- if (!next->f) next->f = prev.f
- RM(year_match);
- RM(month_match);
- RM(day_match);
- RM(day_of_week_match);
- RM(time_value_match);
- RM(time_span_match);
- RM(time_zone_name_match);
- RM(time_zone_offset_match);
- RM(relative_match);
- RM(combined_digits_match);
-#undef RM
-
-#define RV(f) \
- if (next->f == NO_VAL) next->f = prev.f
- RV(year);
- RV(month);
- RV(day);
- RV(hour);
- RV(minute);
- RV(second);
- RV(fraction_second);
-#undef RV
-
-#define RE(f, v) \
- if (next->f == v) next->f = prev.f
- RE(day_of_week, DayOfWeek_DOW_NONE);
- RE(bc_ad, BCAD_BCAD_NONE);
- RE(time_span_code, TimespanCode_TIMESPAN_CODE_NONE);
- RE(time_zone_code, TimezoneCode_TIMEZONE_CODE_NONE);
-#undef RE
-
- if (next->time_zone_offset == std::numeric_limits<int16>::min()) {
- next->time_zone_offset = prev.time_zone_offset;
- }
-
- next->priority = std::max(next->priority, prev.priority);
- next->annotator_priority_score =
- std::max(next->annotator_priority_score, prev.annotator_priority_score);
- if (update_span) {
- next->begin = std::min(next->begin, prev.begin);
- next->end = std::max(next->end, prev.end);
- }
-}
-} // namespace
-
-bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next) {
- // Do not merge if they share the same field.
- if (AnyOverlappedField(prev, next)) {
- return false;
- }
-
- // It's impossible that both prev and next have relative date since it's
- // excluded by overlapping check before.
- if (prev.HasRelativeDate() || next.HasRelativeDate()) {
- // If one of them is relative date, then we merge:
- // - if relative match shouldn't have time, and always has DOW or day.
- // - if not both relative match and non relative match has day.
- // - if non relative match has time or day.
- const DateMatch* rm = &prev;
- const DateMatch* non_rm = &prev;
- if (prev.HasRelativeDate()) {
- non_rm = &next;
- } else {
- rm = &next;
- }
-
- const RelativeMatch* relative_match = rm->relative_match;
- // Relative Match should have day or DOW but no time.
- if (!relative_match->HasDayFields() ||
- relative_match->HasTimeValueFields()) {
- return false;
- }
- // Check if both relative match and non relative match has day.
- if (non_rm->HasDateFields() && relative_match->HasDay()) {
- return false;
- }
- // Non relative match should have either hour (time) or day (date).
- if (!non_rm->HasHour() && !non_rm->HasDay()) {
- return false;
- }
- } else {
- // Only one match has date and another has time.
- if ((prev.HasDateFields() && next.HasDateFields()) ||
- (prev.HasTimeFields() && next.HasTimeFields())) {
- return false;
- }
- // DOW never be extracted as a single DateMatch except in RelativeMatch. So
- // here, we always merge one with day and another one with hour.
- if (!(prev.HasDay() || next.HasDay()) ||
- !(prev.HasHour() || next.HasHour())) {
- return false;
- }
- }
- return true;
-}
-
-void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span) {
- if (IsDateMatchMergeable(prev, *next)) {
- MergeDateMatchImpl(prev, next, update_span);
- }
-}
-
-} // namespace dates
-} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.h b/native/annotator/grammar/dates/utils/date-utils.h
deleted file mode 100644
index 2fcda92..0000000
--- a/native/annotator/grammar/dates/utils/date-utils.h
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <ctime>
-#include <vector>
-
-#include "annotator/grammar/dates/annotations/annotation.h"
-#include "annotator/grammar/dates/utils/date-match.h"
-#include "utils/base/casts.h"
-
-namespace libtextclassifier3 {
-namespace dates {
-
-bool IsLeapYear(int year);
-
-int8 GetLastDayOfMonth(int year, int month);
-
-// Normalizes hour value of the specified date using the specified time-span
-// specification. Returns true if the original hour value (can be no-value)
-// is compatible with the time-span and gets normalized successfully, or
-// false otherwise.
-bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date);
-
-// Returns true iff "a" is considered as a refinement of "b". For example,
-// besides fully compatible fields, having more fields or higher priority.
-bool IsRefinement(const DateMatch& a, const DateMatch& b);
-bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b);
-
-// Returns true iff "a" occurs strictly before "b"
-bool IsPrecedent(const DateMatch& a, const DateMatch& b);
-
-// Fill DatetimeParseResult based on DateMatch object which is created from
-// matched rule. The matched string is extracted from tokenizer which provides
-// an interface to access the clean text based on the matched range.
-void FillDateInstance(const DateMatch& date, DatetimeParseResult* instance);
-
-// Fill DatetimeParseResultSpan based on DateMatch object which is created from
-// matched rule. The matched string is extracted from tokenizer which provides
-// an interface to access the clean text based on the matched range.
-void FillDateInstance(const DateMatch& date, DatetimeParseResultSpan* instance);
-
-// Fill DatetimeParseResultSpan based on DateRangeMatch object which i screated
-// from matched rule.
-void FillDateRangeInstance(const DateRangeMatch& range,
- DatetimeParseResultSpan* instance);
-
-// Merge the fields in DateMatch prev to next if there is no overlapped field.
-// If update_span is true, the span of next is also updated.
-// e.g.: prev is 11am, next is: May 1, then the merged next is May 1, 11am
-void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span);
-
-// If DateMatches have no overlapped field, then they could be merged as the
-// following rules:
-// -- If both don't have relative match and one DateMatch has day but another
-// DateMatch has hour.
-// -- If one have relative match then follow the rules in code.
-// It's impossible to get DateMatch which only has DOW and not in relative
-// match according to current rules.
-bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next);
-} // namespace dates
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
diff --git a/native/annotator/grammar/grammar-annotator.cc b/native/annotator/grammar/grammar-annotator.cc
index 756fcd5..cf36454 100644
--- a/native/annotator/grammar/grammar-annotator.cc
+++ b/native/annotator/grammar/grammar-annotator.cc
@@ -19,12 +19,8 @@
#include "annotator/feature-processor.h"
#include "annotator/grammar/utils.h"
#include "annotator/types.h"
+#include "utils/base/arena.h"
#include "utils/base/logging.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules-utils.h"
-#include "utils/grammar/types.h"
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/utf8/unicodetext.h"
@@ -32,447 +28,296 @@
namespace libtextclassifier3 {
namespace {
-// Returns the unicode codepoint offsets in a utf8 encoded text.
-std::vector<UnicodeText::const_iterator> UnicodeCodepointOffsets(
- const UnicodeText& text) {
- std::vector<UnicodeText::const_iterator> offsets;
- for (auto it = text.begin(); it != text.end(); it++) {
- offsets.push_back(it);
+// Retrieves all capturing nodes from a parse tree.
+std::unordered_map<uint16, const grammar::ParseTree*> GetCapturingNodes(
+ const grammar::ParseTree* parse_tree) {
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes;
+ for (const grammar::MappingNode* mapping_node :
+ grammar::SelectAllOfType<grammar::MappingNode>(
+ parse_tree, grammar::ParseTree::Type::kMapping)) {
+ capturing_nodes[mapping_node->id] = mapping_node;
}
- offsets.push_back(text.end());
- return offsets;
+ return capturing_nodes;
+}
+
+// Computes the selection boundaries from a parse tree.
+CodepointSpan MatchSelectionBoundaries(
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* classification) {
+ if (classification->capturing_group() == nullptr) {
+ // Use full match as selection span.
+ return parse_tree->codepoint_span;
+ }
+
+ // Set information from capturing matches.
+ CodepointSpan span{kInvalidIndex, kInvalidIndex};
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
+ GetCapturingNodes(parse_tree);
+
+ // Compute span boundaries.
+ for (int i = 0; i < classification->capturing_group()->size(); i++) {
+ auto it = capturing_nodes.find(i);
+ if (it == capturing_nodes.end()) {
+ // Capturing group is not active, skip.
+ continue;
+ }
+ const CapturingGroup* group = classification->capturing_group()->Get(i);
+ if (group->extend_selection()) {
+ if (span.first == kInvalidIndex) {
+ span = it->second->codepoint_span;
+ } else {
+ span.first = std::min(span.first, it->second->codepoint_span.first);
+ span.second = std::max(span.second, it->second->codepoint_span.second);
+ }
+ }
+ }
+ return span;
}
} // namespace
-class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate {
- public:
- explicit GrammarAnnotatorCallbackDelegate(
- const UniLib* unilib, const GrammarModel* model,
- const MutableFlatbufferBuilder* entity_data_builder, const ModeFlag mode)
- : unilib_(*unilib),
- model_(model),
- entity_data_builder_(entity_data_builder),
- mode_(mode) {}
-
- // Handles a grammar rule match in the annotator grammar.
- void MatchFound(const grammar::Match* match, grammar::CallbackId type,
- int64 value, grammar::Matcher* matcher) override {
- switch (static_cast<GrammarAnnotator::Callback>(type)) {
- case GrammarAnnotator::Callback::kRuleMatch: {
- HandleRuleMatch(match, /*rule_id=*/value);
- return;
- }
- default:
- grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
- }
- }
-
- // Deduplicate and populate annotations from grammar matches.
- bool GetAnnotations(const std::vector<UnicodeText::const_iterator>& text,
- std::vector<AnnotatedSpan>* annotations) const {
- for (const grammar::Derivation& candidate :
- grammar::DeduplicateDerivations(candidates_)) {
- // Check that assertions are fulfilled.
- if (!grammar::VerifyAssertions(candidate.match)) {
- continue;
- }
- if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) {
- return false;
- }
- }
- return true;
- }
-
- bool GetTextSelection(const std::vector<UnicodeText::const_iterator>& text,
- const CodepointSpan& selection, AnnotatedSpan* result) {
- std::vector<grammar::Derivation> selection_candidates;
- // Deduplicate and verify matches.
- auto maybe_interpretation = GetBestValidInterpretation(
- grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
- selection, candidates_, /*only_exact_overlap=*/false)));
- if (!maybe_interpretation.has_value()) {
- return false;
- }
- const GrammarModel_::RuleClassificationResult* interpretation;
- const grammar::Match* match;
- std::tie(interpretation, match) = maybe_interpretation.value();
- return InstantiateAnnotatedSpanFromInterpretation(text, interpretation,
- match, result);
- }
-
- // Provides a classification results from the grammar matches.
- bool GetClassification(const std::vector<UnicodeText::const_iterator>& text,
- const CodepointSpan& selection,
- ClassificationResult* classification) const {
- // Deduplicate and verify matches.
- auto maybe_interpretation = GetBestValidInterpretation(
- grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
- selection, candidates_, /*only_exact_overlap=*/true)));
- if (!maybe_interpretation.has_value()) {
- return false;
- }
-
- // Instantiate result.
- const GrammarModel_::RuleClassificationResult* interpretation;
- const grammar::Match* match;
- std::tie(interpretation, match) = maybe_interpretation.value();
- return InstantiateClassificationInterpretation(text, interpretation, match,
- classification);
- }
-
- private:
- // Handles annotation/selection/classification rule matches.
- void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
- if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() &
- mode_) != 0) {
- candidates_.push_back(grammar::Derivation{match, rule_id});
- }
- }
-
- // Computes the selection boundaries from a grammar match.
- CodepointSpan MatchSelectionBoundaries(
- const grammar::Match* match,
- const GrammarModel_::RuleClassificationResult* classification) const {
- if (classification->capturing_group() == nullptr) {
- // Use full match as selection span.
- return match->codepoint_span;
- }
-
- // Set information from capturing matches.
- CodepointSpan span{kInvalidIndex, kInvalidIndex};
- // Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
-
- // Compute span boundaries.
- for (int i = 0; i < classification->capturing_group()->size(); i++) {
- auto it = capturing_matches.find(i);
- if (it == capturing_matches.end()) {
- // Capturing group is not active, skip.
- continue;
- }
- const CapturingGroup* group = classification->capturing_group()->Get(i);
- if (group->extend_selection()) {
- if (span.first == kInvalidIndex) {
- span = it->second->codepoint_span;
- } else {
- span.first = std::min(span.first, it->second->codepoint_span.first);
- span.second =
- std::max(span.second, it->second->codepoint_span.second);
- }
- }
- }
- return span;
- }
-
- // Filters out results that do not overlap with a reference span.
- std::vector<grammar::Derivation> GetOverlappingRuleMatches(
- const CodepointSpan& selection,
- const std::vector<grammar::Derivation>& candidates,
- const bool only_exact_overlap) const {
- std::vector<grammar::Derivation> result;
- for (const grammar::Derivation& candidate : candidates) {
- // Discard matches that do not match the selection.
- // Simple check.
- if (!SpansOverlap(selection, candidate.match->codepoint_span)) {
- continue;
- }
-
- // Compute exact selection boundaries (without assertions and
- // non-capturing parts).
- const CodepointSpan span = MatchSelectionBoundaries(
- candidate.match,
- model_->rule_classification_result()->Get(candidate.rule_id));
- if (!SpansOverlap(selection, span) ||
- (only_exact_overlap && span != selection)) {
- continue;
- }
- result.push_back(candidate);
- }
- return result;
- }
-
- // Returns the best valid interpretation of a set of candidate matches.
- Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
- const grammar::Match*>>
- GetBestValidInterpretation(
- const std::vector<grammar::Derivation>& candidates) const {
- const GrammarModel_::RuleClassificationResult* best_interpretation =
- nullptr;
- const grammar::Match* best_match = nullptr;
- for (const grammar::Derivation& candidate : candidates) {
- if (!grammar::VerifyAssertions(candidate.match)) {
- continue;
- }
- const GrammarModel_::RuleClassificationResult*
- rule_classification_result =
- model_->rule_classification_result()->Get(candidate.rule_id);
- if (best_interpretation == nullptr ||
- best_interpretation->priority_score() <
- rule_classification_result->priority_score()) {
- best_interpretation = rule_classification_result;
- best_match = candidate.match;
- }
- }
-
- // No valid interpretation found.
- Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
- const grammar::Match*>>
- result;
- if (best_interpretation != nullptr) {
- result = {best_interpretation, best_match};
- }
- return result;
- }
-
- // Instantiates an annotated span from a rule match and appends it to the
- // result.
- bool AddAnnotatedSpanFromMatch(
- const std::vector<UnicodeText::const_iterator>& text,
- const grammar::Derivation& candidate,
- std::vector<AnnotatedSpan>* result) const {
- if (candidate.rule_id < 0 ||
- candidate.rule_id >= model_->rule_classification_result()->size()) {
- TC3_LOG(INFO) << "Invalid rule id.";
- return false;
- }
- const GrammarModel_::RuleClassificationResult* interpretation =
- model_->rule_classification_result()->Get(candidate.rule_id);
- result->emplace_back();
- return InstantiateAnnotatedSpanFromInterpretation(
- text, interpretation, candidate.match, &result->back());
- }
-
- bool InstantiateAnnotatedSpanFromInterpretation(
- const std::vector<UnicodeText::const_iterator>& text,
- const GrammarModel_::RuleClassificationResult* interpretation,
- const grammar::Match* match, AnnotatedSpan* result) const {
- result->span = MatchSelectionBoundaries(match, interpretation);
- ClassificationResult classification;
- if (!InstantiateClassificationInterpretation(text, interpretation, match,
- &classification)) {
- return false;
- }
- result->classification.push_back(classification);
- return true;
- }
-
- // Instantiates a classification result from a rule match.
- bool InstantiateClassificationInterpretation(
- const std::vector<UnicodeText::const_iterator>& text,
- const GrammarModel_::RuleClassificationResult* interpretation,
- const grammar::Match* match, ClassificationResult* classification) const {
- classification->collection = interpretation->collection_name()->str();
- classification->score = interpretation->target_classification_score();
- classification->priority_score = interpretation->priority_score();
-
- // Assemble entity data.
- if (entity_data_builder_ == nullptr) {
- return true;
- }
- std::unique_ptr<MutableFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
- if (interpretation->serialized_entity_data() != nullptr) {
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(interpretation->serialized_entity_data()->data(),
- interpretation->serialized_entity_data()->size()));
- }
- if (interpretation->entity_data() != nullptr) {
- entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
- interpretation->entity_data()));
- }
-
- // Populate entity data from the capturing matches.
- if (interpretation->capturing_group() != nullptr) {
- // Gather active capturing matches.
- std::unordered_map<uint16, const grammar::Match*> capturing_matches;
- for (const grammar::MappingMatch* match :
- grammar::SelectAllOfType<grammar::MappingMatch>(
- match, grammar::Match::kMappingMatch)) {
- capturing_matches[match->id] = match;
- }
- for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
- auto it = capturing_matches.find(i);
- if (it == capturing_matches.end()) {
- // Capturing group is not active, skip.
- continue;
- }
- const CapturingGroup* group = interpretation->capturing_group()->Get(i);
-
- // Add static entity data.
- if (group->serialized_entity_data() != nullptr) {
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(interpretation->serialized_entity_data()->data(),
- interpretation->serialized_entity_data()->size()));
- }
-
- // Set entity field from captured text.
- if (group->entity_field_path() != nullptr) {
- const grammar::Match* capturing_match = it->second;
- StringPiece group_text = StringPiece(
- text[capturing_match->codepoint_span.first].utf8_data(),
- text[capturing_match->codepoint_span.second].utf8_data() -
- text[capturing_match->codepoint_span.first].utf8_data());
- UnicodeText normalized_group_text =
- UTF8ToUnicodeText(group_text, /*do_copy=*/false);
- if (group->normalization_options() != nullptr) {
- normalized_group_text = NormalizeText(
- unilib_, group->normalization_options(), normalized_group_text);
- }
- if (!entity_data->ParseAndSet(group->entity_field_path(),
- normalized_group_text.ToUTF8String())) {
- TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
- return false;
- }
- }
- }
- }
-
- if (entity_data && entity_data->HasExplicitlySetFields()) {
- classification->serialized_entity_data = entity_data->Serialize();
- }
- return true;
- }
-
- const UniLib& unilib_;
- const GrammarModel* model_;
- const MutableFlatbufferBuilder* entity_data_builder_;
- const ModeFlag mode_;
-
- // All annotation/selection/classification rule match candidates.
- // Grammar rule matches are recorded, deduplicated and then instantiated.
- std::vector<grammar::Derivation> candidates_;
-};
-
GrammarAnnotator::GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
const MutableFlatbufferBuilder* entity_data_builder)
: unilib_(*unilib),
model_(model),
- lexer_(unilib, model->rules()),
tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
entity_data_builder_(entity_data_builder),
- rules_locales_(grammar::ParseRulesLocales(model->rules())) {}
+ analyzer_(unilib, model->rules(), &tokenizer_) {}
+
+// Filters out results that do not overlap with a reference span.
+std::vector<grammar::Derivation> GrammarAnnotator::OverlappingDerivations(
+ const CodepointSpan& selection,
+ const std::vector<grammar::Derivation>& derivations,
+ const bool only_exact_overlap) const {
+ std::vector<grammar::Derivation> result;
+ for (const grammar::Derivation& derivation : derivations) {
+ // Discard matches that do not match the selection.
+ // Simple check.
+ if (!SpansOverlap(selection, derivation.parse_tree->codepoint_span)) {
+ continue;
+ }
+
+ // Compute exact selection boundaries (without assertions and
+ // non-capturing parts).
+ const CodepointSpan span = MatchSelectionBoundaries(
+ derivation.parse_tree,
+ model_->rule_classification_result()->Get(derivation.rule_id));
+ if (!SpansOverlap(selection, span) ||
+ (only_exact_overlap && span != selection)) {
+ continue;
+ }
+ result.push_back(derivation);
+ }
+ return result;
+}
+
+bool GrammarAnnotator::InstantiateAnnotatedSpanFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ AnnotatedSpan* result) const {
+ result->span = MatchSelectionBoundaries(parse_tree, interpretation);
+ ClassificationResult classification;
+ if (!InstantiateClassificationFromDerivation(
+ input_context, parse_tree, interpretation, &classification)) {
+ return false;
+ }
+ result->classification.push_back(classification);
+ return true;
+}
+
+// Instantiates a classification result from a rule match.
+bool GrammarAnnotator::InstantiateClassificationFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ ClassificationResult* classification) const {
+ classification->collection = interpretation->collection_name()->str();
+ classification->score = interpretation->target_classification_score();
+ classification->priority_score = interpretation->priority_score();
+
+ // Assemble entity data.
+ if (entity_data_builder_ == nullptr) {
+ return true;
+ }
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ if (interpretation->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+ if (interpretation->entity_data() != nullptr) {
+ entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
+ interpretation->entity_data()));
+ }
+
+ // Populate entity data from the capturing matches.
+ if (interpretation->capturing_group() != nullptr) {
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes =
+ GetCapturingNodes(parse_tree);
+
+ for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
+ auto it = capturing_nodes.find(i);
+ if (it == capturing_nodes.end()) {
+ // Capturing group is not active, skip.
+ continue;
+ }
+ const CapturingGroup* group = interpretation->capturing_group()->Get(i);
+
+ // Add static entity data.
+ if (group->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+
+ // Set entity field from captured text.
+ if (group->entity_field_path() != nullptr) {
+ const grammar::ParseTree* capturing_match = it->second;
+ UnicodeText match_text =
+ input_context.Span(capturing_match->codepoint_span);
+ if (group->normalization_options() != nullptr) {
+ match_text = NormalizeText(unilib_, group->normalization_options(),
+ match_text);
+ }
+ if (!entity_data->ParseAndSet(group->entity_field_path(),
+ match_text.ToUTF8String())) {
+ TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
+ return false;
+ }
+ }
+ }
+ }
+
+ if (entity_data && entity_data->HasExplicitlySetFields()) {
+ classification->serialized_entity_data = entity_data->Serialize();
+ }
+ return true;
+}
bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
const UnicodeText& text,
std::vector<AnnotatedSpan>* result) const {
- if (model_ == nullptr || model_->rules() == nullptr) {
- // Nothing to do.
- return true;
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ for (const grammar::Derivation& derivation : ValidDeduplicatedDerivations(
+ analyzer_.parser().Parse(input_context, &arena))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_ANNOTATION) == 0) {
+ continue;
+ }
+ result->emplace_back();
+ if (!InstantiateAnnotatedSpanFromDerivation(
+ input_context, derivation.parse_tree, interpretation,
+ &result->back())) {
+ return false;
+ }
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
- }
-
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_ANNOTATION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
- lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
- &matcher);
-
- // Populate results.
- return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result);
+ return true;
}
bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
const UnicodeText& text,
const CodepointSpan& selection,
AnnotatedSpan* result) const {
- if (model_ == nullptr || model_->rules() == nullptr ||
- selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
- // Nothing to do.
+ if (!selection.IsValid() || selection.IsEmpty()) {
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
- return true;
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
+ const grammar::ParseTree* best_match = nullptr;
+ for (const grammar::Derivation& derivation :
+ ValidDeduplicatedDerivations(OverlappingDerivations(
+ selection, analyzer_.parser().Parse(input_context, &arena),
+ /*only_exact_overlap=*/false))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_SELECTION) == 0) {
+ continue;
+ }
+ if (best_interpretation == nullptr ||
+ interpretation->priority_score() >
+ best_interpretation->priority_score()) {
+ best_interpretation = interpretation;
+ best_match = derivation.parse_tree;
+ }
}
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_SELECTION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
- lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
- &matcher);
+ if (best_interpretation == nullptr) {
+ return false;
+ }
- // Populate the result.
- return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text),
- selection, result);
+ return InstantiateAnnotatedSpanFromDerivation(input_context, best_match,
+ best_interpretation, result);
}
bool GrammarAnnotator::ClassifyText(
const std::vector<Locale>& locales, const UnicodeText& text,
const CodepointSpan& selection,
ClassificationResult* classification_result) const {
- if (model_ == nullptr || model_->rules() == nullptr ||
- selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
+ if (!selection.IsValid() || selection.IsEmpty()) {
// Nothing to do.
return false;
}
- // Select locale matching rules.
- std::vector<const grammar::RulesSet_::Rules*> locale_rules =
- SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
- if (locale_rules.empty()) {
- // Nothing to do.
+ grammar::TextContext input_context =
+ analyzer_.BuildTextContextForInput(text, locales);
+
+ if (const TokenSpan context_span = CodepointSpanToTokenSpan(
+ input_context.tokens, selection,
+ /*snap_boundaries_to_containing_tokens=*/true);
+ context_span.IsValid()) {
+ if (model_->context_left_num_tokens() != kInvalidIndex) {
+ input_context.context_span.first =
+ std::max(0, context_span.first - model_->context_left_num_tokens());
+ }
+ if (model_->context_right_num_tokens() != kInvalidIndex) {
+ input_context.context_span.second =
+ std::min(static_cast<int>(input_context.tokens.size()),
+ context_span.second + model_->context_right_num_tokens());
+ }
+ }
+
+ UnsafeArena arena(/*block_size=*/16 << 10);
+
+ const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr;
+ const grammar::ParseTree* best_match = nullptr;
+ for (const grammar::Derivation& derivation :
+ ValidDeduplicatedDerivations(OverlappingDerivations(
+ selection, analyzer_.parser().Parse(input_context, &arena),
+ /*only_exact_overlap=*/true))) {
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(derivation.rule_id);
+ if ((interpretation->enabled_modes() & ModeFlag_CLASSIFICATION) == 0) {
+ continue;
+ }
+ if (best_interpretation == nullptr ||
+ interpretation->priority_score() >
+ best_interpretation->priority_score()) {
+ best_interpretation = interpretation;
+ best_match = derivation.parse_tree;
+ }
+ }
+
+ if (best_interpretation == nullptr) {
return false;
}
- // Run the grammar.
- GrammarAnnotatorCallbackDelegate callback_handler(
- &unilib_, model_, entity_data_builder_,
- /*mode=*/ModeFlag_CLASSIFICATION);
- grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
- &callback_handler);
-
- const std::vector<Token> tokens = tokenizer_.Tokenize(text);
- if (model_->context_left_num_tokens() == -1 &&
- model_->context_right_num_tokens() == -1) {
- // Use all tokens.
- lexer_.Process(text, tokens, /*annotations=*/{}, &matcher);
- } else {
- TokenSpan context_span = CodepointSpanToTokenSpan(
- tokens, selection, /*snap_boundaries_to_containing_tokens=*/true);
- std::vector<Token>::const_iterator begin = tokens.begin();
- std::vector<Token>::const_iterator end = tokens.begin();
- if (model_->context_left_num_tokens() != -1) {
- std::advance(begin, std::max(0, context_span.first -
- model_->context_left_num_tokens()));
- }
- if (model_->context_right_num_tokens() == -1) {
- end = tokens.end();
- } else {
- std::advance(end, std::min(static_cast<int>(tokens.size()),
- context_span.second +
- model_->context_right_num_tokens()));
- }
- lexer_.Process(text, begin, end,
- /*annotations=*/nullptr, &matcher);
- }
-
- // Populate result.
- return callback_handler.GetClassification(UnicodeCodepointOffsets(text),
- selection, classification_result);
+ return InstantiateClassificationFromDerivation(
+ input_context, best_match, best_interpretation, classification_result);
}
} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator.h b/native/annotator/grammar/grammar-annotator.h
index b9ef62c..251b557 100644
--- a/native/annotator/grammar/grammar-annotator.h
+++ b/native/annotator/grammar/grammar-annotator.h
@@ -22,7 +22,9 @@
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/flatbuffers/mutable.h"
-#include "utils/grammar/lexer.h"
+#include "utils/grammar/analyzer.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/text-context.h"
#include "utils/i18n/locale.h"
#include "utils/tokenizer.h"
#include "utils/utf8/unicodetext.h"
@@ -33,10 +35,6 @@
// Grammar backed annotator.
class GrammarAnnotator {
public:
- enum class Callback : grammar::CallbackId {
- kRuleMatch = 1,
- };
-
explicit GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
const MutableFlatbufferBuilder* entity_data_builder);
@@ -59,14 +57,31 @@
AnnotatedSpan* result) const;
private:
+ // Filters out derivations that do not overlap with a reference span.
+ std::vector<grammar::Derivation> OverlappingDerivations(
+ const CodepointSpan& selection,
+ const std::vector<grammar::Derivation>& derivations,
+ const bool only_exact_overlap) const;
+
+ // Fills out an annotated span from a grammar match result.
+ bool InstantiateAnnotatedSpanFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ AnnotatedSpan* result) const;
+
+ // Instantiates a classification result from a rule match.
+ bool InstantiateClassificationFromDerivation(
+ const grammar::TextContext& input_context,
+ const grammar::ParseTree* parse_tree,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ ClassificationResult* classification) const;
+
const UniLib& unilib_;
const GrammarModel* model_;
- const grammar::Lexer lexer_;
const Tokenizer tokenizer_;
const MutableFlatbufferBuilder* entity_data_builder_;
-
- // Pre-parsed locales of the rules.
- const std::vector<std::vector<Locale>> rules_locales_;
+ const grammar::Analyzer analyzer_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc
index 39ee950..b2084cb 100644
--- a/native/annotator/grammar/grammar-annotator_test.cc
+++ b/native/annotator/grammar/grammar-annotator_test.cc
@@ -54,7 +54,7 @@
rules.Add(
"<flight>", {"<carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
rules.Finalize().Serialize(/*include_debug_information=*/false,
@@ -90,7 +90,7 @@
rules.Add(
"<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
@@ -136,7 +136,7 @@
rules.Add(
"<phone>", {"please", "call", "<low_confidence_phone>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/classification_result_id);
rules.Finalize().Serialize(/*include_debug_information=*/false,
@@ -166,7 +166,7 @@
rules.Add(
"<flight>", {"<carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
rules.Finalize().Serialize(/*include_debug_information=*/false,
@@ -209,7 +209,7 @@
rules.Add(
"<flight>", {"<flight_selection>", "<context_assertion>?"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
classification_result_id);
@@ -271,7 +271,7 @@
"<parcel_tracking>",
{"<parcel_tracking_trigger>", "<captured_tracking_number>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
classification_result_id);
@@ -315,7 +315,7 @@
rules.Add(
"<flight>", {"<carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
rules.Finalize().Serialize(/*include_debug_information=*/false,
@@ -344,7 +344,7 @@
rules.Add(
"<person>", {"barack", "obama"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/person_result);
// Add test entity data.
@@ -393,7 +393,7 @@
rules.Add(
"<test>", {"<captured_person>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/person_result);
// Set capturing group entity data information.
@@ -451,13 +451,13 @@
rules.Add(
"<flight>", {"<annotation_carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
rules.Add(
"<flight>", {"<selection_carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight",
ModeFlag_CLASSIFICATION_AND_SELECTION, 1.0,
@@ -465,7 +465,7 @@
rules.Add(
"<flight>", {"<classification_carrier>", "<flight_code>"},
/*callback=*/
- static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/
AddRuleClassificationResult("flight", ModeFlag_CLASSIFICATION, 1.0,
&grammar_model));
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 2a53288..615ad06 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -52,12 +52,13 @@
return true;
}
- Status ChunkMultipleSpans(const std::vector<std::string>& text_fragments,
- AnnotationUsecase annotation_usecase,
- const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- const AnnotateMode annotate_mode,
- Annotations* results) const {
+ Status ChunkMultipleSpans(
+ const std::vector<std::string>& text_fragments,
+ const std::vector<FragmentMetadata>& fragment_metadata,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions, const AnnotateMode annotate_mode,
+ Annotations* results) const {
return Status::OK;
}
diff --git a/native/annotator/knowledge/knowledge-engine-types.h b/native/annotator/knowledge/knowledge-engine-types.h
index 9508c7b..04b71cb 100644
--- a/native/annotator/knowledge/knowledge-engine-types.h
+++ b/native/annotator/knowledge/knowledge-engine-types.h
@@ -21,6 +21,11 @@
enum AnnotateMode { kEntityAnnotation, kEntityAndTopicalityAnnotation };
+struct FragmentMetadata {
+ float relative_bounding_box_top;
+ float relative_bounding_box_height;
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index b279cc5..dbbb422 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -15,17 +15,16 @@
//
include "utils/intents/intent-config.fbs";
-include "annotator/grammar/dates/dates.fbs";
+include "annotator/experimental/experimental.fbs";
+include "annotator/entity-data.fbs";
+include "utils/grammar/rules.fbs";
include "utils/normalization.fbs";
include "utils/tokenizer.fbs";
-include "utils/grammar/rules.fbs";
include "utils/resources.fbs";
-include "utils/zlib/buffer.fbs";
-include "utils/container/bit-vector.fbs";
-include "annotator/entity-data.fbs";
-include "annotator/experimental/experimental.fbs";
include "utils/codepoint-range.fbs";
include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/zlib/buffer.fbs";
+include "utils/container/bit-vector.fbs";
file_identifier "TC2 ";
@@ -374,85 +373,6 @@
tokenize_on_script_change:bool = false;
}
-// Options for grammar date/datetime/date range annotations.
-namespace libtextclassifier3.GrammarDatetimeModel_;
-table AnnotationOptions {
- // If enabled, extract special day offset like today, yesterday, etc.
- enable_special_day_offset:bool = true;
-
- // If true, merge the adjacent day of week, time and date. e.g.
- // "20/2/2016 at 8pm" is extracted as a single instance instead of two
- // instance: "20/2/2016" and "8pm".
- merge_adjacent_components:bool = true;
-
- // List the extra id of requested dates.
- extra_requested_dates:[string];
-
- // If true, try to include preposition to the extracted annotation. e.g.
- // "at 6pm". if it's false, only 6pm is included. offline-actions has
- // special requirements to include preposition.
- include_preposition:bool = true;
-
- // If enabled, extract range in date annotator.
- // input: Monday, 5-6pm
- // If the flag is true, The extracted annotation only contains 1 range
- // instance which is from Monday 5pm to 6pm.
- // If the flag is false, The extracted annotation contains two date
- // instance: "Monday" and "6pm".
- enable_date_range:bool = true;
- reserved_6:int16 (deprecated);
-
- // If enabled, the rule priority score is used to set the priority score of
- // the annotation.
- // In case of false the annotation priority score is set from
- // GrammarDatetimeModel's priority_score
- use_rule_priority_score:bool = false;
-
- // If enabled, annotator will try to resolve the ambiguity by generating
- // possible alternative interpretations of the input text
- // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
- generate_alternative_interpretations_when_ambiguous:bool;
-
- // List of spans which grammar will ignore during the match e.g. if
- // “@” is in the allowed span list and input is “12 March @ 12PM” then “@”
- // will be ignored and 12 March @ 12PM will be translate to
- // {Day:12 Month: March Hour: 12 MERIDIAN: PM}.
- // This can also be achieved by adding additional rules e.g.
- // <Digit_Day> <Month> <Time>
- // <Digit_Day> <Month> @ <Time>
- // Though this is doable in the grammar but requires multiple rules, this
- // list enables the rule to represent multiple rules.
- ignored_spans:[string];
-}
-
-namespace libtextclassifier3;
-table GrammarDatetimeModel {
- // List of BCP 47 locale strings representing all locales supported by the
- // model.
- locales:[string];
-
- // If true, will give only future dates (when the day is not specified).
- prefer_future_for_unspecified_date:bool = false;
-
- // Grammar specific tokenizer options.
- grammar_tokenizer_options:GrammarTokenizerOptions;
-
- // The modes for which to apply the grammars.
- enabled_modes:ModeFlag = ALL;
-
- // The datetime grammar rules.
- datetime_rules:dates.DatetimeRules;
-
- // The final score to assign to the results of grammar model
- target_classification_score:float = 1;
-
- // The priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // Options for grammar annotations.
- annotation_options:GrammarDatetimeModel_.AnnotationOptions;
-}
-
namespace libtextclassifier3.DatetimeModelLibrary_;
table Item {
key:string (shared);
@@ -667,7 +587,7 @@
triggering_locales:string (shared);
embedding_pruning_mask:Model_.EmbeddingPruningMask;
- grammar_datetime_model:GrammarDatetimeModel;
+ reserved_25:int16 (deprecated);
contact_annotator_options:ContactAnnotatorOptions;
money_parsing_options:MoneyParsingOptions;
translate_annotator_options:TranslateAnnotatorOptions;
@@ -1002,6 +922,18 @@
backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
}
+namespace libtextclassifier3.PodNerModel_;
+table Collection {
+ // Collection's name (e.g., "location", "person").
+ name:string (shared);
+
+ // Priority scores used for conflict resolution with the other annotators
+ // when the annotation is made over a single/multi token text.
+ single_token_priority_score:float;
+
+ multi_token_priority_score:float;
+}
+
namespace libtextclassifier3.PodNerModel_.Label_;
enum BoiseType : int {
NONE = 0,
@@ -1043,7 +975,8 @@
// end in punctuation.
append_final_period:bool = false;
- // Priority score used for conflict resolution with the other models.
+ // Priority score used for conflict resolution with the other models. Used
+ // only if collections_array is empty.
priority_score:float = 0;
// Maximum number of wordpieces supported by the model.
@@ -1054,9 +987,7 @@
// wordpieces between two consecutive windows. This overlap enables context
// for each word NER annotates.
sliding_window_num_wordpieces_overlap:int = 20;
-
- // Possible collections for labeled entities, e.g., "location", "person".
- collections:[string];
+ reserved_9:int16 (deprecated);
// The possible labels the ner model can output. If empty the default labels
// will be used.
@@ -1065,14 +996,25 @@
// If the ratio of unknown wordpieces in the input text is greater than this
// maximum, the text won't be annotated.
max_ratio_unknown_wordpieces:float = 0.1;
+
+ // Possible collections for labeled entities.
+ collections:[PodNerModel_.Collection];
+
+ // Minimum word-length and wordpieces-length required for the text to be
+ // annotated.
+ min_number_of_tokens:int = 1;
+
+ min_number_of_wordpieces:int = 1;
}
namespace libtextclassifier3;
table VocabModel {
// A trie that stores a list of vocabs that triggers "Define". A id is
// returned when looking up a vocab from the trie and the id can be used
- // to access more information about that vocab.
- vocab_trie:[ubyte];
+ // to access more information about that vocab. The marisa trie library
+ // requires 8-byte alignment because the first thing in a marisa trie is a
+ // 64-bit integer.
+ vocab_trie:[ubyte] (force_align: 8);
// A bit vector that tells if the vocab should trigger "Define" for users of
// beginner proficiency only. To look up the bit vector, use the id returned
diff --git a/native/annotator/pod_ner/pod-ner-dummy.h b/native/annotator/pod_ner/pod-ner-dummy.h
index 2f6cd41..c2ee00f 100644
--- a/native/annotator/pod_ner/pod-ner-dummy.h
+++ b/native/annotator/pod_ner/pod-ner-dummy.h
@@ -40,8 +40,8 @@
return true;
}
- AnnotatedSpan SuggestSelection(const UnicodeText &context,
- CodepointSpan click) const {
+ bool SuggestSelection(const UnicodeText &context, CodepointSpan click,
+ AnnotatedSpan *result) const {
return {};
}
@@ -49,6 +49,8 @@
ClassificationResult *result) const {
return false;
}
+
+ std::vector<std::string> GetSupportedCollections() const { return {}; }
};
} // namespace libtextclassifier3
diff --git a/native/annotator/test-utils.h b/native/annotator/test-utils.h
index a86302c..d63e66e 100644
--- a/native/annotator/test-utils.h
+++ b/native/annotator/test-utils.h
@@ -33,6 +33,13 @@
Value(first_result, best_class);
}
+MATCHER_P(IsAnnotationWithType, best_class, "") {
+ const std::string first_result = arg.classification.empty()
+ ? "<INVALID RESULTS>"
+ : arg.classification[0].collection;
+ return Value(first_result, best_class);
+}
+
MATCHER_P2(IsDateResult, time_ms_utc, granularity, "") {
return Value(arg.collection, "date") &&
Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
diff --git a/native/annotator/test_data/test_grammar_model.fb b/native/annotator/test_data/test_grammar_model.fb
deleted file mode 100644
index 30f133e..0000000
--- a/native/annotator/test_data/test_grammar_model.fb
+++ /dev/null
Binary files differ
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
index 55f55c9..64b3ac0 100644
--- a/native/annotator/test_data/test_model.fb
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/test_vocab_model.fb b/native/annotator/test_data/test_vocab_model.fb
new file mode 100644
index 0000000..74b7631
--- /dev/null
+++ b/native/annotator/test_data/test_vocab_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
index abe3fb0..dfb7369 100644
--- a/native/annotator/test_data/wrong_embeddings.fb
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/types.h b/native/annotator/types.h
index f7e1143..3063838 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -83,7 +83,8 @@
}
bool IsValid() const {
- return this->first != kInvalidIndex && this->second != kInvalidIndex;
+ return this->first != kInvalidIndex && this->second != kInvalidIndex &&
+ this->first <= this->second && this->first >= 0;
}
bool IsEmpty() const { return this->first == this->second; }
@@ -282,9 +283,9 @@
SECOND = 8,
// Meridiem field where 0 == AM, 1 == PM.
MERIDIEM = 9,
- // Number of hours offset from UTC this date time is in.
+ // Offset in number of minutes from UTC this date time is in.
ZONE_OFFSET = 10,
- // Number of hours offest for DST.
+ // Offset in number of hours for DST.
DST_OFFSET = 11,
};
@@ -430,7 +431,8 @@
std::string serialized_knowledge_result;
ContactPointer contact_pointer;
std::string contact_name, contact_given_name, contact_family_name,
- contact_nickname, contact_email_address, contact_phone_number, contact_id;
+ contact_nickname, contact_email_address, contact_phone_number,
+ contact_account_type, contact_account_name, contact_id;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -524,6 +526,10 @@
// If true, the POD NER annotator is used.
bool use_pod_ner = true;
+ // If true and the model file supports that, the new vocab annotator is used
+ // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
+ bool use_vocab_annotator = true;
+
bool operator==(const BaseOptions& other) const {
bool location_context_equality = this->location_context.has_value() ==
other.location_context.has_value();
@@ -536,7 +542,9 @@
this->annotation_usecase == other.annotation_usecase &&
this->detected_text_language_tags ==
other.detected_text_language_tags &&
- location_context_equality;
+ location_context_equality &&
+ this->use_pod_ner == other.use_pod_ner &&
+ this->use_vocab_annotator == other.use_vocab_annotator;
}
};
@@ -677,6 +685,8 @@
struct InputFragment {
std::string text;
+ float bounding_box_top;
+ float bounding_box_height;
// If present will override the AnnotationOptions reference time and timezone
// when annotating this specific string fragment.
diff --git a/native/annotator/vocab/test_data/test.model b/native/annotator/vocab/test_data/test.model
deleted file mode 100644
index 06b189d..0000000
--- a/native/annotator/vocab/test_data/test.model
+++ /dev/null
Binary files differ
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 3dcdd3b..19afcc4 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -160,6 +160,7 @@
SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
+
private:
const int fd_;
@@ -195,13 +196,23 @@
size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
// Perform actual mmap.
+ return MmapFile(fd, /*offset_in_bytes=*/0, file_size_in_bytes);
+}
+
+MmapHandle MmapFile(int fd, size_t offset_in_bytes, size_t size_in_bytes) {
+ // Make sure the offset is a multiple of the page size, as returned by
+ // sysconf(_SC_PAGE_SIZE); this is required by the man-page for mmap.
+ static const size_t kPageSize = sysconf(_SC_PAGE_SIZE);
+ const size_t aligned_offset = (offset_in_bytes / kPageSize) * kPageSize;
+ const size_t alignment_shift = offset_in_bytes - aligned_offset;
+ const size_t aligned_length = size_in_bytes + alignment_shift;
+
void *mmap_addr = mmap(
// Let system pick address for mmapp-ed data.
nullptr,
- // Mmap all bytes from the file.
- file_size_in_bytes,
+ aligned_length,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -215,16 +226,15 @@
// Descriptor of file to mmap.
fd,
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
+ aligned_offset);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, file_size_in_bytes);
+ return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift,
+ size_in_bytes);
}
bool Unmap(MmapHandle mmap_handle) {
diff --git a/native/lang_id/common/file/mmap.h b/native/lang_id/common/file/mmap.h
index f785465..923751a 100644
--- a/native/lang_id/common/file/mmap.h
+++ b/native/lang_id/common/file/mmap.h
@@ -19,6 +19,7 @@
#include <stddef.h>
+#include <cstddef>
#include <string>
#include "lang_id/common/lite_strings/stringpiece.h"
@@ -97,8 +98,15 @@
#endif
// Like MmapFile(const std::string &filename), but uses a file descriptor.
+// This function maps the entire file content.
MmapHandle MmapFile(FileDescriptorOrHandle fd);
+// Like MmapFile(const std::string &filename), but uses a file descriptor,
+// with an offset relative to the file start and a specified size, such that we
+// consider only a range of the file content.
+MmapHandle MmapFile(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes);
+
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
bool Unmap(MmapHandle mmap_handle);
@@ -112,6 +120,10 @@
explicit ScopedMmap(FileDescriptorOrHandle fd) : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes)
+ : handle_(MmapFile(fd, offset_in_bytes, size_in_bytes)) {}
+
~ScopedMmap() {
if (handle_.ok()) {
Unmap(handle_);
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index b2163eb..dc36fb7 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -44,6 +44,16 @@
new LangId(std::move(model_provider)));
}
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd, offset, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
size_t num_bytes) {
std::unique_ptr<ModelProvider> model_provider(
diff --git a/native/lang_id/fb_model/lang-id-from-fb.h b/native/lang_id/fb_model/lang-id-from-fb.h
index 061247b..eed843d 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.h
+++ b/native/lang_id/fb_model/lang-id-from-fb.h
@@ -40,6 +40,11 @@
FileDescriptorOrHandle fd);
// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor, staring at |offset| and of size |num_bytes|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
// the |num_bytes| bytes that start at address |data|.
//
// IMPORTANT: the model bytes must be alive during the lifetime of the returned
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index c81b116..43bf860 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -48,6 +48,16 @@
Initialize(scoped_mmap_->handle().to_stringpiece());
}
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
// Note: valid_ was initialized to false. In the code below, we set valid_ to
// true only if all initialization steps completed successfully. Otherwise,
diff --git a/native/lang_id/fb_model/model-provider-from-fb.h b/native/lang_id/fb_model/model-provider-from-fb.h
index c3def49..55e631c 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.h
+++ b/native/lang_id/fb_model/model-provider-from-fb.h
@@ -43,6 +43,11 @@
// file descriptor |fd|.
explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd);
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd, std::size_t offset,
+ std::size_t size);
+
// Constructs a model provider from a flatbuffer-format SAFT model the bytes
// of which are already in RAM (size bytes starting from address data).
// Useful if you "transport" these bytes otherwise than via a normal file
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index e4bb5d8..e86f198 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -93,6 +93,16 @@
return reinterpret_cast<jlong>(lang_id.release());
}
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ std::unique_ptr<LangId> lang_id =
+ GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
LangId* model = reinterpret_cast<LangId*>(ptr);
@@ -166,3 +176,13 @@
LangId* model = reinterpret_cast<LangId*>(ptr);
return model->GetFloatProperty("min_text_size_in_bytes", 0);
}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ std::unique_ptr<LangId> lang_id =
+ GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index 5eb2b00..e917197 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -20,7 +20,9 @@
#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
#include <jni.h>
+
#include <string>
+
#include "utils/java/jni-base.h"
#ifndef TC3_LANG_ID_CLASS_NAME
@@ -39,6 +41,9 @@
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject clazz, jstring path);
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
@@ -60,6 +65,9 @@
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
(JNIEnv* env, jobject thizz, jlong ptr);
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/lang_id/script/tiny-script-detector.h b/native/lang_id/script/tiny-script-detector.h
index a55da04..d08270c 100644
--- a/native/lang_id/script/tiny-script-detector.h
+++ b/native/lang_id/script/tiny-script-detector.h
@@ -74,12 +74,12 @@
// CPU, so it's better to use than int32.
static const unsigned int kGreekStart = 0x370;
- // Commented out (unsued in the code): kGreekEnd = 0x3FF;
+ // Commented out (unused in the code): kGreekEnd = 0x3FF;
static const unsigned int kCyrillicStart = 0x400;
static const unsigned int kCyrillicEnd = 0x4FF;
static const unsigned int kHebrewStart = 0x590;
- // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
+ // Commented out (unused in the code): kHebrewEnd = 0x5FF;
static const unsigned int kArabicStart = 0x600;
static const unsigned int kArabicEnd = 0x6FF;
const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
@@ -117,7 +117,7 @@
static const unsigned int kHiraganaStart = 0x3041;
static const unsigned int kHiraganaEnd = 0x309F;
- // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
+ // Commented out (unused in the code): kKatakanaStart = 0x30A0;
static const unsigned int kKatakanaEnd = 0x30FF;
const unsigned int codepoint =
((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
diff --git a/native/models/textclassifier.ar.model b/native/models/textclassifier.ar.model
index 87a442a..ff460e6 100755
--- a/native/models/textclassifier.ar.model
+++ b/native/models/textclassifier.ar.model
Binary files differ
diff --git a/native/models/textclassifier.en.model b/native/models/textclassifier.en.model
index 70e3cd2..9eca5dd 100755
--- a/native/models/textclassifier.en.model
+++ b/native/models/textclassifier.en.model
Binary files differ
diff --git a/native/models/textclassifier.es.model b/native/models/textclassifier.es.model
index 8ea0938..c25fef1 100755
--- a/native/models/textclassifier.es.model
+++ b/native/models/textclassifier.es.model
Binary files differ
diff --git a/native/models/textclassifier.fr.model b/native/models/textclassifier.fr.model
index 3ed3172..b98c075 100755
--- a/native/models/textclassifier.fr.model
+++ b/native/models/textclassifier.fr.model
Binary files differ
diff --git a/native/models/textclassifier.it.model b/native/models/textclassifier.it.model
index 4381909..5bb5a21 100755
--- a/native/models/textclassifier.it.model
+++ b/native/models/textclassifier.it.model
Binary files differ
diff --git a/native/models/textclassifier.ja.model b/native/models/textclassifier.ja.model
index 5db8e14..8851b7c 100755
--- a/native/models/textclassifier.ja.model
+++ b/native/models/textclassifier.ja.model
Binary files differ
diff --git a/native/models/textclassifier.ko.model b/native/models/textclassifier.ko.model
index a0d37ff..7b1b26a 100755
--- a/native/models/textclassifier.ko.model
+++ b/native/models/textclassifier.ko.model
Binary files differ
diff --git a/native/models/textclassifier.nl.model b/native/models/textclassifier.nl.model
index 5e627e0..7005cf4 100755
--- a/native/models/textclassifier.nl.model
+++ b/native/models/textclassifier.nl.model
Binary files differ
diff --git a/native/models/textclassifier.pl.model b/native/models/textclassifier.pl.model
index 7c43109..9d3b7e3 100755
--- a/native/models/textclassifier.pl.model
+++ b/native/models/textclassifier.pl.model
Binary files differ
diff --git a/native/models/textclassifier.pt.model b/native/models/textclassifier.pt.model
index b3b2232..4af2b0d 100755
--- a/native/models/textclassifier.pt.model
+++ b/native/models/textclassifier.pt.model
Binary files differ
diff --git a/native/models/textclassifier.ru.model b/native/models/textclassifier.ru.model
index 722afbe..fda7a7c 100755
--- a/native/models/textclassifier.ru.model
+++ b/native/models/textclassifier.ru.model
Binary files differ
diff --git a/native/models/textclassifier.th.model b/native/models/textclassifier.th.model
index b156ed7..f3b6ce5 100755
--- a/native/models/textclassifier.th.model
+++ b/native/models/textclassifier.th.model
Binary files differ
diff --git a/native/models/textclassifier.tr.model b/native/models/textclassifier.tr.model
index 5a66a11..8e34988 100755
--- a/native/models/textclassifier.tr.model
+++ b/native/models/textclassifier.tr.model
Binary files differ
diff --git a/native/models/textclassifier.universal.model b/native/models/textclassifier.universal.model
index 83704c3..09f1e0b 100755
--- a/native/models/textclassifier.universal.model
+++ b/native/models/textclassifier.universal.model
Binary files differ
diff --git a/native/models/textclassifier.zh.model b/native/models/textclassifier.zh.model
index 946d188..f664882 100755
--- a/native/models/textclassifier.zh.model
+++ b/native/models/textclassifier.zh.model
Binary files differ
diff --git a/native/utils/base/arena_test.cc b/native/utils/base/arena_test.cc
index d5e9bf3..a84190d 100644
--- a/native/utils/base/arena_test.cc
+++ b/native/utils/base/arena_test.cc
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#include "utils/base/arena.h"
#include "utils/base/logging.h"
diff --git a/native/utils/calendar/calendar_test-include.h b/native/utils/calendar/calendar_test-include.h
deleted file mode 100644
index 504d67e..0000000
--- a/native/utils/calendar/calendar_test-include.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This is a shared test between icu and javaicu calendar implementations.
-// It is meant to be #include'd.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-#include "utils/jvm-test-utils.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-class CalendarTest : public ::testing::Test {
- protected:
- CalendarTest()
- : calendarlib_(libtextclassifier3::CreateCalendarLibForTesting()) {}
- std::unique_ptr<CalendarLib> calendarlib_;
-};
-
-} // namespace test_internal
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
diff --git a/native/utils/calendar/calendar_test-include.cc b/native/utils/calendar/calendar_test.cc
similarity index 86%
rename from native/utils/calendar/calendar_test-include.cc
rename to native/utils/calendar/calendar_test.cc
index 36b9778..b94813c 100644
--- a/native/utils/calendar/calendar_test-include.cc
+++ b/native/utils/calendar/calendar_test.cc
@@ -14,12 +14,20 @@
* limitations under the License.
*/
-#include "utils/calendar/calendar_test-include.h"
+#include "utils/jvm-test-utils.h"
+#include "gtest/gtest.h"
namespace libtextclassifier3 {
-namespace test_internal {
+namespace {
-static constexpr int kWednesday = 4;
+class CalendarTest : public ::testing::Test {
+ protected:
+ CalendarTest()
+ : calendarlib_(libtextclassifier3::CreateCalendarLibForTesting()) {}
+
+ static constexpr int kWednesday = 4;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
TEST_F(CalendarTest, Interface) {
int64 time;
@@ -54,6 +62,43 @@
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
}
+TEST_F(CalendarTest, SetsTimeZone) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 30);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 10);
+
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514788210000L /* Jan 01 2018 07:30:10 GMT+01:00 */);
+
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ 60); // GMT+01:00
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514788210000L /* Jan 01 2018 07:30:10 GMT+01:00 */);
+
+ // Now the hour is in terms of GMT+02:00 which is one hour ahead of
+ // GMT+01:00.
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ 120); // GMT+02:00
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1514784610000L /* Jan 01 2018 06:30:10 GMT+01:00 */);
+}
+
TEST_F(CalendarTest, RoundingToGranularityBasic) {
int64 time;
DatetimeGranularity granularity;
@@ -320,5 +365,5 @@
EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
}
-} // namespace test_internal
+} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/flatbuffers_test.bfbs b/native/utils/flatbuffers/flatbuffers_test.bfbs
index 725e512..519550f 100644
--- a/native/utils/flatbuffers/flatbuffers_test.bfbs
+++ b/native/utils/flatbuffers/flatbuffers_test.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test.fbs b/native/utils/flatbuffers/flatbuffers_test.fbs
index 70b164a..f501e13 100644
--- a/native/utils/flatbuffers/flatbuffers_test.fbs
+++ b/native/utils/flatbuffers/flatbuffers_test.fbs
@@ -36,6 +36,7 @@
table NestedA {
nestedb: NestedB;
value: string;
+ repeated_str: [string];
}
table NestedB {
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
index ea8b7d2..fec4363 100644
--- a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.fbs b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
index 6ce9973..0410874 100644
--- a/native/utils/flatbuffers/flatbuffers_test_extended.fbs
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
@@ -36,6 +36,7 @@
table NestedA {
nestedb: NestedB;
value: string;
+ repeated_str: [string];
}
table NestedB {
diff --git a/native/utils/flatbuffers/mutable.cc b/native/utils/flatbuffers/mutable.cc
index ca3f1b0..0f425eb 100644
--- a/native/utils/flatbuffers/mutable.cc
+++ b/native/utils/flatbuffers/mutable.cc
@@ -295,6 +295,15 @@
return it->second.get();
}
+RepeatedField* MutableFlatbuffer::Repeated(const FlatbufferFieldPath* path) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return nullptr;
+ }
+ return parent->Repeated(field);
+}
+
flatbuffers::uoffset_t MutableFlatbuffer::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
// Build all children before we can start with this table.
@@ -391,43 +400,6 @@
builder.GetSize());
}
-template <>
-bool MutableFlatbuffer::AppendFromVector<std::string>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
- field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::String* element : *from_vector) {
- to_repeated->Add(element->str());
- }
- return true;
-}
-
-template <>
-bool MutableFlatbuffer::AppendFromVector<MutableFlatbuffer>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<const flatbuffers::Vector<
- flatbuffers::Offset<const flatbuffers::Table>>*>(field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::Table* const from_element : *from_vector) {
- MutableFlatbuffer* to_element = to_repeated->Add();
- if (to_element == nullptr) {
- return false;
- }
- to_element->MergeFrom(from_element);
- }
- return true;
-}
-
bool MutableFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
// No fields to set.
if (type_->fields() == nullptr) {
@@ -491,46 +463,13 @@
return false;
}
break;
- case reflection::Vector:
- switch (field->type()->element()) {
- case reflection::Int:
- AppendFromVector<int32>(from, field);
- break;
- case reflection::UInt:
- AppendFromVector<uint>(from, field);
- break;
- case reflection::Long:
- AppendFromVector<int64>(from, field);
- break;
- case reflection::ULong:
- AppendFromVector<uint64>(from, field);
- break;
- case reflection::Byte:
- AppendFromVector<int8_t>(from, field);
- break;
- case reflection::UByte:
- AppendFromVector<uint8_t>(from, field);
- break;
- case reflection::String:
- AppendFromVector<std::string>(from, field);
- break;
- case reflection::Obj:
- AppendFromVector<MutableFlatbuffer>(from, field);
- break;
- case reflection::Double:
- AppendFromVector<double>(from, field);
- break;
- case reflection::Float:
- AppendFromVector<float>(from, field);
- break;
- default:
- TC3_LOG(ERROR) << "Repeated unsupported type: "
- << field->type()->element()
- << " for field: " << field->name()->str();
- return false;
- break;
+ case reflection::Vector: {
+ if (RepeatedField* repeated_field = Repeated(field);
+ repeated_field == nullptr || !repeated_field->Extend(from)) {
+ return false;
}
break;
+ }
default:
TC3_LOG(ERROR) << "Unsupported type: " << type
<< " for field: " << field->name()->str();
@@ -561,6 +500,22 @@
}
}
+std::string RepeatedField::ToTextProto() const {
+ std::string result = " [";
+ std::string current_field_separator;
+ for (int index = 0; index < Size(); index++) {
+ if (is_primitive_) {
+ result.append(current_field_separator + items_.at(index).ToString());
+ } else {
+ result.append(current_field_separator + "{" +
+ Get<MutableFlatbuffer*>(index)->ToTextProto() + "}");
+ }
+ current_field_separator = ", ";
+ }
+ result.append("] ");
+ return result;
+}
+
std::string MutableFlatbuffer::ToTextProto() const {
std::string result;
std::string current_field_separator;
@@ -577,6 +532,14 @@
current_field_separator = ", ";
}
+ // Add repeated message
+ for (const auto& repeated_fb_pair : repeated_fields_) {
+ result.append(current_field_separator +
+ repeated_fb_pair.first->name()->c_str() + ": " +
+ repeated_fb_pair.second->ToTextProto());
+ current_field_separator = ", ";
+ }
+
// Add nested messages.
for (const auto& field_flatbuffer_pair : children_) {
const std::string field_name = field_flatbuffer_pair.first->name()->str();
@@ -618,6 +581,46 @@
} // namespace
+bool RepeatedField::Extend(const flatbuffers::Table* from) {
+ switch (field_->type()->element()) {
+ case reflection::Int:
+ AppendFromVector<int32>(from);
+ return true;
+ case reflection::UInt:
+ AppendFromVector<uint>(from);
+ return true;
+ case reflection::Long:
+ AppendFromVector<int64>(from);
+ return true;
+ case reflection::ULong:
+ AppendFromVector<uint64>(from);
+ return true;
+ case reflection::Byte:
+ AppendFromVector<int8_t>(from);
+ return true;
+ case reflection::UByte:
+ AppendFromVector<uint8_t>(from);
+ return true;
+ case reflection::String:
+ AppendFromVector<std::string>(from);
+ return true;
+ case reflection::Obj:
+ AppendFromVector<MutableFlatbuffer>(from);
+ return true;
+ case reflection::Double:
+ AppendFromVector<double>(from);
+ return true;
+ case reflection::Float:
+ AppendFromVector<float>(from);
+ return true;
+ default:
+ TC3_LOG(ERROR) << "Repeated unsupported type: "
+ << field_->type()->element()
+ << " for field: " << field_->name()->str();
+ return false;
+ }
+}
+
flatbuffers::uoffset_t RepeatedField::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
switch (field_->type()->element()) {
diff --git a/native/utils/flatbuffers/mutable.h b/native/utils/flatbuffers/mutable.h
index 8210e2a..90f6baa 100644
--- a/native/utils/flatbuffers/mutable.h
+++ b/native/utils/flatbuffers/mutable.h
@@ -144,6 +144,11 @@
RepeatedField* Repeated(StringPiece field_name);
RepeatedField* Repeated(const reflection::Field* field);
+ // Gets a repeated field specified by path.
+ // Returns nullptr if the field was not found, or the field
+ // type was not a repeated field.
+ RepeatedField* Repeated(const FlatbufferFieldPath* path);
+
// Serializes the flatbuffer.
flatbuffers::uoffset_t Serialize(
flatbuffers::FlatBufferBuilder* builder) const;
@@ -273,10 +278,17 @@
}
}
+ bool Extend(const flatbuffers::Table* from);
+
flatbuffers::uoffset_t Serialize(
flatbuffers::FlatBufferBuilder* builder) const;
+ std::string ToTextProto() const;
+
private:
+ template <typename T>
+ bool AppendFromVector(const flatbuffers::Table* from);
+
flatbuffers::uoffset_t SerializeString(
flatbuffers::FlatBufferBuilder* builder) const;
flatbuffers::uoffset_t SerializeObject(
@@ -314,7 +326,8 @@
Variant variant_value(value);
if (!IsMatchingType<T>(field->type()->base_type())) {
TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`, expected: " << field->type()->base_type()
+ << "`, expected: "
+ << EnumNameBaseType(field->type()->base_type())
<< ", got: " << variant_value.GetType();
return false;
}
@@ -366,17 +379,47 @@
}
template <typename T>
-bool MutableFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
- const reflection::Field* field) {
- const flatbuffers::Vector<T>* from_vector =
- from->GetPointer<const flatbuffers::Vector<T>*>(field->offset());
- if (from_vector == nullptr) {
+bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) {
+ const flatbuffers::Vector<T>* values =
+ from->GetPointer<const flatbuffers::Vector<T>*>(field_->offset());
+ if (values == nullptr) {
return false;
}
+ for (const T element : *values) {
+ Add(element);
+ }
+ return true;
+}
- RepeatedField* to_repeated = Repeated(field);
- for (const T element : *from_vector) {
- to_repeated->Add(element);
+template <>
+inline bool RepeatedField::AppendFromVector<std::string>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::String* element : *values) {
+ Add(element->str());
+ }
+ return true;
+}
+
+template <>
+inline bool RepeatedField::AppendFromVector<MutableFlatbuffer>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<const flatbuffers::Table>>*>(field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::Table* const from_element : *values) {
+ MutableFlatbuffer* to_element = Add();
+ if (to_element == nullptr) {
+ return false;
+ }
+ to_element->MergeFrom(from_element);
}
return true;
}
diff --git a/native/utils/flatbuffers/mutable_test.cc b/native/utils/flatbuffers/mutable_test.cc
index a119f1f..8fefc07 100644
--- a/native/utils/flatbuffers/mutable_test.cc
+++ b/native/utils/flatbuffers/mutable_test.cc
@@ -96,22 +96,13 @@
}
TEST_F(MutableFlatbufferTest, HandlesNestedFields) {
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> path =
+ CreateFieldPath({"flight_number", "carrier_code"});
std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
MutableFlatbuffer* parent = nullptr;
reflection::Field const* field = nullptr;
- EXPECT_TRUE(
- buffer->GetFieldWithParent(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- &parent, &field));
+ EXPECT_TRUE(buffer->GetFieldWithParent(path.get(), &parent, &field));
EXPECT_EQ(parent, buffer->Mutable("flight_number"));
EXPECT_EQ(field,
buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
@@ -245,10 +236,8 @@
std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
// Set a multiply nested field.
- std::unique_ptr<FlatbufferFieldPathT> field_path_t =
- CreateFieldPath("nested.nestedb.nesteda.nestedb.nesteda");
- OwnedFlatbuffer<FlatbufferFieldPath, std::string> field_path(
- PackFlatbuffer<FlatbufferFieldPath>(field_path_t.get()));
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> field_path =
+ CreateFieldPath({"nested", "nestedb", "nesteda", "nestedb", "nesteda"});
buffer->Mutable(field_path.get())->Set("value", "le value");
std::unique_ptr<test::EntityDataT> entity_data =
@@ -283,9 +272,22 @@
flight_info->Set("carrier_code", "LX");
flight_info->Set("flight_code", 38);
+ // Add non primitive type.
+ auto reminders = buffer->Repeated("reminders");
+ auto foo_reminder = reminders->Add();
+ foo_reminder->Set("title", "foo reminder");
+ auto bar_reminder = reminders->Add();
+ bar_reminder->Set("title", "bar reminder");
+
+ // Add primitive type.
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(111)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(222)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(333)));
+
EXPECT_EQ(buffer->ToTextProto(),
- "a_long_field: 84, an_int_field: 42, flight_number "
- "{flight_code: 38, carrier_code: 'LX'}");
+ "a_long_field: 84, an_int_field: 42, numbers: [111, 222, 333] , "
+ "reminders: [{title: 'foo reminder'}, {title: 'bar reminder'}] , "
+ "flight_number {flight_code: 38, carrier_code: 'LX'}");
}
TEST_F(MutableFlatbufferTest, RepeatedFieldSetThroughReflectionCanBeRead) {
@@ -346,5 +348,20 @@
EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(2), 3);
}
+TEST_F(MutableFlatbufferTest, GetsRepeatedFieldFromPath) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> notes =
+ CreateFieldPath({"nested", "repeated_str"});
+
+ EXPECT_TRUE(buffer->Repeated(notes.get())->Add("a"));
+ EXPECT_TRUE(buffer->Repeated(notes.get())->Add("test"));
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_THAT(entity_data->nested->repeated_str, SizeIs(2));
+ EXPECT_THAT(entity_data->nested->repeated_str, ElementsAre("a", "test"));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.h b/native/utils/flatbuffers/reflection.h
index 8650a95..9a0fec7 100644
--- a/native/utils/flatbuffers/reflection.h
+++ b/native/utils/flatbuffers/reflection.h
@@ -85,6 +85,64 @@
inline const reflection::BaseType flatbuffers_base_type<StringPiece>::value =
reflection::String;
+template <reflection::BaseType>
+struct flatbuffers_cpp_type;
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Bool> {
+ using value = bool;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Byte> {
+ using value = int8;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UByte> {
+ using value = uint8;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Short> {
+ using value = int16;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UShort> {
+ using value = uint16;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Int> {
+ using value = int32;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::UInt> {
+ using value = uint32;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Long> {
+ using value = int64;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::ULong> {
+ using value = uint64;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Float> {
+ using value = float;
+};
+
+template <>
+struct flatbuffers_cpp_type<reflection::BaseType::Double> {
+ using value = double;
+};
+
// Gets the field information for a field name, returns nullptr if the
// field was not defined.
const reflection::Field* GetFieldOrNull(const reflection::Object* type,
diff --git a/native/utils/flatbuffers/test-utils.h b/native/utils/flatbuffers/test-utils.h
index 5fd5b24..dbfc732 100644
--- a/native/utils/flatbuffers/test-utils.h
+++ b/native/utils/flatbuffers/test-utils.h
@@ -22,9 +22,8 @@
#include <fstream>
#include <string>
+#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/flatbuffers_generated.h"
-#include "utils/strings/split.h"
-#include "utils/strings/stringpiece.h"
#include "utils/test-data-test-utils.h"
#include "gtest/gtest.h"
@@ -38,14 +37,21 @@
}
// Creates a flatbuffer field path from a dot separated field path string.
-inline std::unique_ptr<FlatbufferFieldPathT> CreateFieldPath(
- const StringPiece path) {
- std::unique_ptr<FlatbufferFieldPathT> field_path(new FlatbufferFieldPathT);
- for (const StringPiece field : strings::Split(path, '.')) {
- field_path->field.emplace_back(new FlatbufferFieldT);
- field_path->field.back()->field_name = field.ToString();
+inline std::unique_ptr<FlatbufferFieldPathT> CreateUnpackedFieldPath(
+ const std::vector<std::string>& fields) {
+ std::unique_ptr<FlatbufferFieldPathT> path(new FlatbufferFieldPathT);
+ for (const std::string& field : fields) {
+ path->field.emplace_back(new FlatbufferFieldT);
+ path->field.back()->field_name = field;
}
- return field_path;
+ return path;
+}
+
+inline OwnedFlatbuffer<FlatbufferFieldPath, std::string> CreateFieldPath(
+ const std::vector<std::string>& fields) {
+ std::unique_ptr<FlatbufferFieldPathT> path = CreateUnpackedFieldPath(fields);
+ return OwnedFlatbuffer<FlatbufferFieldPath, std::string>(
+ PackFlatbuffer<FlatbufferFieldPath>(path.get()));
}
} // namespace libtextclassifier3
diff --git a/native/utils/grammar/analyzer.cc b/native/utils/grammar/analyzer.cc
new file mode 100644
index 0000000..fcba217
--- /dev/null
+++ b/native/utils/grammar/analyzer.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/analyzer.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::grammar {
+
+Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set)
+ // TODO(smillius): Add tokenizer options to `RulesSet`.
+ : owned_tokenizer_(new Tokenizer(libtextclassifier3::TokenizationType_ICU,
+ unilib,
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false)),
+ tokenizer_(owned_tokenizer_.get()),
+ parser_(unilib, rules_set),
+ semantic_evaluator_(rules_set->semantic_values_schema() != nullptr
+ ? flatbuffers::GetRoot<reflection::Schema>(
+ rules_set->semantic_values_schema()->data())
+ : nullptr) {}
+
+Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set,
+ const Tokenizer* tokenizer)
+ : tokenizer_(tokenizer),
+ parser_(unilib, rules_set),
+ semantic_evaluator_(rules_set->semantic_values_schema() != nullptr
+ ? flatbuffers::GetRoot<reflection::Schema>(
+ rules_set->semantic_values_schema()->data())
+ : nullptr) {}
+
+StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse(
+ const TextContext& input, UnsafeArena* arena) const {
+ std::vector<EvaluatedDerivation> result;
+
+ // Evaluate each derivation.
+ for (const Derivation& derivation :
+ ValidDeduplicatedDerivations(parser_.Parse(input, arena))) {
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ semantic_evaluator_.Eval(input, derivation, arena));
+ result.emplace_back(EvaluatedDerivation{std::move(derivation), value});
+ }
+
+ return result;
+}
+
+StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse(
+ const UnicodeText& text, const std::vector<Locale>& locales,
+ UnsafeArena* arena) const {
+ return Parse(BuildTextContextForInput(text, locales), arena);
+}
+
+TextContext Analyzer::BuildTextContextForInput(
+ const UnicodeText& text, const std::vector<Locale>& locales) const {
+ TextContext context;
+ context.text = UnicodeText(text, /*do_copy=*/false);
+ context.tokens = tokenizer_->Tokenize(context.text);
+ context.codepoints = context.text.Codepoints();
+ context.codepoints.push_back(context.text.end());
+ context.locales = locales;
+ context.context_span.first = 0;
+ context.context_span.second = context.tokens.size();
+ return context;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/analyzer.h b/native/utils/grammar/analyzer.h
new file mode 100644
index 0000000..c83c622
--- /dev/null
+++ b/native/utils/grammar/analyzer.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/evaluated-derivation.h"
+#include "utils/grammar/parsing/parser.h"
+#include "utils/grammar/semantics/composer.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// An analyzer that parses and semantically evaluates an input text with a
+// grammar.
+class Analyzer {
+ public:
+ explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set);
+ explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set,
+ const Tokenizer* tokenizer);
+
+ // Parses and evaluates an input.
+ StatusOr<std::vector<EvaluatedDerivation>> Parse(const TextContext& input,
+ UnsafeArena* arena) const;
+ StatusOr<std::vector<EvaluatedDerivation>> Parse(
+ const UnicodeText& text, const std::vector<Locale>& locales,
+ UnsafeArena* arena) const;
+
+ // Pre-processes an input text for parsing.
+ TextContext BuildTextContextForInput(
+ const UnicodeText& text, const std::vector<Locale>& locales = {}) const;
+
+ const Parser& parser() const { return parser_; }
+
+ private:
+ std::unique_ptr<Tokenizer> owned_tokenizer_;
+ const Tokenizer* tokenizer_;
+ Parser parser_;
+ SemanticComposer semantic_evaluator_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_
diff --git a/native/utils/grammar/analyzer_test.cc b/native/utils/grammar/analyzer_test.cc
new file mode 100644
index 0000000..9f71efe
--- /dev/null
+++ b/native/utils/grammar/analyzer_test.cc
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/analyzer.h"
+
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::SizeIs;
+
+class AnalyzerTest : public GrammarTest {};
+
+TEST_F(AnalyzerTest, ParsesTextWithGrammar) {
+ RulesSetT model;
+
+ // Add semantic values schema.
+ model.semantic_values_schema.assign(semantic_values_schema_.buffer().begin(),
+ semantic_values_schema_.buffer().end());
+
+ // Define rules and semantics.
+ Rules rules;
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ model.semantic_expression.push_back(CreatePrimitiveConstExpression(1));
+
+ rules.Add("<month>", {"february"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ model.semantic_expression.push_back(CreatePrimitiveConstExpression(2));
+
+ const int kMonth = 0;
+ rules.Add("<month_rule>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+
+ Analyzer analyzer(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+
+ {
+ auto maybe_results = analyzer.Parse(
+ UTF8ToUnicodeText("The month is January 2020", /*do_copy=*/false),
+ /*locales=*/{}, &arena_);
+ EXPECT_TRUE(maybe_results.ok());
+
+ const std::vector<EvaluatedDerivation> results = maybe_results.ValueOrDie();
+ EXPECT_THAT(results, SizeIs(1));
+
+ // Check parse tree.
+ EXPECT_THAT(
+ results[0].derivation,
+ IsDerivation(kMonth /* rule_id */, 13 /* begin */, 20 /* end */));
+
+ // Check semantic result.
+ EXPECT_EQ(results[0].value->Value<int32>(), 1);
+ }
+
+ {
+ auto maybe_results =
+ analyzer.Parse(UTF8ToUnicodeText("february", /*do_copy=*/false),
+ /*locales=*/{}, &arena_);
+ EXPECT_TRUE(maybe_results.ok());
+
+ const std::vector<EvaluatedDerivation> results = maybe_results.ValueOrDie();
+ EXPECT_THAT(results, SizeIs(1));
+
+ // Check parse tree.
+ EXPECT_THAT(results[0].derivation,
+ IsDerivation(kMonth /* rule_id */, 0 /* begin */, 8 /* end */));
+
+ // Check semantic result.
+ EXPECT_EQ(results[0].value->Value<int32>(), 2);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/callback-delegate.h b/native/utils/grammar/callback-delegate.h
deleted file mode 100644
index a5424dd..0000000
--- a/native/utils/grammar/callback-delegate.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
-
-#include "utils/base/integral_types.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-
-namespace libtextclassifier3::grammar {
-
-class Matcher;
-
-// CallbackDelegate is an interface and default implementation used by the
-// grammar matcher to dispatch rule matches.
-class CallbackDelegate {
- public:
- virtual ~CallbackDelegate() = default;
-
- // This is called by the matcher whenever it finds a match for a rule to
- // which a callback is attached.
- virtual void MatchFound(const Match* match, const CallbackId callback_id,
- const int64 callback_param, Matcher* matcher) {}
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
diff --git a/native/utils/grammar/evaluated-derivation.h b/native/utils/grammar/evaluated-derivation.h
new file mode 100644
index 0000000..bac252a
--- /dev/null
+++ b/native/utils/grammar/evaluated-derivation.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
+
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// A parse tree for a root rule and its semantic value.
+struct EvaluatedDerivation {
+ Derivation derivation;
+ const SemanticValue* value;
+};
+
+}; // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_
diff --git a/native/utils/grammar/lexer.cc b/native/utils/grammar/lexer.cc
deleted file mode 100644
index 3a2d0d3..0000000
--- a/native/utils/grammar/lexer.cc
+++ /dev/null
@@ -1,321 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/lexer.h"
-
-#include <unordered_map>
-
-#include "annotator/types.h"
-#include "utils/zlib/zlib.h"
-#include "utils/zlib/zlib_regex.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-inline bool CheckMemoryUsage(const Matcher* matcher) {
- // The maximum memory usage for matching.
- constexpr int kMaxMemoryUsage = 1 << 20;
- return matcher->ArenaSize() <= kMaxMemoryUsage;
-}
-
-Match* CheckedAddMatch(const Nonterm nonterm,
- const CodepointSpan codepoint_span,
- const int match_offset, const int16 type,
- Matcher* matcher) {
- if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) {
- return nullptr;
- }
- return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span,
- match_offset, type);
-}
-
-void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span,
- const int match_offset, int16 type, Matcher* matcher) {
- if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) {
- matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
- nonterm, codepoint_span, match_offset, type));
- }
-}
-
-int MapCodepointToTokenPaddingIfPresent(
- const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
- const int start) {
- const auto it = token_alignment.find(start);
- if (it != token_alignment.end()) {
- return it->second;
- }
- return start;
-}
-
-} // namespace
-
-Lexer::Lexer(const UniLib* unilib, const RulesSet* rules)
- : unilib_(*unilib),
- rules_(rules),
- regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {}
-
-std::vector<Lexer::RegexAnnotator> Lexer::BuildRegexAnnotator(
- const UniLib& unilib, const RulesSet* rules) const {
- std::vector<Lexer::RegexAnnotator> result;
- if (rules->regex_annotator() != nullptr) {
- std::unique_ptr<ZlibDecompressor> decompressor =
- ZlibDecompressor::Instance();
- result.reserve(rules->regex_annotator()->size());
- for (const RulesSet_::RegexAnnotator* regex_annotator :
- *rules->regex_annotator()) {
- result.push_back(
- {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
- regex_annotator->compressed_pattern(),
- rules->lazy_regex_compilation(),
- decompressor.get()),
- regex_annotator->nonterminal()});
- }
- }
- return result;
-}
-
-void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const {
- switch (symbol.type) {
- case Symbol::Type::TYPE_MATCH: {
- // Just emit the match.
- matcher->AddMatch(symbol.match);
- return;
- }
- case Symbol::Type::TYPE_DIGITS: {
- // Emit <digits> if used by the rules.
- CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span,
- symbol.match_offset, Match::kDigitsType, matcher);
-
- // Emit <n_digits> if used by the rules.
- if (nonterms->n_digits_nt() != nullptr) {
- const int num_digits =
- symbol.codepoint_span.second - symbol.codepoint_span.first;
- if (num_digits <= nonterms->n_digits_nt()->size()) {
- CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1),
- symbol.codepoint_span, symbol.match_offset,
- Match::kDigitsType, matcher);
- }
- }
- break;
- }
- case Symbol::Type::TYPE_TERM: {
- // Emit <uppercase_token> if used by the rules.
- if (nonterms->uppercase_token_nt() != 0 &&
- unilib_.IsUpperText(
- UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
- CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span,
- symbol.match_offset, Match::kTokenType, matcher);
- }
- break;
- }
- default:
- break;
- }
-
- // Emit the token as terminal.
- if (CheckMemoryUsage(matcher)) {
- matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
- symbol.lexeme);
- }
-
- // Emit <token> if used by rules.
- CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset,
- Match::kTokenType, matcher);
-}
-
-Lexer::Symbol::Type Lexer::GetSymbolType(
- const UnicodeText::const_iterator& it) const {
- if (unilib_.IsPunctuation(*it)) {
- return Symbol::Type::TYPE_PUNCTUATION;
- } else if (unilib_.IsDigit(*it)) {
- return Symbol::Type::TYPE_DIGITS;
- } else {
- return Symbol::Type::TYPE_TERM;
- }
-}
-
-void Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- std::vector<Lexer::Symbol>* symbols) const {
- // Possibly split token.
- UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
- /*do_copy=*/false);
- int last_end = prev_token_end;
- auto token_end = token_unicode.end();
- auto it = token_unicode.begin();
- Symbol::Type type = GetSymbolType(it);
- CodepointIndex sub_token_start = codepoint_span.first;
- while (it != token_end) {
- auto next = std::next(it);
- int num_codepoints = 1;
- Symbol::Type next_type;
- while (next != token_end) {
- next_type = GetSymbolType(next);
- if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
- break;
- }
- ++next;
- ++num_codepoints;
- }
- symbols->push_back(Symbol{
- type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
- /*match_offset=*/last_end,
- /*lexeme=*/
- StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())});
- last_end = sub_token_start + num_codepoints;
- it = next;
- type = next_type;
- sub_token_start = last_end;
- }
-}
-
-void Lexer::Process(const UnicodeText& text, const std::vector<Token>& tokens,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const {
- return Process(text, tokens.begin(), tokens.end(), annotations, matcher);
-}
-
-void Lexer::Process(const UnicodeText& text,
- const std::vector<Token>::const_iterator& begin,
- const std::vector<Token>::const_iterator& end,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const {
- if (begin == end) {
- return;
- }
-
- const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals();
-
- // Initialize processing of new text.
- CodepointIndex prev_token_end = 0;
- std::vector<Symbol> symbols;
- matcher->Reset();
-
- // The matcher expects the terminals and non-terminals it received to be in
- // non-decreasing end-position order. The sorting above makes sure the
- // pre-defined matches adhere to that order.
- // Ideally, we would just have to emit a predefined match whenever we see that
- // the next token we feed would be ending later.
- // But as we implicitly ignore whitespace, we have to merge preceding
- // whitespace to the match start so that tokens and non-terminals fed appear
- // as next to each other without whitespace.
- // We keep track of real token starts and precending whitespace in
- // `token_match_start`, so that we can extend a predefined match's start to
- // include the preceding whitespace.
- std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
-
- // Add start symbols.
- if (Match* match =
- CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0},
- /*match_offset=*/0, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- if (Match* match =
- CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0},
- /*match_offset=*/0, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
-
- for (auto token_it = begin; token_it != end; token_it++) {
- const Token& token = *token_it;
-
- // Record match starts for token boundaries, so that we can snap pre-defined
- // matches to it.
- if (prev_token_end != token.start) {
- token_match_start[token.start] = prev_token_end;
- }
-
- ProcessToken(token.value,
- /*prev_token_end=*/prev_token_end,
- CodepointSpan{token.start, token.end}, &symbols);
- prev_token_end = token.end;
-
- // Add word break symbol if used by the grammar.
- if (Match* match = CheckedAddMatch(
- nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end},
- /*match_offset=*/token.end, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- }
-
- // Add end symbol if used by the grammar.
- if (Match* match = CheckedAddMatch(
- nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
- /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) {
- symbols.push_back(Symbol(match));
- }
-
- // Add matches based on annotations.
- auto annotation_nonterminals = nonterminals->annotation_nt();
- if (annotation_nonterminals != nullptr && annotations != nullptr) {
- for (const AnnotatedSpan& annotated_span : *annotations) {
- const ClassificationResult& classification =
- annotated_span.classification.front();
- if (auto entry = annotation_nonterminals->LookupByKey(
- classification.collection.c_str())) {
- AnnotationMatch* match = matcher->AllocateAndInitMatch<AnnotationMatch>(
- entry->value(), annotated_span.span,
- /*match_offset=*/
- MapCodepointToTokenPaddingIfPresent(token_match_start,
- annotated_span.span.first),
- Match::kAnnotationMatch);
- match->annotation = &classification;
- symbols.push_back(Symbol(match));
- }
- }
- }
-
- // Add regex annotator matches for the range covered by the tokens.
- for (const RegexAnnotator& regex_annotator : regex_annotators_) {
- std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
- regex_annotator.pattern->Matcher(UnicodeText::Substring(
- text, begin->start, prev_token_end, /*do_copy=*/false));
- int status = UniLib::RegexMatcher::kNoError;
- while (regex_matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError) {
- const CodepointSpan span = {
- regex_matcher->Start(0, &status) + begin->start,
- regex_matcher->End(0, &status) + begin->start};
- if (Match* match =
- CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/
- MapCodepointToTokenPaddingIfPresent(
- token_match_start, span.first),
- Match::kUnknownType, matcher)) {
- symbols.push_back(Symbol(match));
- }
- }
- }
-
- std::sort(symbols.begin(), symbols.end(),
- [](const Symbol& a, const Symbol& b) {
- // Sort by increasing (end, start) position to guarantee the
- // matcher requirement that the tokens are fed in non-decreasing
- // end position order.
- return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
- std::tie(b.codepoint_span.second, b.codepoint_span.first);
- });
-
- // Emit symbols to matcher.
- for (const Symbol& symbol : symbols) {
- Emit(symbol, nonterminals, matcher);
- }
-
- // Finish the matching.
- matcher->Finish();
-}
-
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/lexer.h b/native/utils/grammar/lexer.h
deleted file mode 100644
index ca31c25..0000000
--- a/native/utils/grammar/lexer.h
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This is a lexer that runs off the tokenizer and outputs the tokens to a
-// grammar matcher. The tokens it forwards are the same as the ones produced
-// by the tokenizer, but possibly further split and normalized (downcased).
-// Examples:
-//
-// - single character tokens for punctuation (e.g., AddTerminal("?"))
-//
-// - a string of letters (e.g., "Foo" -- it calls AddTerminal() on "foo")
-//
-// - a string of digits (e.g., AddTerminal("37"))
-//
-// In addition to the terminal tokens above, it also outputs certain
-// special nonterminals:
-//
-// - a <token> nonterminal, which it outputs in addition to the
-// regular AddTerminal() call for every token
-//
-// - a <digits> nonterminal, which it outputs in addition to
-// the regular AddTerminal() call for each string of digits
-//
-// - <N_digits> nonterminals, where N is the length of the string of
-// digits. By default the maximum N that will be output is 20. This
-// may be changed at compile time by kMaxNDigitsLength. For instance,
-// "123" will produce a <3_digits> nonterminal, "1234567" will produce
-// a <7_digits> nonterminal.
-//
-// It does not output any whitespace. Instead, whitespace gets absorbed into
-// the token that follows them in the text.
-// For example, if the text contains:
-//
-// ...hello there world...
-// | | |
-// offset=16 39 52
-//
-// then the output will be:
-//
-// "hello" [?, 16)
-// "there" [16, 44) <-- note "16" NOT "39"
-// "world" [44, ?) <-- note "44" NOT "52"
-//
-// This makes it appear to the Matcher as if the tokens are adjacent -- so
-// whitespace is simply ignored.
-//
-// A minor optimization: We don't bother to output nonterminals if the grammar
-// rules don't reference them.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
-
-#include "annotator/types.h"
-#include "utils/grammar/matcher.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::grammar {
-
-class Lexer {
- public:
- explicit Lexer(const UniLib* unilib, const RulesSet* rules);
-
- // Processes a tokenized text. Classifies the tokens and feeds them to the
- // matcher.
- // The provided annotations will be fed to the matcher alongside the tokens.
- // NOTE: The `annotations` need to outlive any dependent processing.
- void Process(const UnicodeText& text, const std::vector<Token>& tokens,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const;
- void Process(const UnicodeText& text,
- const std::vector<Token>::const_iterator& begin,
- const std::vector<Token>::const_iterator& end,
- const std::vector<AnnotatedSpan>* annotations,
- Matcher* matcher) const;
-
- private:
- // A lexical symbol with an identified meaning that represents raw tokens,
- // token categories or predefined text matches.
- // It is the unit fed to the grammar matcher.
- struct Symbol {
- // The type of the lexical symbol.
- enum class Type {
- // A raw token.
- TYPE_TERM,
-
- // A symbol representing a string of digits.
- TYPE_DIGITS,
-
- // Punctuation characters.
- TYPE_PUNCTUATION,
-
- // A predefined match.
- TYPE_MATCH
- };
-
- explicit Symbol() = default;
-
- // Constructs a symbol of a given type with an anchor in the text.
- Symbol(const Type type, const CodepointSpan codepoint_span,
- const int match_offset, StringPiece lexeme)
- : type(type),
- codepoint_span(codepoint_span),
- match_offset(match_offset),
- lexeme(lexeme) {}
-
- // Constructs a symbol from a pre-defined match.
- explicit Symbol(Match* match)
- : type(Type::TYPE_MATCH),
- codepoint_span(match->codepoint_span),
- match_offset(match->match_offset),
- match(match) {}
-
- // The type of the symbole.
- Type type;
-
- // The span in the text as codepoint offsets.
- CodepointSpan codepoint_span;
-
- // The match start offset (including preceding whitespace) as codepoint
- // offset.
- int match_offset;
-
- // The symbol text value.
- StringPiece lexeme;
-
- // The predefined match.
- Match* match;
- };
-
- // Processes a single token: the token is split and classified into symbols.
- void ProcessToken(const StringPiece value, const int prev_token_end,
- const CodepointSpan codepoint_span,
- std::vector<Symbol>* symbols) const;
-
- // Emits a token to the matcher.
- void Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
- Matcher* matcher) const;
-
- // Gets the type of a character.
- Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
-
- private:
- struct RegexAnnotator {
- std::unique_ptr<UniLib::RegexPattern> pattern;
- Nonterm nonterm;
- };
-
- // Uncompress and build the defined regex annotators.
- std::vector<RegexAnnotator> BuildRegexAnnotator(const UniLib& unilib,
- const RulesSet* rules) const;
-
- const UniLib& unilib_;
- const RulesSet* rules_;
- std::vector<RegexAnnotator> regex_annotators_;
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
diff --git a/native/utils/grammar/match.cc b/native/utils/grammar/match.cc
deleted file mode 100644
index ecf9874..0000000
--- a/native/utils/grammar/match.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/match.h"
-
-#include <algorithm>
-#include <stack>
-
-namespace libtextclassifier3::grammar {
-
-void Traverse(const Match* root,
- const std::function<bool(const Match*)>& node_fn) {
- std::stack<const Match*> open;
- open.push(root);
-
- while (!open.empty()) {
- const Match* node = open.top();
- open.pop();
- if (!node_fn(node) || node->IsLeaf()) {
- continue;
- }
- open.push(node->rhs2);
- if (node->rhs1 != nullptr) {
- open.push(node->rhs1);
- }
- }
-}
-
-const Match* SelectFirst(const Match* root,
- const std::function<bool(const Match*)>& pred_fn) {
- std::stack<const Match*> open;
- open.push(root);
-
- while (!open.empty()) {
- const Match* node = open.top();
- open.pop();
- if (pred_fn(node)) {
- return node;
- }
- if (node->IsLeaf()) {
- continue;
- }
- open.push(node->rhs2);
- if (node->rhs1 != nullptr) {
- open.push(node->rhs1);
- }
- }
-
- return nullptr;
-}
-
-std::vector<const Match*> SelectAll(
- const Match* root, const std::function<bool(const Match*)>& pred_fn) {
- std::vector<const Match*> result;
- Traverse(root, [&result, pred_fn](const Match* node) {
- if (pred_fn(node)) {
- result.push_back(node);
- }
- return true;
- });
- return result;
-}
-
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/match.h b/native/utils/grammar/match.h
deleted file mode 100644
index f96703d..0000000
--- a/native/utils/grammar/match.h
+++ /dev/null
@@ -1,172 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
-
-#include <functional>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/grammar/types.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3::grammar {
-
-// Represents a single match that was found for a particular nonterminal.
-// Instances should be created by calling Matcher::AllocateMatch().
-// This uses an arena to allocate matches (and subclasses thereof).
-struct Match {
- static constexpr int16 kUnknownType = 0;
- static constexpr int16 kTokenType = -1;
- static constexpr int16 kDigitsType = -2;
- static constexpr int16 kBreakType = -3;
- static constexpr int16 kAssertionMatch = -4;
- static constexpr int16 kMappingMatch = -5;
- static constexpr int16 kExclusionMatch = -6;
- static constexpr int16 kAnnotationMatch = -7;
-
- void Init(const Nonterm arg_lhs, const CodepointSpan arg_codepoint_span,
- const int arg_match_offset, const int arg_type = kUnknownType) {
- lhs = arg_lhs;
- codepoint_span = arg_codepoint_span;
- match_offset = arg_match_offset;
- type = arg_type;
- rhs1 = nullptr;
- rhs2 = nullptr;
- }
-
- void Init(const Match& other) { *this = other; }
-
- // For binary rule matches: rhs1 != NULL and rhs2 != NULL
- // unary rule matches: rhs1 == NULL and rhs2 != NULL
- // terminal rule matches: rhs1 != NULL and rhs2 == NULL
- // custom leaves: rhs1 == NULL and rhs2 == NULL
- bool IsInteriorNode() const { return rhs2 != nullptr; }
- bool IsLeaf() const { return !rhs2; }
-
- bool IsBinaryRule() const { return rhs1 && rhs2; }
- bool IsUnaryRule() const { return !rhs1 && rhs2; }
- bool IsTerminalRule() const { return rhs1 && !rhs2; }
- bool HasLeadingWhitespace() const {
- return codepoint_span.first != match_offset;
- }
-
- const Match* unary_rule_rhs() const { return rhs2; }
-
- // Used in singly-linked queue of matches for processing.
- Match* next = nullptr;
-
- // Nonterminal we found a match for.
- Nonterm lhs = kUnassignedNonterm;
-
- // Type of the match.
- int16 type = kUnknownType;
-
- // The span in codepoints.
- CodepointSpan codepoint_span = CodepointSpan::kInvalid;
-
- // The begin codepoint offset used during matching.
- // This is usually including any prefix whitespace.
- int match_offset;
-
- union {
- // The first sub match for binary rules.
- const Match* rhs1 = nullptr;
-
- // The terminal, for terminal rules.
- const char* terminal;
- };
- // First or second sub-match for interior nodes.
- const Match* rhs2 = nullptr;
-};
-
-// Match type to keep track of associated values.
-struct MappingMatch : public Match {
- // The associated id or value.
- int64 id;
-};
-
-// Match type to keep track of assertions.
-struct AssertionMatch : public Match {
- // If true, the assertion is negative and will be valid if the input doesn't
- // match.
- bool negative;
-};
-
-// Match type to define exclusions.
-struct ExclusionMatch : public Match {
- // The nonterminal that denotes matches to exclude from a successful match.
- // So the match is only valid if there is no match of `exclusion_nonterm`
- // spanning the same text range.
- Nonterm exclusion_nonterm;
-};
-
-// Match to represent an annotator annotated span in the grammar.
-struct AnnotationMatch : public Match {
- const ClassificationResult* annotation;
-};
-
-// Utility functions for parse tree traversal.
-
-// Does a preorder traversal, calling `node_fn` on each node.
-// `node_fn` is expected to return whether to continue expanding a node.
-void Traverse(const Match* root,
- const std::function<bool(const Match*)>& node_fn);
-
-// Does a preorder traversal, calling `pred_fn` and returns the first node
-// on which `pred_fn` returns true.
-const Match* SelectFirst(const Match* root,
- const std::function<bool(const Match*)>& pred_fn);
-
-// Does a preorder traversal, selecting all nodes where `pred_fn` returns true.
-std::vector<const Match*> SelectAll(
- const Match* root, const std::function<bool(const Match*)>& pred_fn);
-
-// Selects all terminals from a parse tree.
-inline std::vector<const Match*> SelectTerminals(const Match* root) {
- return SelectAll(root, &Match::IsTerminalRule);
-}
-
-// Selects all leaves from a parse tree.
-inline std::vector<const Match*> SelectLeaves(const Match* root) {
- return SelectAll(root, &Match::IsLeaf);
-}
-
-// Retrieves the first child node of a given type.
-template <typename T>
-const T* SelectFirstOfType(const Match* root, const int16 type) {
- return static_cast<const T*>(SelectFirst(
- root, [type](const Match* node) { return node->type == type; }));
-}
-
-// Retrieves all nodes of a given type.
-template <typename T>
-const std::vector<const T*> SelectAllOfType(const Match* root,
- const int16 type) {
- std::vector<const T*> result;
- Traverse(root, [&result, type](const Match* node) {
- if (node->type == type) {
- result.push_back(static_cast<const T*>(node));
- }
- return true;
- });
- return result;
-}
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
diff --git a/native/utils/grammar/matcher.h b/native/utils/grammar/matcher.h
deleted file mode 100644
index 47bac43..0000000
--- a/native/utils/grammar/matcher.h
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// A token matcher based on context-free grammars.
-//
-// A lexer passes token to the matcher: literal terminal strings and token
-// types. It passes tokens to the matcher by calling AddTerminal() and
-// AddMatch() for literal terminals and token types, respectively.
-// The lexer passes each token along with the [begin, end) position range
-// in which it occurs. So for an input string "Groundhog February 2, 2007", the
-// lexer would tell the matcher that:
-//
-// "Groundhog" occurs at [0, 9)
-// <space> occurs at [9, 10)
-// "February" occurs at [10, 18)
-// <space> occurs at [18, 19)
-// <string_of_digits> occurs at [19, 20)
-// "," occurs at [20, 21)
-// <space> occurs at [21, 22)
-// <string_of_digits> occurs at [22, 26)
-//
-// The lexer passes tokens to the matcher by calling AddTerminal() and
-// AddMatch() for literal terminals and token types, respectively.
-//
-// Although it is unnecessary for this example grammar, a lexer can
-// output multiple tokens for the same input range. So our lexer could
-// additionally output:
-// "2" occurs at [19, 20) // a second token for [19, 20)
-// "2007" occurs at [22, 26)
-// <syllable> occurs at [0, 6) // overlaps with (Groundhog [0, 9))
-// <syllable> occurs at [6, 9)
-// The only constraint on the lexer's output is that it has to pass tokens
-// to the matcher in left-to-right order, strictly speaking, their "end"
-// positions must be nondecreasing. (This constraint allows a more
-// efficient matching algorithm.) The "begin" positions can be in any
-// order.
-//
-// There are two kinds of supported callbacks:
-// (1) OUTPUT: Callbacks are the only output mechanism a matcher has. For each
-// "top-level" rule in your grammar, like the rule for <date> above -- something
-// you're trying to find instances of -- you use a callback which the matcher
-// will invoke every time it finds an instance of <date>.
-// (2) FILTERS:
-// Callbacks allow you to put extra conditions on when a grammar rule
-// applies. In the example grammar, the rule
-//
-// <day> ::= <string_of_digits> // must be between 1 and 31
-//
-// should only apply for *some* <string_of_digits> tokens, not others.
-// By using a filter callback on this rule, you can tell the matcher that
-// an instance of the rule's RHS is only *sometimes* considered an
-// instance of its LHS. The filter callback will get invoked whenever
-// the matcher finds an instance of <string_of_digits>. The callback can
-// look at the digits and decide whether they represent a number between
-// 1 and 31. If so, the callback calls Matcher::AddMatch() to tell the
-// matcher there's a <day> there. If not, the callback simply exits
-// without calling AddMatch().
-//
-// Technically, a FILTER callback can make any number of calls to
-// AddMatch() or even AddTerminal(). But the expected usage is to just
-// make zero or one call to AddMatch(). OUTPUT callbacks are not expected
-// to call either of these -- output callbacks are invoked merely as a
-// side-effect, not in order to decide whether a rule applies or not.
-//
-// In the above example, you would probably use three callbacks. Filter
-// callbacks on the rules for <day> and <year> would check the numeric
-// value of the <string_of_digits>. An output callback on the rule for
-// <date> would simply increment the counter of dates found on the page.
-//
-// Note that callbacks are attached to rules, not to nonterminals. You
-// could have two alternative rules for <date> and use a different
-// callback for each one.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
-#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
-
-#include <array>
-#include <functional>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/base/arena.h"
-#include "utils/grammar/callback-delegate.h"
-#include "utils/grammar/match.h"
-#include "utils/grammar/rules_generated.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3::grammar {
-
-class Matcher {
- public:
- explicit Matcher(const UniLib* unilib, const RulesSet* rules,
- const std::vector<const RulesSet_::Rules*> rules_shards,
- CallbackDelegate* delegate)
- : state_(STATE_DEFAULT),
- unilib_(*unilib),
- arena_(kBlocksize),
- rules_(rules),
- rules_shards_(rules_shards),
- delegate_(delegate) {
- TC3_CHECK(rules_ != nullptr);
- Reset();
- }
- explicit Matcher(const UniLib* unilib, const RulesSet* rules,
- CallbackDelegate* delegate)
- : Matcher(unilib, rules, {}, delegate) {
- rules_shards_.reserve(rules->rules()->size());
- rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(),
- rules->rules()->end());
- }
-
- // Resets the matcher.
- void Reset();
-
- // Finish the matching.
- void Finish();
-
- // Tells the matcher that the given terminal was found occupying position
- // range [begin, end) in the input.
- // The matcher may invoke callback functions before returning, if this
- // terminal triggers any new matches for rules in the grammar.
- // Calls to AddTerminal() and AddMatch() must be in left-to-right order,
- // that is, the sequence of `end` values must be non-decreasing.
- void AddTerminal(const CodepointSpan codepoint_span, const int match_offset,
- StringPiece terminal);
- void AddTerminal(const CodepointIndex begin, const CodepointIndex end,
- StringPiece terminal) {
- AddTerminal(CodepointSpan{begin, end}, begin, terminal);
- }
-
- // Adds a nonterminal match to the chart.
- // This can be invoked by the lexer if the lexer needs to add nonterminals to
- // the chart.
- void AddMatch(Match* match);
-
- // Allocates memory from an area for a new match.
- // The `size` parameter is there to allow subclassing of the match object
- // with additional fields.
- Match* AllocateMatch(const size_t size) {
- return reinterpret_cast<Match*>(arena_.Alloc(size));
- }
-
- template <typename T>
- T* AllocateMatch() {
- return reinterpret_cast<T*>(arena_.Alloc(sizeof(T)));
- }
-
- template <typename T, typename... Args>
- T* AllocateAndInitMatch(Args... args) {
- T* match = AllocateMatch<T>();
- match->Init(args...);
- return match;
- }
-
- // Returns the current number of bytes allocated for all match objects.
- size_t ArenaSize() const { return arena_.status().bytes_allocated(); }
-
- private:
- static constexpr int kBlocksize = 16 << 10;
-
- // The state of the matcher.
- enum State {
- // The matcher is in the default state.
- STATE_DEFAULT = 0,
-
- // The matcher is currently processing queued match items.
- STATE_PROCESSING = 1,
- };
- State state_;
-
- // Process matches from lhs set.
- void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset,
- const int whitespace_gap,
- const std::function<void(Match*)>& initializer,
- const RulesSet_::LhsSet* lhs_set,
- CallbackDelegate* delegate);
-
- // Queues a newly created match item.
- void QueueForProcessing(Match* item);
-
- // Queues a match item for later post checking of the exclusion condition.
- // For exclusions we need to check that the `item->excluded_nonterminal`
- // doesn't match the same span. As we cannot know which matches have already
- // been added, we queue the item for later post checking - once all matches
- // up to `item->codepoint_span.second` have been added.
- void QueueForPostCheck(ExclusionMatch* item);
-
- // Adds pending items to the chart, possibly generating new matches as a
- // result.
- void ProcessPendingSet();
-
- // Returns whether the chart contains a match for a given nonterminal.
- bool ContainsMatch(const Nonterm nonterm, const CodepointSpan& span) const;
-
- // Checks all pending exclusion matches that their exclusion condition is
- // fulfilled.
- void ProcessPendingExclusionMatches();
-
- UniLib unilib_;
-
- // Memory arena for match allocation.
- UnsafeArena arena_;
-
- // The end position of the most recent match or terminal, for sanity
- // checking.
- int last_end_;
-
- // Rules.
- const RulesSet* rules_;
-
- // The set of items pending to be added to the chart as a singly-linked list.
- Match* pending_items_;
-
- // The set of items pending to be post-checked as a singly-linked list.
- ExclusionMatch* pending_exclusion_items_;
-
- // The chart data structure: a hashtable containing all matches, indexed by
- // their end positions.
- static constexpr int kChartHashTableNumBuckets = 1 << 8;
- static constexpr int kChartHashTableBitmask = kChartHashTableNumBuckets - 1;
- std::array<Match*, kChartHashTableNumBuckets> chart_;
-
- // The active rule shards.
- std::vector<const RulesSet_::Rules*> rules_shards_;
-
- // The callback handler.
- CallbackDelegate* delegate_;
-};
-
-} // namespace libtextclassifier3::grammar
-
-#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
diff --git a/native/utils/grammar/parsing/chart.h b/native/utils/grammar/parsing/chart.h
new file mode 100644
index 0000000..4ec05d7
--- /dev/null
+++ b/native/utils/grammar/parsing/chart.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
+
+#include <array>
+
+#include "annotator/types.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+
+namespace libtextclassifier3::grammar {
+
+// Chart is a hashtable container for use with a CYK style parser.
+// The hashtable contains all matches, indexed by their end positions.
+template <int NumBuckets = 1 << 8>
+class Chart {
+ public:
+ explicit Chart() { std::fill(chart_.begin(), chart_.end(), nullptr); }
+
+ // Iterator that allows iterating through recorded matches that end at a given
+ // match offset.
+ class Iterator {
+ public:
+ explicit Iterator(const int match_offset, const ParseTree* value)
+ : match_offset_(match_offset), value_(value) {}
+
+ bool Done() const {
+ return value_ == nullptr ||
+ (value_->codepoint_span.second < match_offset_);
+ }
+ const ParseTree* Item() const { return value_; }
+ void Next() {
+ TC3_DCHECK(!Done());
+ value_ = value_->next;
+ }
+
+ private:
+ const int match_offset_;
+ const ParseTree* value_;
+ };
+
+ // Returns whether the chart contains a match for a given nonterminal.
+ bool HasMatch(const Nonterm nonterm, const CodepointSpan& span) const;
+
+ // Adds a match to the chart.
+ void Add(ParseTree* item) {
+ item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask];
+ chart_[item->codepoint_span.second & kChartHashTableBitmask] = item;
+ }
+
+ // Records a derivation of a root rule.
+ void AddDerivation(const Derivation& derivation) {
+ root_derivations_.push_back(derivation);
+ }
+
+ // Returns an iterator through all matches ending at `match_offset`.
+ Iterator MatchesEndingAt(const int match_offset) const {
+ const ParseTree* value = chart_[match_offset & kChartHashTableBitmask];
+ // The chain of items is in decreasing `end` order.
+ // Find the ones that have prev->end == item->begin.
+ while (value != nullptr && (value->codepoint_span.second > match_offset)) {
+ value = value->next;
+ }
+ return Iterator(match_offset, value);
+ }
+
+ const std::vector<Derivation> derivations() const {
+ return root_derivations_;
+ }
+
+ private:
+ static constexpr int kChartHashTableBitmask = NumBuckets - 1;
+ std::array<ParseTree*, NumBuckets> chart_;
+ std::vector<Derivation> root_derivations_;
+};
+
+template <int NumBuckets>
+bool Chart<NumBuckets>::HasMatch(const Nonterm nonterm,
+ const CodepointSpan& span) const {
+ // Lookup by end.
+ for (Chart<NumBuckets>::Iterator it = MatchesEndingAt(span.second);
+ !it.Done(); it.Next()) {
+ if (it.Item()->lhs == nonterm &&
+ it.Item()->codepoint_span.first == span.first) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_
diff --git a/native/utils/grammar/parsing/chart_test.cc b/native/utils/grammar/parsing/chart_test.cc
new file mode 100644
index 0000000..e4ec72f
--- /dev/null
+++ b/native/utils/grammar/parsing/chart_test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/chart.h"
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::IsEmpty;
+
+class ChartTest : public testing::Test {
+ protected:
+ explicit ChartTest() : arena_(/*block_size=*/16 << 10) {}
+ UnsafeArena arena_;
+};
+
+TEST_F(ChartTest, IsEmptyByDefault) {
+ Chart<> chart;
+
+ EXPECT_THAT(chart.derivations(), IsEmpty());
+ EXPECT_TRUE(chart.MatchesEndingAt(0).Done());
+}
+
+TEST_F(ChartTest, IteratesThroughCell) {
+ Chart<> chart;
+ ParseTree* m0 = arena_.AllocAndInit<ParseTree>(/*lhs=*/0, CodepointSpan{0, 1},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m0);
+ ParseTree* m1 = arena_.AllocAndInit<ParseTree>(/*lhs=*/1, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m1);
+ ParseTree* m2 = arena_.AllocAndInit<ParseTree>(/*lhs=*/2, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m2);
+
+ // Position 0 should be empty.
+ EXPECT_TRUE(chart.MatchesEndingAt(0).Done());
+
+ // Position 1 should contain m0.
+ {
+ Chart<>::Iterator it = chart.MatchesEndingAt(1);
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m0);
+ it.Next();
+ EXPECT_TRUE(it.Done());
+ }
+
+ // Position 2 should contain m1 and m2.
+ {
+ Chart<>::Iterator it = chart.MatchesEndingAt(2);
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m2);
+ it.Next();
+ ASSERT_FALSE(it.Done());
+ EXPECT_EQ(it.Item(), m1);
+ it.Next();
+ EXPECT_TRUE(it.Done());
+ }
+}
+
+TEST_F(ChartTest, ChecksExistingMatches) {
+ Chart<> chart;
+ ParseTree* m0 = arena_.AllocAndInit<ParseTree>(/*lhs=*/0, CodepointSpan{0, 1},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m0);
+ ParseTree* m1 = arena_.AllocAndInit<ParseTree>(/*lhs=*/1, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m1);
+ ParseTree* m2 = arena_.AllocAndInit<ParseTree>(/*lhs=*/2, CodepointSpan{0, 2},
+ /*match_offset=*/0,
+ ParseTree::Type::kDefault);
+ chart.Add(m2);
+
+ EXPECT_TRUE(chart.HasMatch(0, CodepointSpan{0, 1}));
+ EXPECT_FALSE(chart.HasMatch(0, CodepointSpan{0, 2}));
+ EXPECT_TRUE(chart.HasMatch(1, CodepointSpan{0, 2}));
+ EXPECT_TRUE(chart.HasMatch(2, CodepointSpan{0, 2}));
+ EXPECT_FALSE(chart.HasMatch(0, CodepointSpan{0, 2}));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/derivation.cc b/native/utils/grammar/parsing/derivation.cc
new file mode 100644
index 0000000..6618654
--- /dev/null
+++ b/native/utils/grammar/parsing/derivation.cc
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/derivation.h"
+
+#include <algorithm>
+
+namespace libtextclassifier3::grammar {
+
+bool Derivation::IsValid() const {
+ bool result = true;
+ Traverse(parse_tree, [&result](const ParseTree* node) {
+ if (node->type != ParseTree::Type::kAssertion) {
+ // Only validation if all checks so far passed.
+ return result;
+ }
+ // Positive assertions are by definition fulfilled,
+ // fail if the assertion is negative.
+ if (static_cast<const AssertionNode*>(node)->negative) {
+ result = false;
+ }
+ return result;
+ });
+ return result;
+}
+
+std::vector<Derivation> DeduplicateDerivations(
+ const std::vector<Derivation>& derivations) {
+ std::vector<Derivation> sorted_candidates = derivations;
+ std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(),
+ [](const Derivation& a, const Derivation& b) {
+ // Sort by id.
+ if (a.rule_id != b.rule_id) {
+ return a.rule_id < b.rule_id;
+ }
+
+ // Sort by increasing start.
+ if (a.parse_tree->codepoint_span.first !=
+ b.parse_tree->codepoint_span.first) {
+ return a.parse_tree->codepoint_span.first <
+ b.parse_tree->codepoint_span.first;
+ }
+
+ // Sort by decreasing end.
+ return a.parse_tree->codepoint_span.second >
+ b.parse_tree->codepoint_span.second;
+ });
+
+ // Deduplicate by overlap.
+ std::vector<Derivation> result;
+ for (int i = 0; i < sorted_candidates.size(); i++) {
+ const Derivation& candidate = sorted_candidates[i];
+ bool eliminated = false;
+
+ // Due to the sorting above, the candidate can only be completely
+ // intersected by a match before it in the sorted order.
+ for (int j = i - 1; j >= 0; j--) {
+ if (sorted_candidates[j].rule_id != candidate.rule_id) {
+ break;
+ }
+ if (sorted_candidates[j].parse_tree->codepoint_span.first <=
+ candidate.parse_tree->codepoint_span.first &&
+ sorted_candidates[j].parse_tree->codepoint_span.second >=
+ candidate.parse_tree->codepoint_span.second) {
+ eliminated = true;
+ break;
+ }
+ }
+ if (!eliminated) {
+ result.push_back(candidate);
+ }
+ }
+ return result;
+}
+
+std::vector<Derivation> ValidDeduplicatedDerivations(
+ const std::vector<Derivation>& derivations) {
+ std::vector<Derivation> result;
+ for (const Derivation& derivation : DeduplicateDerivations(derivations)) {
+ // Check that asserts are fulfilled.
+ if (derivation.IsValid()) {
+ result.push_back(derivation);
+ }
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/derivation.h b/native/utils/grammar/parsing/derivation.h
new file mode 100644
index 0000000..70e169d
--- /dev/null
+++ b/native/utils/grammar/parsing/derivation.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
+
+#include <vector>
+
+#include "utils/grammar/parsing/parse-tree.h"
+
+namespace libtextclassifier3::grammar {
+
+// A parse tree for a root rule.
+struct Derivation {
+ const ParseTree* parse_tree;
+ int64 rule_id;
+
+ // Checks that all assertions are fulfilled.
+ bool IsValid() const;
+};
+
+// Deduplicates rule derivations by containing overlap.
+// The grammar system can output multiple candidates for optional parts.
+// For example if a rule has an optional suffix, we
+// will get two rule derivations when the suffix is present: one with and one
+// without the suffix. We therefore deduplicate by containing overlap, viz. from
+// two candidates we keep the longer one if it completely contains the shorter.
+std::vector<Derivation> DeduplicateDerivations(
+ const std::vector<Derivation>& derivations);
+
+// Deduplicates and validates rule derivations.
+std::vector<Derivation> ValidDeduplicatedDerivations(
+ const std::vector<Derivation>& derivations);
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
diff --git a/native/utils/grammar/parsing/lexer.cc b/native/utils/grammar/parsing/lexer.cc
new file mode 100644
index 0000000..79e92e1
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/lexer.h"
+
+namespace libtextclassifier3::grammar {
+
+Symbol::Type Lexer::GetSymbolType(const UnicodeText::const_iterator& it) const {
+ if (unilib_.IsPunctuation(*it)) {
+ return Symbol::Type::TYPE_PUNCTUATION;
+ } else if (unilib_.IsDigit(*it)) {
+ return Symbol::Type::TYPE_DIGITS;
+ } else {
+ return Symbol::Type::TYPE_TERM;
+ }
+}
+
+void Lexer::AppendTokenSymbols(const StringPiece value, int match_offset,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const {
+ // Possibly split token.
+ UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
+ /*do_copy=*/false);
+ int next_match_offset = match_offset;
+ auto token_end = token_unicode.end();
+ auto it = token_unicode.begin();
+ Symbol::Type type = GetSymbolType(it);
+ CodepointIndex sub_token_start = codepoint_span.first;
+ while (it != token_end) {
+ auto next = std::next(it);
+ int num_codepoints = 1;
+ Symbol::Type next_type;
+ while (next != token_end) {
+ next_type = GetSymbolType(next);
+ if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
+ break;
+ }
+ ++next;
+ ++num_codepoints;
+ }
+ symbols->emplace_back(
+ type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
+ /*match_offset=*/next_match_offset,
+ /*lexeme=*/
+ StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data()));
+ next_match_offset = sub_token_start + num_codepoints;
+ it = next;
+ type = next_type;
+ sub_token_start = next_match_offset;
+ }
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/lexer.h b/native/utils/grammar/parsing/lexer.h
new file mode 100644
index 0000000..f902fbd
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A lexer that (splits) and classifies tokens.
+//
+// Any whitespace gets absorbed into the token that follows them in the text.
+// For example, if the text contains:
+//
+// ...hello there world...
+// | | |
+// offset=16 39 52
+//
+// then the output will be:
+//
+// "hello" [?, 16)
+// "there" [16, 44) <-- note "16" NOT "39"
+// "world" [44, ?) <-- note "44" NOT "52"
+//
+// This makes it appear to the Matcher as if the tokens are adjacent.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// A lexical symbol with an identified meaning that represents raw tokens,
+// token categories or predefined text matches.
+// It is the unit fed to the grammar matcher.
+struct Symbol {
+ // The type of the lexical symbol.
+ enum class Type {
+ // A raw token.
+ TYPE_TERM,
+
+ // A symbol representing a string of digits.
+ TYPE_DIGITS,
+
+ // Punctuation characters.
+ TYPE_PUNCTUATION,
+
+ // A predefined parse tree.
+ TYPE_PARSE_TREE
+ };
+
+ explicit Symbol() = default;
+
+ // Constructs a symbol of a given type with an anchor in the text.
+ Symbol(const Type type, const CodepointSpan codepoint_span,
+ const int match_offset, StringPiece lexeme)
+ : type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset),
+ lexeme(lexeme) {}
+
+ // Constructs a symbol from a pre-defined parse tree.
+ explicit Symbol(ParseTree* parse_tree)
+ : type(Type::TYPE_PARSE_TREE),
+ codepoint_span(parse_tree->codepoint_span),
+ match_offset(parse_tree->match_offset),
+ parse_tree(parse_tree) {}
+
+ // The type of the symbol.
+ Type type;
+
+ // The span in the text as codepoint offsets.
+ CodepointSpan codepoint_span;
+
+ // The match start offset (including preceding whitespace) as codepoint
+ // offset.
+ int match_offset;
+
+ // The symbol text value.
+ StringPiece lexeme;
+
+ // The predefined parse tree.
+ ParseTree* parse_tree;
+};
+
+class Lexer {
+ public:
+ explicit Lexer(const UniLib* unilib) : unilib_(*unilib) {}
+
+ // Processes a single token.
+ // Splits a token into classified symbols.
+ void AppendTokenSymbols(const StringPiece value, int match_offset,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const;
+
+ private:
+ // Gets the type of a character.
+ Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
+
+ const UniLib& unilib_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_
diff --git a/native/utils/grammar/parsing/lexer_test.cc b/native/utils/grammar/parsing/lexer_test.cc
new file mode 100644
index 0000000..dad3b8e
--- /dev/null
+++ b/native/utils/grammar/parsing/lexer_test.cc
@@ -0,0 +1,170 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Unit tests for the lexer.
+
+#include "utils/grammar/parsing/lexer.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+
+std::ostream& operator<<(std::ostream& os, const Symbol& symbol) {
+ return os << "Symbol(type=" << static_cast<int>(symbol.type) << ", span=("
+ << symbol.codepoint_span.first << ", "
+ << symbol.codepoint_span.second
+ << "), lexeme=" << symbol.lexeme.ToString() << ")";
+}
+
+namespace {
+
+using ::testing::DescribeMatcher;
+using ::testing::ElementsAre;
+using ::testing::ExplainMatchResult;
+
+// Superclass of all tests here.
+class LexerTest : public testing::Test {
+ protected:
+ explicit LexerTest()
+ : unilib_(libtextclassifier3::CreateUniLibForTesting()),
+ tokenizer_(TokenizationType_ICU, unilib_.get(),
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false),
+ lexer_(unilib_.get()) {}
+
+ std::vector<Symbol> SymbolsForTokens(const std::vector<Token>& tokens) const {
+ std::vector<Symbol> symbols;
+ for (const Token& token : tokens) {
+ lexer_.AppendTokenSymbols(token.value, token.start,
+ CodepointSpan{token.start, token.end},
+ &symbols);
+ }
+ return symbols;
+ }
+
+ std::unique_ptr<UniLib> unilib_;
+ Tokenizer tokenizer_;
+ Lexer lexer_;
+};
+
+MATCHER_P4(IsSymbol, type, begin, end, terminal,
+ "is symbol with type that " +
+ DescribeMatcher<Symbol::Type>(type, negation) + ", begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", value that " +
+ DescribeMatcher<std::string>(terminal, negation)) {
+ return ExplainMatchResult(type, arg.type, result_listener) &&
+ ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(terminal, arg.lexeme.ToString(), result_listener);
+}
+
+TEST_F(LexerTest, HandlesSimpleWords) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("This is a word");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 4, "This"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 5, 7, "is"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 8, 9, "a"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 10, 14, "word")));
+}
+
+TEST_F(LexerTest, SplitsConcatedLettersAndDigit) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("1234This a4321cde");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_DIGITS, 0, 4, "1234"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 4, 8, "This"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 9, 10, "a"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 10, 14, "4321"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 17, "cde")));
+}
+
+TEST_F(LexerTest, SplitsPunctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("10/18/2014");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_DIGITS, 0, 2, "10"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 2, 3, "/"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 3, 5, "18"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 5, 6, "/"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 6, 10, "2014")));
+}
+
+TEST_F(LexerTest, SplitsUTF8Punctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("电话:0871—6857(曹");
+ EXPECT_THAT(
+ SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 2, "电话"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 2, 3, ":"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 3, 7, "0871"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 7, 8, "—"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 8, 12, "6857"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 12, 13, "("),
+ IsSymbol(Symbol::Type::TYPE_TERM, 13, 14, "曹")));
+}
+
+TEST_F(LexerTest, HandlesMixedPunctuation) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("电话 :0871—6857(曹");
+ EXPECT_THAT(
+ SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 2, "电话"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 3, 4, ":"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 4, 8, "0871"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 8, 9, "—"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 9, 13, "6857"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 13, 14, "("),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 15, "曹")));
+}
+
+TEST_F(LexerTest, HandlesTokensWithDigits) {
+ std::vector<Token> tokens =
+ tokenizer_.Tokenize("The.qUIck\n brown2345fox88 \xE2\x80\x94 the");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 3, "The"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 3, 4, "."),
+ IsSymbol(Symbol::Type::TYPE_TERM, 4, 9, "qUIck"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 11, 16, "brown"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 16, 20, "2345"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 20, 23, "fox"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 23, 25, "88"),
+ IsSymbol(Symbol::Type::TYPE_PUNCTUATION, 26, 27, "—"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 28, 31, "the")));
+}
+
+TEST_F(LexerTest, SplitsPlusSigns) {
+ std::vector<Token> tokens = tokenizer_.Tokenize("The+2345++the +");
+ EXPECT_THAT(SymbolsForTokens(tokens),
+ ElementsAre(IsSymbol(Symbol::Type::TYPE_TERM, 0, 3, "The"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 3, 4, "+"),
+ IsSymbol(Symbol::Type::TYPE_DIGITS, 4, 8, "2345"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 8, 9, "+"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 9, 10, "+"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 10, 13, "the"),
+ IsSymbol(Symbol::Type::TYPE_TERM, 14, 15, "+")));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/matcher.cc b/native/utils/grammar/parsing/matcher.cc
similarity index 67%
rename from native/utils/grammar/matcher.cc
rename to native/utils/grammar/parsing/matcher.cc
index a8ebba5..fa0ea0a 100644
--- a/native/utils/grammar/matcher.cc
+++ b/native/utils/grammar/parsing/matcher.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "utils/grammar/matcher.h"
+#include "utils/grammar/parsing/matcher.h"
#include <iostream>
#include <limits>
@@ -58,10 +58,13 @@
// Queue next character.
if (buffer_pos >= buffer_size) {
buffer_pos = 0;
- // Lower-case the next character.
+
+ // Lower-case the next character. The character and its lower-cased
+ // counterpart may be represented with a different number of bytes in
+ // utf8.
buffer_size =
ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
- data += buffer_size;
+ data += GetNumBytesForUTF8Char(data);
}
TC3_DCHECK_LT(buffer_pos, buffer_size);
return buffer[buffer_pos++];
@@ -130,7 +133,7 @@
}
++match_length;
- // By the loop variant and due to the fact that the strings are sorted,
+ // By the loop invariant and due to the fact that the strings are sorted,
// a matching string will be at `left` now.
if (!input_iterator.HasNext()) {
const int string_offset = LittleEndian::ToHost32(offsets[left]);
@@ -217,7 +220,7 @@
}
inline void GetLhs(const RulesSet* rules_set, const int lhs_entry,
- Nonterm* nonterminal, CallbackId* callback, uint64* param,
+ Nonterm* nonterminal, CallbackId* callback, int64* param,
int8* max_whitespace_gap) {
if (lhs_entry > 0) {
// Direct encoding of the nonterminal.
@@ -236,27 +239,18 @@
} // namespace
-void Matcher::Reset() {
- state_ = STATE_DEFAULT;
- arena_.Reset();
- pending_items_ = nullptr;
- pending_exclusion_items_ = nullptr;
- std::fill(chart_.begin(), chart_.end(), nullptr);
- last_end_ = std::numeric_limits<int>().lowest();
-}
-
void Matcher::Finish() {
// Check any pending items.
ProcessPendingExclusionMatches();
}
-void Matcher::QueueForProcessing(Match* item) {
+void Matcher::QueueForProcessing(ParseTree* item) {
// Push element to the front.
item->next = pending_items_;
pending_items_ = item;
}
-void Matcher::QueueForPostCheck(ExclusionMatch* item) {
+void Matcher::QueueForPostCheck(ExclusionNode* item) {
// Push element to the front.
item->next = pending_exclusion_items_;
pending_exclusion_items_ = item;
@@ -282,11 +276,11 @@
ExecuteLhsSet(
codepoint_span, match_offset,
/*whitespace_gap=*/(codepoint_span.first - match_offset),
- [terminal](Match* match) {
- match->terminal = terminal.data();
- match->rhs2 = nullptr;
+ [terminal](ParseTree* parse_tree) {
+ parse_tree->terminal = terminal.data();
+ parse_tree->rhs2 = nullptr;
},
- lhs_set, delegate_);
+ lhs_set);
}
// Try case-insensitive matches.
@@ -298,42 +292,41 @@
ExecuteLhsSet(
codepoint_span, match_offset,
/*whitespace_gap=*/(codepoint_span.first - match_offset),
- [terminal](Match* match) {
- match->terminal = terminal.data();
- match->rhs2 = nullptr;
+ [terminal](ParseTree* parse_tree) {
+ parse_tree->terminal = terminal.data();
+ parse_tree->rhs2 = nullptr;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
ProcessPendingSet();
}
-void Matcher::AddMatch(Match* match) {
- TC3_CHECK_GE(match->codepoint_span.second, last_end_);
+void Matcher::AddParseTree(ParseTree* parse_tree) {
+ TC3_CHECK_GE(parse_tree->codepoint_span.second, last_end_);
// Finish any pending post-checks.
- if (match->codepoint_span.second > last_end_) {
+ if (parse_tree->codepoint_span.second > last_end_) {
ProcessPendingExclusionMatches();
}
- last_end_ = match->codepoint_span.second;
- QueueForProcessing(match);
+ last_end_ = parse_tree->codepoint_span.second;
+ QueueForProcessing(parse_tree);
ProcessPendingSet();
}
-void Matcher::ExecuteLhsSet(const CodepointSpan codepoint_span,
- const int match_offset_bytes,
- const int whitespace_gap,
- const std::function<void(Match*)>& initializer,
- const RulesSet_::LhsSet* lhs_set,
- CallbackDelegate* delegate) {
+void Matcher::ExecuteLhsSet(
+ const CodepointSpan codepoint_span, const int match_offset_bytes,
+ const int whitespace_gap,
+ const std::function<void(ParseTree*)>& initializer_fn,
+ const RulesSet_::LhsSet* lhs_set) {
TC3_CHECK(lhs_set);
- Match* match = nullptr;
+ ParseTree* parse_tree = nullptr;
Nonterm prev_lhs = kUnassignedNonterm;
for (const int32 lhs_entry : *lhs_set->lhs()) {
Nonterm lhs;
CallbackId callback_id;
- uint64 callback_param;
+ int64 callback_param;
int8 max_whitespace_gap;
GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param,
&max_whitespace_gap);
@@ -343,91 +336,70 @@
continue;
}
- // Handle default callbacks.
+ // Handle callbacks.
switch (static_cast<DefaultCallback>(callback_id)) {
- case DefaultCallback::kSetType: {
- Match* typed_match = AllocateAndInitMatch<Match>(lhs, codepoint_span,
- match_offset_bytes);
- initializer(typed_match);
- typed_match->type = callback_param;
- QueueForProcessing(typed_match);
- continue;
- }
case DefaultCallback::kAssertion: {
- AssertionMatch* assertion_match = AllocateAndInitMatch<AssertionMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(assertion_match);
- assertion_match->type = Match::kAssertionMatch;
- assertion_match->negative = (callback_param != 0);
- QueueForProcessing(assertion_match);
+ AssertionNode* assertion_node = arena_->AllocAndInit<AssertionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*negative=*/(callback_param != 0));
+ initializer_fn(assertion_node);
+ QueueForProcessing(assertion_node);
continue;
}
case DefaultCallback::kMapping: {
- MappingMatch* mapping_match = AllocateAndInitMatch<MappingMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(mapping_match);
- mapping_match->type = Match::kMappingMatch;
- mapping_match->id = callback_param;
- QueueForProcessing(mapping_match);
+ MappingNode* mapping_node = arena_->AllocAndInit<MappingNode>(
+ lhs, codepoint_span, match_offset_bytes, /*id=*/callback_param);
+ initializer_fn(mapping_node);
+ QueueForProcessing(mapping_node);
continue;
}
case DefaultCallback::kExclusion: {
// We can only check the exclusion once all matches up to this position
// have been processed. Schedule and post check later.
- ExclusionMatch* exclusion_match = AllocateAndInitMatch<ExclusionMatch>(
- lhs, codepoint_span, match_offset_bytes);
- initializer(exclusion_match);
- exclusion_match->exclusion_nonterm = callback_param;
- QueueForPostCheck(exclusion_match);
+ ExclusionNode* exclusion_node = arena_->AllocAndInit<ExclusionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*exclusion_nonterm=*/callback_param);
+ initializer_fn(exclusion_node);
+ QueueForPostCheck(exclusion_node);
+ continue;
+ }
+ case DefaultCallback::kSemanticExpression: {
+ SemanticExpressionNode* expression_node =
+ arena_->AllocAndInit<SemanticExpressionNode>(
+ lhs, codepoint_span, match_offset_bytes,
+ /*expression=*/
+ rules_->semantic_expression()->Get(callback_param));
+ initializer_fn(expression_node);
+ QueueForProcessing(expression_node);
continue;
}
default:
break;
}
- if (callback_id != kNoCallback && rules_->callback() != nullptr) {
- const RulesSet_::CallbackEntry* callback_info =
- rules_->callback()->LookupByKey(callback_id);
- if (callback_info && callback_info->value().is_filter()) {
- // Filter callback.
- Match candidate;
- candidate.Init(lhs, codepoint_span, match_offset_bytes);
- initializer(&candidate);
- delegate->MatchFound(&candidate, callback_id, callback_param, this);
- continue;
- }
- }
-
if (prev_lhs != lhs) {
prev_lhs = lhs;
- match =
- AllocateAndInitMatch<Match>(lhs, codepoint_span, match_offset_bytes);
- initializer(match);
- QueueForProcessing(match);
+ parse_tree = arena_->AllocAndInit<ParseTree>(
+ lhs, codepoint_span, match_offset_bytes, ParseTree::Type::kDefault);
+ initializer_fn(parse_tree);
+ QueueForProcessing(parse_tree);
}
- if (callback_id != kNoCallback) {
- // This is an output callback.
- delegate->MatchFound(match, callback_id, callback_param, this);
+ if (static_cast<DefaultCallback>(callback_id) ==
+ DefaultCallback::kRootRule) {
+ chart_.AddDerivation(Derivation{parse_tree, /*rule_id=*/callback_param});
}
}
}
void Matcher::ProcessPendingSet() {
- // Avoid recursion caused by:
- // ProcessPendingSet --> callback --> AddMatch --> ProcessPendingSet --> ...
- if (state_ == STATE_PROCESSING) {
- return;
- }
- state_ = STATE_PROCESSING;
while (pending_items_) {
// Process.
- Match* item = pending_items_;
+ ParseTree* item = pending_items_;
pending_items_ = pending_items_->next;
// Add it to the chart.
- item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask];
- chart_[item->codepoint_span.second & kChartHashTableBitmask] = item;
+ chart_.Add(item);
// Check unary rules that trigger.
for (const RulesSet_::Rules* shard : rules_shards_) {
@@ -437,26 +409,19 @@
item->codepoint_span, item->match_offset,
/*whitespace_gap=*/
(item->codepoint_span.first - item->match_offset),
- [item](Match* match) {
- match->rhs1 = nullptr;
- match->rhs2 = item;
+ [item](ParseTree* parse_tree) {
+ parse_tree->rhs1 = nullptr;
+ parse_tree->rhs2 = item;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
// Check binary rules that trigger.
// Lookup by begin.
- Match* prev = chart_[item->match_offset & kChartHashTableBitmask];
- // The chain of items is in decreasing `end` order.
- // Find the ones that have prev->end == item->begin.
- while (prev != nullptr &&
- (prev->codepoint_span.second > item->match_offset)) {
- prev = prev->next;
- }
- for (;
- prev != nullptr && (prev->codepoint_span.second == item->match_offset);
- prev = prev->next) {
+ for (Chart<>::Iterator it = chart_.MatchesEndingAt(item->match_offset);
+ !it.Done(); it.Next()) {
+ const ParseTree* prev = it.Item();
for (const RulesSet_::Rules* shard : rules_shards_) {
if (const RulesSet_::LhsSet* lhs_set =
FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) {
@@ -468,45 +433,27 @@
(item->codepoint_span.first -
item->match_offset), // Whitespace gap is the gap
// between the two parts.
- [prev, item](Match* match) {
- match->rhs1 = prev;
- match->rhs2 = item;
+ [prev, item](ParseTree* parse_tree) {
+ parse_tree->rhs1 = prev;
+ parse_tree->rhs2 = item;
},
- lhs_set, delegate_);
+ lhs_set);
}
}
}
}
- state_ = STATE_DEFAULT;
}
void Matcher::ProcessPendingExclusionMatches() {
while (pending_exclusion_items_) {
- ExclusionMatch* item = pending_exclusion_items_;
- pending_exclusion_items_ = static_cast<ExclusionMatch*>(item->next);
+ ExclusionNode* item = pending_exclusion_items_;
+ pending_exclusion_items_ = static_cast<ExclusionNode*>(item->next);
// Check that the exclusion condition is fulfilled.
- if (!ContainsMatch(item->exclusion_nonterm, item->codepoint_span)) {
- AddMatch(item);
+ if (!chart_.HasMatch(item->exclusion_nonterm, item->codepoint_span)) {
+ AddParseTree(item);
}
}
}
-bool Matcher::ContainsMatch(const Nonterm nonterm,
- const CodepointSpan& span) const {
- // Lookup by end.
- Match* match = chart_[span.second & kChartHashTableBitmask];
- // The chain of items is in decreasing `end` order.
- while (match != nullptr && match->codepoint_span.second > span.second) {
- match = match->next;
- }
- while (match != nullptr && match->codepoint_span.second == span.second) {
- if (match->lhs == nonterm && match->codepoint_span.first == span.first) {
- return true;
- }
- match = match->next;
- }
- return false;
-}
-
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/matcher.h b/native/utils/grammar/parsing/matcher.h
new file mode 100644
index 0000000..f12a6a5
--- /dev/null
+++ b/native/utils/grammar/parsing/matcher.h
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A token based context-free grammar matcher.
+//
+// A parser passes token to the matcher: literal terminal strings and token
+// types.
+// The parser passes each token along with the [begin, end) position range
+// in which it occurs. So for an input string "Groundhog February 2, 2007", the
+// parser would tell the matcher that:
+//
+// "Groundhog" occurs at [0, 9)
+// "February" occurs at [9, 18)
+// <digits> occurs at [18, 20)
+// "," occurs at [20, 21)
+// <digits> occurs at [21, 26)
+//
+// Multiple overlapping symbols can be passed.
+// The only constraint on symbol order is that they have to be passed in
+// left-to-right order, strictly speaking, their "end" positions must be
+// nondecreasing. This constraint allows a more efficient matching algorithm.
+// The "begin" positions can be in any order.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
+
+#include <array>
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/chart.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+class Matcher {
+ public:
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ const std::vector<const RulesSet_::Rules*> rules_shards,
+ UnsafeArena* arena)
+ : unilib_(*unilib),
+ arena_(arena),
+ last_end_(std::numeric_limits<int>().lowest()),
+ rules_(rules),
+ rules_shards_(rules_shards),
+ pending_items_(nullptr),
+ pending_exclusion_items_(nullptr) {
+ TC3_CHECK_NE(rules, nullptr);
+ }
+
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ UnsafeArena* arena)
+ : Matcher(unilib, rules, {}, arena) {
+ rules_shards_.reserve(rules->rules()->size());
+ rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(),
+ rules->rules()->end());
+ }
+
+ // Finish the matching.
+ void Finish();
+
+ // Tells the matcher that the given terminal was found occupying position
+ // range [begin, end) in the input.
+ // The matcher may invoke callback functions before returning, if this
+ // terminal triggers any new matches for rules in the grammar.
+ // Calls to AddTerminal() and AddParseTree() must be in left-to-right order,
+ // that is, the sequence of `end` values must be non-decreasing.
+ void AddTerminal(const CodepointSpan codepoint_span, const int match_offset,
+ StringPiece terminal);
+ void AddTerminal(const CodepointIndex begin, const CodepointIndex end,
+ StringPiece terminal) {
+ AddTerminal(CodepointSpan{begin, end}, begin, terminal);
+ }
+
+ // Adds predefined parse tree.
+ void AddParseTree(ParseTree* parse_tree);
+
+ const Chart<> chart() const { return chart_; }
+
+ private:
+ // Process matches from lhs set.
+ void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset,
+ const int whitespace_gap,
+ const std::function<void(ParseTree*)>& initializer_fn,
+ const RulesSet_::LhsSet* lhs_set);
+
+ // Queues a newly created match item.
+ void QueueForProcessing(ParseTree* item);
+
+ // Queues a match item for later post checking of the exclusion condition.
+ // For exclusions we need to check that the `item->excluded_nonterminal`
+ // doesn't match the same span. As we cannot know which matches have already
+ // been added, we queue the item for later post checking - once all matches
+ // up to `item->codepoint_span.second` have been added.
+ void QueueForPostCheck(ExclusionNode* item);
+
+ // Adds pending items to the chart, possibly generating new matches as a
+ // result.
+ void ProcessPendingSet();
+
+ // Checks all pending exclusion matches that their exclusion condition is
+ // fulfilled.
+ void ProcessPendingExclusionMatches();
+
+ UniLib unilib_;
+
+ // Memory arena for match allocation.
+ UnsafeArena* arena_;
+
+ // The end position of the most recent match or terminal, for sanity
+ // checking.
+ int last_end_;
+
+ // Rules.
+ const RulesSet* rules_;
+ // The active rule shards.
+ std::vector<const RulesSet_::Rules*> rules_shards_;
+
+ // The set of items pending to be added to the chart as a singly-linked list.
+ ParseTree* pending_items_;
+
+ // The set of items pending to be post-checked as a singly-linked list.
+ ExclusionNode* pending_exclusion_items_;
+
+ // The chart data structure: a hashtable containing all matches, indexed by
+ // their end positions.
+ Chart<> chart_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_
diff --git a/native/utils/grammar/parsing/matcher_test.cc b/native/utils/grammar/parsing/matcher_test.cc
new file mode 100644
index 0000000..8528009
--- /dev/null
+++ b/native/utils/grammar/parsing/matcher_test.cc
@@ -0,0 +1,428 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/matcher.h"
+
+#include <string>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/strings/append.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::DescribeMatcher;
+using ::testing::ElementsAre;
+using ::testing::ExplainMatchResult;
+using ::testing::IsEmpty;
+
+struct TestMatchResult {
+ CodepointSpan codepoint_span;
+ std::string terminal;
+ std::string nonterminal;
+ int rule_id;
+
+ friend std::ostream& operator<<(std::ostream& os,
+ const TestMatchResult& match) {
+ return os << "Result(rule_id=" << match.rule_id
+ << ", begin=" << match.codepoint_span.first
+ << ", end=" << match.codepoint_span.second
+ << ", terminal=" << match.terminal
+ << ", nonterminal=" << match.nonterminal << ")";
+ }
+};
+
+MATCHER_P3(IsTerminal, begin, end, terminal,
+ "is terminal with begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", value that " +
+ DescribeMatcher<std::string>(terminal, negation)) {
+ return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(terminal, arg.terminal, result_listener);
+}
+
+MATCHER_P3(IsNonterminal, begin, end, name,
+ "is nonterminal with begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", name that " +
+ DescribeMatcher<std::string>(name, negation)) {
+ return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
+ result_listener) &&
+ ExplainMatchResult(name, arg.nonterminal, result_listener);
+}
+
+MATCHER_P4(IsDerivation, begin, end, name, rule_id,
+ "is derivation of rule that " +
+ DescribeMatcher<int>(rule_id, negation) + ", begin that " +
+ DescribeMatcher<int>(begin, negation) + ", end that " +
+ DescribeMatcher<int>(end, negation) + ", name that " +
+ DescribeMatcher<std::string>(name, negation)) {
+ return ExplainMatchResult(IsNonterminal(begin, end, name), arg,
+ result_listener) &&
+ ExplainMatchResult(rule_id, arg.rule_id, result_listener);
+}
+
+// Superclass of all tests.
+class MatcherTest : public testing::Test {
+ protected:
+ MatcherTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_), arena_(/*block_size=*/16 << 10) {}
+
+ std::string GetNonterminalName(
+ const RulesSet_::DebugInformation* debug_information,
+ const Nonterm nonterminal) const {
+ if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
+ debug_information->nonterminal_names()->LookupByKey(nonterminal)) {
+ return entry->value()->str();
+ }
+ // Unnamed Nonterm.
+ return "()";
+ }
+
+ std::vector<TestMatchResult> GetMatchResults(
+ const Chart<>& chart,
+ const RulesSet_::DebugInformation* debug_information) {
+ std::vector<TestMatchResult> result;
+ for (const Derivation& derivation : chart.derivations()) {
+ result.emplace_back();
+ result.back().rule_id = derivation.rule_id;
+ result.back().codepoint_span = derivation.parse_tree->codepoint_span;
+ result.back().nonterminal =
+ GetNonterminalName(debug_information, derivation.parse_tree->lhs);
+ if (derivation.parse_tree->IsTerminalRule()) {
+ result.back().terminal = derivation.parse_tree->terminal;
+ }
+ }
+ return result;
+ }
+
+ UniLib unilib_;
+ UnsafeArena arena_;
+};
+
+TEST_F(MatcherTest, HandlesBasicOperations) {
+ // Create an example grammar.
+ Rules rules;
+ rules.Add("<test>", {"the", "quick", "brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<action>", {"<test>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "quick");
+ matcher.AddTerminal(2, 3, "brown");
+ matcher.AddTerminal(3, 4, "fox");
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 4, "<test>"),
+ IsNonterminal(0, 4, "<action>")));
+}
+
+std::string CreateTestGrammar() {
+ // Create an example grammar.
+ Rules rules;
+
+ // Callbacks on terminal rules.
+ rules.Add("<output_5>", {"quick"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 6);
+ rules.Add("<output_0>", {"the"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 1);
+
+ // Callbacks on non-terminal rules.
+ rules.Add("<output_1>", {"the", "quick", "brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 2);
+ rules.Add("<output_2>", {"the", "quick"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 3);
+ rules.Add("<output_3>", {"brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 4);
+
+ // Now a complex thing: "the* brown fox".
+ rules.Add("<thestarbrownfox>", {"brown", "fox"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
+ rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
+
+ return rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+}
+
+Nonterm FindNontermForName(const RulesSet* rules,
+ const std::string& nonterminal_name) {
+ for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
+ *rules->debug_information()->nonterminal_names()) {
+ if (entry->value()->str() == nonterminal_name) {
+ return entry->key();
+ }
+ }
+ return kUnassignedNonterm;
+}
+
+TEST_F(MatcherTest, HandlesDerivationsOfRules) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "quick");
+ matcher.AddTerminal(2, 3, "brown");
+ matcher.AddTerminal(3, 4, "fox");
+ matcher.AddTerminal(3, 5, "fox");
+ matcher.AddTerminal(4, 6, "fox"); // Not adjacent to "brown".
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(
+ // the
+ IsDerivation(0, 1, "<output_0>", 1),
+
+ // quick
+ IsDerivation(1, 2, "<output_5>", 6),
+ IsDerivation(0, 2, "<output_2>", 3),
+
+ // brown
+
+ // fox
+ IsDerivation(0, 4, "<output_1>", 2),
+ IsDerivation(2, 4, "<output_3>", 4),
+ IsDerivation(2, 4, "<thestarbrownfox>", 5),
+
+ // fox
+ IsDerivation(0, 5, "<output_1>", 2),
+ IsDerivation(2, 5, "<output_3>", 4),
+ IsDerivation(2, 5, "<thestarbrownfox>", 5)));
+}
+
+TEST_F(MatcherTest, HandlesRecursiveRules) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ matcher.AddTerminal(0, 1, "the");
+ matcher.AddTerminal(1, 2, "the");
+ matcher.AddTerminal(2, 4, "the");
+ matcher.AddTerminal(3, 4, "the");
+ matcher.AddTerminal(4, 5, "brown");
+ matcher.AddTerminal(5, 6, "fox"); // Generates 5 of <thestarbrownfox>
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsTerminal(0, 1, "the"), IsTerminal(1, 2, "the"),
+ IsTerminal(2, 4, "the"), IsTerminal(3, 4, "the"),
+ IsNonterminal(4, 6, "<output_3>"),
+ IsNonterminal(4, 6, "<thestarbrownfox>"),
+ IsNonterminal(3, 6, "<thestarbrownfox>"),
+ IsNonterminal(2, 6, "<thestarbrownfox>"),
+ IsNonterminal(1, 6, "<thestarbrownfox>"),
+ IsNonterminal(0, 6, "<thestarbrownfox>")));
+}
+
+TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
+ const std::string rules_buffer = CreateTestGrammar();
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ // Test having the lexer call AddParseTree() instead of AddTerminal()
+ matcher.AddTerminal(-4, 37, "the");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ FindNontermForName(rules_set, "<thestarbrownfox>"), CodepointSpan{37, 42},
+ /*match_offset=*/37, ParseTree::Type::kDefault));
+
+ EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsTerminal(-4, 37, "the"),
+ IsNonterminal(-4, 42, "<thestarbrownfox>")));
+}
+
+TEST_F(MatcherTest, HandlesOptionalRuleElements) {
+ Rules rules;
+ rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+ Matcher matcher(&unilib_, rules_set, &arena_);
+
+ // Run the matcher on "a b c d e".
+ matcher.AddTerminal(0, 1, "a");
+ matcher.AddTerminal(1, 2, "b");
+ matcher.AddTerminal(2, 3, "c");
+ matcher.AddTerminal(3, 4, "d");
+ matcher.AddTerminal(4, 5, "e");
+
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(
+ IsNonterminal(0, 4, "<output_2>"), IsTerminal(4, 5, "e"),
+ IsNonterminal(0, 5, "<output_0>"), IsNonterminal(0, 5, "<output_1>"),
+ IsNonterminal(0, 5, "<output_2>"), IsNonterminal(1, 5, "<output_0>"),
+ IsNonterminal(2, 5, "<output_0>"),
+ IsNonterminal(3, 5, "<output_0>")));
+}
+
+TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
+ Rules rules;
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on LX 1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{9, 13}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ IsEmpty());
+ }
+}
+
+TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
+ Rules rules;
+ rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
+ rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
+ rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
+ /*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on lx1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "lx");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
+ }
+
+ // Check that the grammar does trigger on dl1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(12, 14, "dl");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{14, 18}, /*match_offset=*/14, ParseTree::Type::kDefault));
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
+ }
+}
+
+TEST_F(MatcherTest, HandlesExclusions) {
+ Rules rules;
+ rules.Add("<all_zeros>", {"0000"});
+ rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
+ /*excluded_nonterminal=*/"<all_zeros>");
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ rules.Add("<iata>", {"dl"});
+ // Require no whitespace between code and flight number.
+ rules.Add("<flight_number>", {"<iata>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule));
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
+ /*include_debug_information=*/true);
+ const RulesSet* rules_set =
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
+
+ // Check that the grammar triggers on LX1138.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(0, 2, "LX");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
+ matcher.Finish();
+ EXPECT_THAT(
+ GetMatchResults(matcher.chart(), rules_set->debug_information()),
+ ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
+ }
+
+ // Check that the grammar doesn't trigger on LX0000.
+ {
+ Matcher matcher(&unilib_, rules_set, &arena_);
+ matcher.AddTerminal(6, 8, "LX");
+ matcher.AddTerminal(8, 12, "0000");
+ matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
+ rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
+ CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
+ matcher.Finish();
+ EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parse-tree.cc b/native/utils/grammar/parsing/parse-tree.cc
new file mode 100644
index 0000000..8a53173
--- /dev/null
+++ b/native/utils/grammar/parsing/parse-tree.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parse-tree.h"
+
+#include <algorithm>
+#include <stack>
+
+namespace libtextclassifier3::grammar {
+
+void Traverse(const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& node_fn) {
+ std::stack<const ParseTree*> open;
+ open.push(root);
+
+ while (!open.empty()) {
+ const ParseTree* node = open.top();
+ open.pop();
+ if (!node_fn(node) || node->IsLeaf()) {
+ continue;
+ }
+ open.push(node->rhs2);
+ if (node->rhs1 != nullptr) {
+ open.push(node->rhs1);
+ }
+ }
+}
+
+std::vector<const ParseTree*> SelectAll(
+ const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& pred_fn) {
+ std::vector<const ParseTree*> result;
+ Traverse(root, [&result, pred_fn](const ParseTree* node) {
+ if (pred_fn(node)) {
+ result.push_back(node);
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parse-tree.h b/native/utils/grammar/parsing/parse-tree.h
new file mode 100644
index 0000000..d3075d8
--- /dev/null
+++ b/native/utils/grammar/parsing/parse-tree.h
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
+
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+
+// Represents a parse tree for a match that was found for a nonterminal.
+struct ParseTree {
+ enum class Type : int8 {
+ // Default, untyped match.
+ kDefault = 0,
+
+ // An assertion match (see: AssertionNode).
+ kAssertion = 1,
+
+ // A value mapping match (see: MappingNode).
+ kMapping = 2,
+
+ // An exclusion match (see: ExclusionNode).
+ kExclusion = 3,
+
+ // A match for an annotation (see: AnnotationNode).
+ kAnnotation = 4,
+
+ // A match for a semantic annotation (see: SemanticExpressionNode).
+ kExpression = 5,
+ };
+
+ explicit ParseTree() = default;
+ explicit ParseTree(const Nonterm lhs, const CodepointSpan& codepoint_span,
+ const int match_offset, const Type type)
+ : lhs(lhs),
+ type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset) {}
+
+ // For binary rule matches: rhs1 != NULL and rhs2 != NULL
+ // unary rule matches: rhs1 == NULL and rhs2 != NULL
+ // terminal rule matches: rhs1 != NULL and rhs2 == NULL
+ // custom leaves: rhs1 == NULL and rhs2 == NULL
+ bool IsInteriorNode() const { return rhs2 != nullptr; }
+ bool IsLeaf() const { return !rhs2; }
+
+ bool IsBinaryRule() const { return rhs1 && rhs2; }
+ bool IsUnaryRule() const { return !rhs1 && rhs2; }
+ bool IsTerminalRule() const { return rhs1 && !rhs2; }
+ bool HasLeadingWhitespace() const {
+ return codepoint_span.first != match_offset;
+ }
+
+ const ParseTree* unary_rule_rhs() const { return rhs2; }
+
+ // Used in singly-linked queue of matches for processing.
+ ParseTree* next = nullptr;
+
+ // Nonterminal we found a match for.
+ Nonterm lhs = kUnassignedNonterm;
+
+ // Type of the match.
+ Type type = Type::kDefault;
+
+ // The span in codepoints.
+ CodepointSpan codepoint_span;
+
+ // The begin codepoint offset used during matching.
+ // This is usually including any prefix whitespace.
+ int match_offset;
+
+ union {
+ // The first sub match for binary rules.
+ const ParseTree* rhs1 = nullptr;
+
+ // The terminal, for terminal rules.
+ const char* terminal;
+ };
+ // First or second sub-match for interior nodes.
+ const ParseTree* rhs2 = nullptr;
+};
+
+// Node type to keep track of associated values.
+struct MappingNode : public ParseTree {
+ explicit MappingNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset, const int64 arg_value)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kMapping),
+ id(arg_value) {}
+ // The associated id or value.
+ int64 id;
+};
+
+// Node type to keep track of assertions.
+struct AssertionNode : public ParseTree {
+ explicit AssertionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset, const bool arg_negative)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kAssertion),
+ negative(arg_negative) {}
+ // If true, the assertion is negative and will be valid if the input doesn't
+ // match.
+ bool negative;
+};
+
+// Node type to define exclusions.
+struct ExclusionNode : public ParseTree {
+ explicit ExclusionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const Nonterm arg_exclusion_nonterm)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kExclusion),
+ exclusion_nonterm(arg_exclusion_nonterm) {}
+ // The nonterminal that denotes matches to exclude from a successful match.
+ // So the match is only valid if there is no match of `exclusion_nonterm`
+ // spanning the same text range.
+ Nonterm exclusion_nonterm;
+};
+
+// Match to represent an annotator annotated span in the grammar.
+struct AnnotationNode : public ParseTree {
+ explicit AnnotationNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const ClassificationResult* arg_annotation)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kAnnotation),
+ annotation(arg_annotation) {}
+ const ClassificationResult* annotation;
+};
+
+// Node type to represent an associated semantic expression.
+struct SemanticExpressionNode : public ParseTree {
+ explicit SemanticExpressionNode(const Nonterm arg_lhs,
+ const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset,
+ const SemanticExpression* arg_expression)
+ : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset,
+ Type::kExpression),
+ expression(arg_expression) {}
+ const SemanticExpression* expression;
+};
+
+// Utility functions for parse tree traversal.
+
+// Does a preorder traversal, calling `node_fn` on each node.
+// `node_fn` is expected to return whether to continue expanding a node.
+void Traverse(const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& node_fn);
+
+// Does a preorder traversal, selecting all nodes where `pred_fn` returns true.
+std::vector<const ParseTree*> SelectAll(
+ const ParseTree* root,
+ const std::function<bool(const ParseTree*)>& pred_fn);
+
+// Retrieves all nodes of a given type.
+template <typename T>
+const std::vector<const T*> SelectAllOfType(const ParseTree* root,
+ const ParseTree::Type type) {
+ std::vector<const T*> result;
+ Traverse(root, [&result, type](const ParseTree* node) {
+ if (node->type == type) {
+ result.push_back(static_cast<const T*>(node));
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_
diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc
new file mode 100644
index 0000000..4e39a98
--- /dev/null
+++ b/native/utils/grammar/parsing/parser.cc
@@ -0,0 +1,278 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parser.h"
+
+#include <unordered_map>
+
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/rules-utils.h"
+#include "utils/grammar/types.h"
+#include "utils/zlib/zlib.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+inline bool CheckMemoryUsage(const UnsafeArena* arena) {
+ // The maximum memory usage for matching.
+ constexpr int kMaxMemoryUsage = 1 << 20;
+ return arena->status().bytes_allocated() <= kMaxMemoryUsage;
+}
+
+// Maps a codepoint to include the token padding if it aligns with a token
+// start. Whitespace is ignored when symbols are fed to the matcher. Preceding
+// whitespace is merged to the match start so that tokens and non-terminals
+// appear next to each other without whitespace. For text or regex annotations,
+// we therefore merge the whitespace padding to the start if the annotation
+// starts at a token.
+int MapCodepointToTokenPaddingIfPresent(
+ const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
+ const int start) {
+ const auto it = token_alignment.find(start);
+ if (it != token_alignment.end()) {
+ return it->second;
+ }
+ return start;
+}
+
+} // namespace
+
+Parser::Parser(const UniLib* unilib, const RulesSet* rules)
+ : unilib_(*unilib),
+ rules_(rules),
+ lexer_(unilib),
+ nonterminals_(rules_->nonterminals()),
+ rules_locales_(ParseRulesLocales(rules_)),
+ regex_annotators_(BuildRegexAnnotators()) {}
+
+// Uncompresses and build the defined regex annotators.
+std::vector<Parser::RegexAnnotator> Parser::BuildRegexAnnotators() const {
+ std::vector<RegexAnnotator> result;
+ if (rules_->regex_annotator() != nullptr) {
+ std::unique_ptr<ZlibDecompressor> decompressor =
+ ZlibDecompressor::Instance();
+ result.reserve(rules_->regex_annotator()->size());
+ for (const RulesSet_::RegexAnnotator* regex_annotator :
+ *rules_->regex_annotator()) {
+ result.push_back(
+ {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
+ regex_annotator->compressed_pattern(),
+ rules_->lazy_regex_compilation(),
+ decompressor.get()),
+ regex_annotator->nonterminal()});
+ }
+ }
+ return result;
+}
+
+std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
+ UnsafeArena* arena) const {
+ // Whitespace is ignored when symbols are fed to the matcher.
+ // For regex matches and existing text annotations we therefore have to merge
+ // preceding whitespace to the match start so that tokens and non-terminals
+ // appear as next to each other without whitespace. We keep track of real
+ // token starts and precending whitespace in `token_match_start`, so that we
+ // can extend a match's start to include the preceding whitespace.
+ std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
+ for (int i = input.context_span.first + 1; i < input.context_span.second;
+ i++) {
+ const CodepointIndex token_start = input.tokens[i].start;
+ const CodepointIndex prev_token_end = input.tokens[i - 1].end;
+ if (token_start != prev_token_end) {
+ token_match_start[token_start] = prev_token_end;
+ }
+ }
+
+ std::vector<Symbol> symbols;
+ CodepointIndex match_offset = input.tokens[input.context_span.first].start;
+
+ // Add start symbol.
+ if (input.context_span.first == 0 &&
+ nonterminals_->start_nt() != kUnassignedNonterm) {
+ match_offset = 0;
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->start_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, ParseTree::Type::kDefault));
+ }
+
+ if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->wordbreak_nt(),
+ CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Add symbols from tokens.
+ for (int i = input.context_span.first; i < input.context_span.second; i++) {
+ const Token& token = input.tokens[i];
+ lexer_.AppendTokenSymbols(token.value, /*match_offset=*/match_offset,
+ CodepointSpan{token.start, token.end}, &symbols);
+ match_offset = token.end;
+
+ // Add word break symbol.
+ if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->wordbreak_nt(),
+ CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+ }
+
+ // Add end symbol if used by the grammar.
+ if (input.context_span.second == input.tokens.size() &&
+ nonterminals_->end_nt() != kUnassignedNonterm) {
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ nonterminals_->end_nt(), CodepointSpan{match_offset, match_offset},
+ /*match_offset=*/match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Add symbols from the regex annotators.
+ const CodepointIndex context_start =
+ input.tokens[input.context_span.first].start;
+ const CodepointIndex context_end =
+ input.tokens[input.context_span.second - 1].end;
+ for (const RegexAnnotator& regex_annotator : regex_annotators_) {
+ std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
+ regex_annotator.pattern->Matcher(UnicodeText::Substring(
+ input.text, context_start, context_end, /*do_copy=*/false));
+ int status = UniLib::RegexMatcher::kNoError;
+ while (regex_matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError) {
+ const CodepointSpan span{regex_matcher->Start(0, &status) + context_start,
+ regex_matcher->End(0, &status) + context_start};
+ symbols.emplace_back(arena->AllocAndInit<ParseTree>(
+ regex_annotator.nonterm, span, /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(token_match_start, span.first),
+ ParseTree::Type::kDefault));
+ }
+ }
+
+ // Add symbols based on annotations.
+ if (auto annotation_nonterminals = nonterminals_->annotation_nt()) {
+ for (const AnnotatedSpan& annotated_span : input.annotations) {
+ const ClassificationResult& classification =
+ annotated_span.classification.front();
+ if (auto entry = annotation_nonterminals->LookupByKey(
+ classification.collection.c_str())) {
+ symbols.emplace_back(arena->AllocAndInit<AnnotationNode>(
+ entry->value(), annotated_span.span, /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(token_match_start,
+ annotated_span.span.first),
+ &classification));
+ }
+ }
+ }
+
+ std::sort(symbols.begin(), symbols.end(),
+ [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
+
+ return symbols;
+}
+
+void Parser::EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
+ Matcher* matcher) const {
+ if (!CheckMemoryUsage(arena)) {
+ return;
+ }
+ switch (symbol.type) {
+ case Symbol::Type::TYPE_PARSE_TREE: {
+ // Just emit the parse tree.
+ matcher->AddParseTree(symbol.parse_tree);
+ return;
+ }
+ case Symbol::Type::TYPE_DIGITS: {
+ // Emit <digits> if used by the rules.
+ if (nonterminals_->digits_nt() != kUnassignedNonterm) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->digits_nt(), symbol.codepoint_span,
+ symbol.match_offset, ParseTree::Type::kDefault));
+ }
+
+ // Emit <n_digits> if used by the rules.
+ if (nonterminals_->n_digits_nt() != nullptr) {
+ const int num_digits =
+ symbol.codepoint_span.second - symbol.codepoint_span.first;
+ if (num_digits <= nonterminals_->n_digits_nt()->size()) {
+ const Nonterm n_digits_nt =
+ nonterminals_->n_digits_nt()->Get(num_digits - 1);
+ if (n_digits_nt != kUnassignedNonterm) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->n_digits_nt()->Get(num_digits - 1),
+ symbol.codepoint_span, symbol.match_offset,
+ ParseTree::Type::kDefault));
+ }
+ }
+ }
+ break;
+ }
+ case Symbol::Type::TYPE_TERM: {
+ // Emit <uppercase_token> if used by the rules.
+ if (nonterminals_->uppercase_token_nt() != 0 &&
+ unilib_.IsUpperText(
+ UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->uppercase_token_nt(), symbol.codepoint_span,
+ symbol.match_offset, ParseTree::Type::kDefault));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ // Emit the token as terminal.
+ matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
+ symbol.lexeme);
+
+ // Emit <token> if used by rules.
+ matcher->AddParseTree(arena->AllocAndInit<ParseTree>(
+ nonterminals_->token_nt(), symbol.codepoint_span, symbol.match_offset,
+ ParseTree::Type::kDefault));
+}
+
+// Parses an input text and returns the root rule derivations.
+std::vector<Derivation> Parser::Parse(const TextContext& input,
+ UnsafeArena* arena) const {
+ // Check the tokens, input can be non-empty (whitespace) but have no tokens.
+ if (input.tokens.empty()) {
+ return {};
+ }
+
+ // Select locale matching rules.
+ std::vector<const RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(rules_, rules_locales_, input.locales);
+
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return {};
+ }
+
+ Matcher matcher(&unilib_, rules_, locale_rules, arena);
+ for (const Symbol& symbol : SortedSymbolsForInput(input, arena)) {
+ EmitSymbol(symbol, arena, &matcher);
+ }
+ matcher.Finish();
+ return matcher.chart().derivations();
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/parsing/parser.h b/native/utils/grammar/parsing/parser.h
new file mode 100644
index 0000000..0b320a0
--- /dev/null
+++ b/native/utils/grammar/parsing/parser.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/lexer.h"
+#include "utils/grammar/parsing/matcher.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+// Syntactic parsing pass.
+// The parser validates and deduplicates candidates produced by the grammar
+// matcher. It augments the parse trees with derivation information for semantic
+// evaluation.
+class Parser {
+ public:
+ explicit Parser(const UniLib* unilib, const RulesSet* rules);
+
+ // Parses an input text and returns the root rule derivations.
+ std::vector<Derivation> Parse(const TextContext& input,
+ UnsafeArena* arena) const;
+
+ private:
+ struct RegexAnnotator {
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ Nonterm nonterm;
+ };
+
+ // Uncompresses and build the defined regex annotators.
+ std::vector<RegexAnnotator> BuildRegexAnnotators() const;
+
+ // Produces symbols for a text input to feed to a matcher.
+ // These are symbols for each tokens from the lexer, existing text annotations
+ // and regex annotations.
+ // The symbols are sorted with increasing end-positions to satisfy the matcher
+ // requirements.
+ std::vector<Symbol> SortedSymbolsForInput(const TextContext& input,
+ UnsafeArena* arena) const;
+
+ // Emits a symbol to the matcher.
+ void EmitSymbol(const Symbol& symbol, UnsafeArena* arena,
+ Matcher* matcher) const;
+
+ const UniLib& unilib_;
+ const RulesSet* rules_;
+ const Lexer lexer_;
+
+ // Pre-defined nonterminals.
+ const RulesSet_::Nonterminals* nonterminals_;
+
+ // Pre-parsed locales of the rules.
+ const std::vector<std::vector<Locale>> rules_locales_;
+
+ std::vector<RegexAnnotator> regex_annotators_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_
diff --git a/native/utils/grammar/parsing/parser_test.cc b/native/utils/grammar/parsing/parser_test.cc
new file mode 100644
index 0000000..cf8310b
--- /dev/null
+++ b/native/utils/grammar/parsing/parser_test.cc
@@ -0,0 +1,296 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/parsing/parser.h"
+
+#include <string>
+#include <vector>
+
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/ir.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+
+class ParserTest : public GrammarTest {};
+
+TEST_F(ParserTest, ParsesSimpleRules) {
+ Rules rules;
+ rules.Add("<day>", {"<2_digits>"});
+ rules.Add("<month>", {"<2_digits>"});
+ rules.Add("<year>", {"<4_digits>"});
+ constexpr int kDate = 0;
+ rules.Add("<date>", {"<year>", "/", "<month>", "/", "<day>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kDate);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Event: 2020/05/08"), &arena_)),
+ ElementsAre(IsDerivation(kDate, 7, 17)));
+}
+
+TEST_F(ParserTest, HandlesEmptyInput) {
+ Rules rules;
+ constexpr int kTest = 0;
+ rules.Add("<test>", {"test"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kTest);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Event: test"), &arena_)),
+ ElementsAre(IsDerivation(kTest, 7, 11)));
+
+ // Check that we bail out in case of empty input.
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText(""), &arena_)),
+ IsEmpty());
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText(" "), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesUppercaseTokens) {
+ Rules rules;
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"please?", "reply", "<uppercase_token>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply STOP to cancel."), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply stop to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesAnchors) {
+ Rules rules;
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"<^>", "reply", "<uppercase_token>", "<$>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Reply STOP"), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Please reply STOP to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesWordBreaks) {
+ Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ constexpr int kFlight = 0;
+ rules.Add("<flight>", {"<carrier>", "<digits>", "<\b>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ // Make sure the grammar recognizes "LX 38".
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("My flight is: LX 38. Arriving later"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 14, 19)));
+
+ // Make sure the grammar doesn't trigger on "LX 38.00".
+ EXPECT_THAT(ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("LX 38.00"), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesAnnotations) {
+ Rules rules;
+ constexpr int kCallPhone = 0;
+ rules.Add("<flight>", {"dial", "<phone>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kCallPhone);
+ rules.BindAnnotation("<phone>", "phone");
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ TextContext context = TextContextForText("Please dial 911");
+
+ // Sanity check that we don't trigger if we don't feed the correct
+ // annotations.
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
+ IsEmpty());
+
+ // Create a phone annotion.
+ AnnotatedSpan phone_span;
+ phone_span.span = CodepointSpan{12, 15};
+ phone_span.classification.emplace_back("phone", 1.0);
+ context.annotations.push_back(phone_span);
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(context, &arena_)),
+ ElementsAre(IsDerivation(kCallPhone, 7, 15)));
+}
+
+TEST_F(ParserTest, HandlesRegexAnnotators) {
+ Rules rules;
+ rules.AddRegex("<code>",
+ "(\"([A-Za-z]+)\"|\\b\"?(?:[A-Z]+[0-9]*|[0-9])\"?\\b)");
+ constexpr int kScriptedReply = 0;
+ rules.Add("<test>", {"please?", "reply", "<code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ kScriptedReply);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply STOP to cancel."), &arena_)),
+ ElementsAre(IsDerivation(kScriptedReply, 0, 10)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("Reply Stop to cancel."), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesExclusions) {
+ Rules rules;
+ rules.Add("<excluded>", {"be", "safe"});
+ rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
+ /*excluded_nonterminal=*/"<excluded>");
+ constexpr int kSetReminder = 0;
+ rules.Add("<set_reminder>",
+ {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be there"), &arena_)),
+ ElementsAre(IsDerivation(kSetReminder, 0, 25)));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be safe"), &arena_)),
+ IsEmpty());
+}
+
+TEST_F(ParserTest, HandlesFillers) {
+ Rules rules;
+ constexpr int kSetReminder = 0;
+ rules.Add("<set_reminder>", {"do", "not", "forget", "to", "<filler>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kSetReminder);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("do not forget to be there"), &arena_)),
+ ElementsAre(IsDerivation(kSetReminder, 0, 25)));
+}
+
+TEST_F(ParserTest, HandlesAssertions) {
+ Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>",
+ {"<carrier>", "<flight_code>", "<context_assertion>?"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("LX38 aa 44 LX 38.38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
+}
+
+TEST_F(ParserTest, HandlesWhitespaceGapLimit) {
+ Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight,
+ /*max_whitespace_gap=*/0);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(ValidDeduplicatedDerivations(parser.Parse(
+ TextContextForText("LX38 aa 44 LX 38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4)));
+}
+
+TEST_F(ParserTest, HandlesCaseSensitiveMatching) {
+ Rules rules;
+ rules.Add("<carrier>", {"Lx"}, /*callback=*/kNoCallback, /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
+ rules.Add("<carrier>", {"AA"}, /*callback=*/kNoCallback, /*callback_param=*/0,
+ /*max_whitespace_gap=*/-1, /*case_sensitive=*/true);
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ // Flight: carrier + flight code and check right context.
+ constexpr int kFlight = 0;
+ rules.Add("<track_flight>", {"<carrier>", "<flight_code>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kFlight);
+ const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer();
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(rules_buffer.data()));
+
+ EXPECT_THAT(
+ ValidDeduplicatedDerivations(
+ parser.Parse(TextContextForText("Lx38 AA 44 LX 38"), &arena_)),
+ ElementsAre(IsDerivation(kFlight, 0, 4), IsDerivation(kFlight, 5, 10)));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils.cc b/native/utils/grammar/rules-utils.cc
index 56c928a..5e8c189 100644
--- a/native/utils/grammar/rules-utils.cc
+++ b/native/utils/grammar/rules-utils.cc
@@ -54,70 +54,4 @@
return shards;
}
-std::vector<Derivation> DeduplicateDerivations(
- const std::vector<Derivation>& derivations) {
- std::vector<Derivation> sorted_candidates = derivations;
- std::stable_sort(
- sorted_candidates.begin(), sorted_candidates.end(),
- [](const Derivation& a, const Derivation& b) {
- // Sort by id.
- if (a.rule_id != b.rule_id) {
- return a.rule_id < b.rule_id;
- }
-
- // Sort by increasing start.
- if (a.match->codepoint_span.first != b.match->codepoint_span.first) {
- return a.match->codepoint_span.first < b.match->codepoint_span.first;
- }
-
- // Sort by decreasing end.
- return a.match->codepoint_span.second > b.match->codepoint_span.second;
- });
-
- // Deduplicate by overlap.
- std::vector<Derivation> result;
- for (int i = 0; i < sorted_candidates.size(); i++) {
- const Derivation& candidate = sorted_candidates[i];
- bool eliminated = false;
-
- // Due to the sorting above, the candidate can only be completely
- // intersected by a match before it in the sorted order.
- for (int j = i - 1; j >= 0; j--) {
- if (sorted_candidates[j].rule_id != candidate.rule_id) {
- break;
- }
- if (sorted_candidates[j].match->codepoint_span.first <=
- candidate.match->codepoint_span.first &&
- sorted_candidates[j].match->codepoint_span.second >=
- candidate.match->codepoint_span.second) {
- eliminated = true;
- break;
- }
- }
-
- if (!eliminated) {
- result.push_back(candidate);
- }
- }
- return result;
-}
-
-bool VerifyAssertions(const Match* match) {
- bool result = true;
- grammar::Traverse(match, [&result](const Match* node) {
- if (node->type != Match::kAssertionMatch) {
- // Only validation if all checks so far passed.
- return result;
- }
-
- // Positive assertions are by definition fulfilled,
- // fail if the assertion is negative.
- if (static_cast<const AssertionMatch*>(node)->negative) {
- result = false;
- }
- return result;
- });
- return result;
-}
-
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils.h b/native/utils/grammar/rules-utils.h
index e6ac541..64e8245 100644
--- a/native/utils/grammar/rules-utils.h
+++ b/native/utils/grammar/rules-utils.h
@@ -19,10 +19,8 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
-#include <unordered_map>
#include <vector>
-#include "utils/grammar/match.h"
#include "utils/grammar/rules_generated.h"
#include "utils/i18n/locale.h"
@@ -37,22 +35,6 @@
const std::vector<std::vector<Locale>>& shard_locales,
const std::vector<Locale>& locales);
-// Deduplicates rule derivations by containing overlap.
-// The grammar system can output multiple candidates for optional parts.
-// For example if a rule has an optional suffix, we
-// will get two rule derivations when the suffix is present: one with and one
-// without the suffix. We therefore deduplicate by containing overlap, viz. from
-// two candidates we keep the longer one if it completely contains the shorter.
-struct Derivation {
- const Match* match;
- int64 rule_id;
-};
-std::vector<Derivation> DeduplicateDerivations(
- const std::vector<Derivation>& derivations);
-
-// Checks that all assertions of a match tree are fulfilled.
-bool VerifyAssertions(const Match* match);
-
} // namespace libtextclassifier3::grammar
#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
diff --git a/native/utils/grammar/rules-utils_test.cc b/native/utils/grammar/rules-utils_test.cc
deleted file mode 100644
index 6391be1..0000000
--- a/native/utils/grammar/rules-utils_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/grammar/rules-utils.h"
-
-#include <vector>
-
-#include "utils/grammar/match.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3::grammar {
-namespace {
-
-using testing::ElementsAre;
-using testing::Value;
-
-// Create test match object.
-Match CreateMatch(const CodepointIndex begin, const CodepointIndex end) {
- Match match;
- match.Init(0, CodepointSpan{begin, end},
- /*arg_match_offset=*/begin);
- return match;
-}
-
-MATCHER_P(IsDerivation, candidate, "") {
- return Value(arg.rule_id, candidate.rule_id) &&
- Value(arg.match, candidate.match);
-}
-
-TEST(UtilsTest, DeduplicatesMatches) {
- // Overlapping matches from the same rule.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep longest.
- EXPECT_THAT(DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[2])));
-}
-
-TEST(UtilsTest, DeduplicatesMatchesPerRule) {
- // Overlapping matches from different rules.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0},
- {&matches[0], /*rule_id=*/1}};
-
- // Keep longest for rule 0, but also keep match from rule 1.
- EXPECT_THAT(
- DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[2]), IsDerivation(candidates[3])));
-}
-
-TEST(UtilsTest, KeepNonoverlapping) {
- // Non-overlapping matches.
- Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(2, 3)};
- const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
- {&matches[1], /*rule_id=*/0},
- {&matches[2], /*rule_id=*/0}};
-
- // Keep all matches.
- EXPECT_THAT(
- DeduplicateDerivations(candidates),
- ElementsAre(IsDerivation(candidates[0]), IsDerivation(candidates[1]),
- IsDerivation(candidates[2])));
-}
-
-} // namespace
-} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index 71e23f8..3225892 100755
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -16,6 +16,7 @@
include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/grammar/semantics/expression.fbs";
// The terminal rules map as sorted strings table.
// The sorted terminal strings table is represented as offsets into the
@@ -147,19 +148,6 @@
annotation_nt:[Nonterminals_.AnnotationNtEntry];
}
-// Callback information.
-namespace libtextclassifier3.grammar.RulesSet_;
-struct Callback {
- // Whether the callback is a filter.
- is_filter:bool;
-}
-
-namespace libtextclassifier3.grammar.RulesSet_;
-struct CallbackEntry {
- key:uint (key);
- value:Callback;
-}
-
namespace libtextclassifier3.grammar.RulesSet_.DebugInformation_;
table NonterminalNamesEntry {
key:int (key);
@@ -205,13 +193,15 @@
terminals:string (shared);
nonterminals:RulesSet_.Nonterminals;
- callback:[RulesSet_.CallbackEntry];
+ reserved_6:int16 (deprecated);
debug_information:RulesSet_.DebugInformation;
regex_annotator:[RulesSet_.RegexAnnotator];
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool;
- reserved_10:int16 (deprecated);
+
+ // The semantic expressions associated with rule matches.
+ semantic_expression:[SemanticExpression];
// The schema defining the semantic results.
semantic_values_schema:[ubyte];
diff --git a/native/utils/grammar/semantics/composer.cc b/native/utils/grammar/semantics/composer.cc
new file mode 100644
index 0000000..2d69049
--- /dev/null
+++ b/native/utils/grammar/semantics/composer.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/composer.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/evaluators/constituent-eval.h"
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+#include "utils/grammar/semantics/evaluators/parse-number-eval.h"
+#include "utils/grammar/semantics/evaluators/span-eval.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Gathers all constituents of a rule and index them.
+// The constituents are numbered in the rule construction. But consituents could
+// be in optional parts of the rule and might not be present in a match.
+// This finds all constituents that are present in a match and allows to
+// retrieve them by their index.
+std::unordered_map<int, const ParseTree*> GatherConstituents(
+ const ParseTree* root) {
+ std::unordered_map<int, const ParseTree*> constituents;
+ Traverse(root, [root, &constituents](const ParseTree* node) {
+ switch (node->type) {
+ case ParseTree::Type::kMapping:
+ TC3_CHECK(node->IsUnaryRule());
+ constituents[static_cast<const MappingNode*>(node)->id] =
+ node->unary_rule_rhs();
+ return false;
+ case ParseTree::Type::kDefault:
+ // Continue traversal.
+ return true;
+ default:
+ // Don't continue the traversal if we are not at the root node.
+ // This could e.g. be an assertion node.
+ return (node == root);
+ }
+ });
+ return constituents;
+}
+
+} // namespace
+
+SemanticComposer::SemanticComposer(
+ const reflection::Schema* semantic_values_schema) {
+ evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression,
+ std::make_unique<ArithmeticExpressionEvaluator>(this));
+ evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression,
+ std::make_unique<ConstituentEvaluator>());
+ evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression,
+ std::make_unique<ParseNumberEvaluator>(this));
+ evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression,
+ std::make_unique<SpanAsStringEvaluator>());
+ if (semantic_values_schema != nullptr) {
+ // Register semantic functions.
+ evaluators_.emplace(
+ SemanticExpression_::Expression_ComposeExpression,
+ std::make_unique<ComposeEvaluator>(this, semantic_values_schema));
+ evaluators_.emplace(
+ SemanticExpression_::Expression_ConstValueExpression,
+ std::make_unique<ConstEvaluator>(semantic_values_schema));
+ evaluators_.emplace(
+ SemanticExpression_::Expression_MergeValueExpression,
+ std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema));
+ }
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Eval(
+ const TextContext& text_context, const Derivation& derivation,
+ UnsafeArena* arena) const {
+ if (!derivation.parse_tree->IsUnaryRule() ||
+ derivation.parse_tree->unary_rule_rhs()->type !=
+ ParseTree::Type::kExpression) {
+ return nullptr;
+ }
+ return Eval(text_context,
+ static_cast<const SemanticExpressionNode*>(
+ derivation.parse_tree->unary_rule_rhs()),
+ arena);
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Eval(
+ const TextContext& text_context, const SemanticExpressionNode* derivation,
+ UnsafeArena* arena) const {
+ // Evaluate constituents.
+ EvalContext context{&text_context, derivation};
+ for (const auto& [constituent_index, constituent] :
+ GatherConstituents(derivation)) {
+ if (constituent->type == ParseTree::Type::kExpression) {
+ TC3_ASSIGN_OR_RETURN(
+ context.rule_constituents[constituent_index],
+ Eval(text_context,
+ static_cast<const SemanticExpressionNode*>(constituent), arena));
+ } else {
+ // Just use the text of the constituent if no semantic expression was
+ // defined.
+ context.rule_constituents[constituent_index] = SemanticValue::Create(
+ text_context.Span(constituent->codepoint_span), arena);
+ }
+ }
+ return Apply(context, derivation->expression, arena);
+}
+
+StatusOr<const SemanticValue*> SemanticComposer::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const auto handler_it = evaluators_.find(expression->expression_type());
+ if (handler_it == evaluators_.end()) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ std::string("Unhandled expression type: ") +
+ EnumNameExpression(expression->expression_type()));
+ }
+ return handler_it->second->Apply(context, expression, arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/composer.h b/native/utils/grammar/semantics/composer.h
new file mode 100644
index 0000000..135f7d6
--- /dev/null
+++ b/native/utils/grammar/semantics/composer.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/text-context.h"
+
+namespace libtextclassifier3::grammar {
+
+// Semantic value composer.
+// It evaluates a semantic expression of a syntactic parse tree as a semantic
+// value.
+// It evaluates the constituents of a rule match and applies them to semantic
+// expression, calling out to semantic functions that implement the basic
+// building blocks.
+class SemanticComposer : public SemanticExpressionEvaluator {
+ public:
+ // Expects a flatbuffer schema that describes the possible result values of
+ // an evaluation.
+ explicit SemanticComposer(const reflection::Schema* semantic_values_schema);
+
+ // Evaluates a semantic expression that is associated with the root of a parse
+ // tree.
+ StatusOr<const SemanticValue*> Eval(const TextContext& text_context,
+ const Derivation& derivation,
+ UnsafeArena* arena) const;
+
+ // Applies a semantic expression to a list of constituents and
+ // produces an output semantic value.
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ // Evaluates a semantic expression against a parse tree.
+ StatusOr<const SemanticValue*> Eval(const TextContext& text_context,
+ const SemanticExpressionNode* derivation,
+ UnsafeArena* arena) const;
+
+ std::unordered_map<SemanticExpression_::Expression,
+ std::unique_ptr<SemanticExpressionEvaluator>>
+ evaluators_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_
diff --git a/native/utils/grammar/semantics/composer_test.cc b/native/utils/grammar/semantics/composer_test.cc
new file mode 100644
index 0000000..95b0759
--- /dev/null
+++ b/native/utils/grammar/semantics/composer_test.cc
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/composer.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parser.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/rules.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::ElementsAre;
+
+class SemanticComposerTest : public GrammarTest {};
+
+TEST_F(SemanticComposerTest, EvaluatesSimpleMapping) {
+ RulesSetT model;
+ Rules rules;
+ const int test_value_type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ {
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 1;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = test_value_type;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ {
+ rules.Add("<month>", {"february"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 2;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = test_value_type;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ const int kMonth = 0;
+ rules.Add("<month_rule>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule), kMonth);
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+ SemanticComposer composer(semantic_values_schema_.get());
+
+ {
+ const TextContext text = TextContextForText("Month: January");
+ const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 14)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 1);
+ }
+
+ {
+ const TextContext text = TextContextForText("Month: February");
+ const std::vector<Derivation> derivations = parser.Parse(text, &arena_);
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kMonth, 7, 15)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 2);
+ }
+}
+
+TEST_F(SemanticComposerTest, RecursivelyEvaluatesConstituents) {
+ RulesSetT model;
+ Rules rules;
+ const int test_value_type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ constexpr int kDateRule = 0;
+ {
+ rules.Add("<month>", {"january"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ TestValueT value;
+ value.value = 42;
+ const std::string serialized_value = PackFlatbuffer<TestValue>(&value);
+ ConstValueExpressionT const_value;
+ const_value.type = test_value_type;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.value.assign(serialized_value.begin(), serialized_value.end());
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(const_value);
+ }
+ {
+ // Define constituents of the rule.
+ // TODO(smillius): Add support in the rules builder to directly specify
+ // constituent ids in the rule, e.g. `<date> ::= <month>@0? <4_digits>`.
+ rules.Add("<date_@0>", {"<month>"},
+ static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/1);
+ rules.Add("<date>", {"<date_@0>?", "<4_digits>"},
+ static_cast<CallbackId>(DefaultCallback::kSemanticExpression),
+ /*callback_param=*/model.semantic_expression.size());
+ ConstituentExpressionT constituent;
+ constituent.id = 1;
+ model.semantic_expression.emplace_back(new SemanticExpressionT);
+ model.semantic_expression.back()->expression.Set(constituent);
+ rules.Add("<date_rule>", {"<date>"},
+ static_cast<CallbackId>(DefaultCallback::kRootRule),
+ /*callback_param=*/kDateRule);
+ }
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false, &model);
+ const std::string model_buffer = PackFlatbuffer<RulesSet>(&model);
+ Parser parser(unilib_.get(),
+ flatbuffers::GetRoot<RulesSet>(model_buffer.data()));
+ SemanticComposer composer(semantic_values_schema_.get());
+
+ {
+ const TextContext text = TextContextForText("Event: January 2020");
+ const std::vector<Derivation> derivations =
+ ValidDeduplicatedDerivations(parser.Parse(text, &arena_));
+ EXPECT_THAT(derivations, ElementsAre(IsDerivation(kDateRule, 7, 19)));
+
+ StatusOr<const SemanticValue*> maybe_value =
+ composer.Eval(text, derivations.front(), &arena_);
+ EXPECT_TRUE(maybe_value.ok());
+
+ const TestValue* value = maybe_value.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(value->value(), 42);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/eval-context.h b/native/utils/grammar/semantics/eval-context.h
new file mode 100644
index 0000000..aab878a
--- /dev/null
+++ b/native/utils/grammar/semantics/eval-context.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
+
+#include <unordered_map>
+
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/text-context.h"
+
+namespace libtextclassifier3::grammar {
+
+// Context for the evaluation of the semantic expression of a rule parse tree.
+// This contains data about the evaluated constituents (named parts) of a rule
+// and it's match.
+struct EvalContext {
+ // The input text.
+ const TextContext* text_context = nullptr;
+
+ // The syntactic parse tree that is begin evaluated.
+ const ParseTree* parse_tree = nullptr;
+
+ // A map of an id of a rule constituent (named part of a rule match) to it's
+ // evaluated semantic value.
+ std::unordered_map<int, const SemanticValue*> rule_constituents;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_
diff --git a/native/utils/grammar/semantics/evaluator.h b/native/utils/grammar/semantics/evaluator.h
new file mode 100644
index 0000000..7b6bf90
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluator.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Interface for a semantic function that evaluates an expression and returns
+// a semantic value.
+class SemanticExpressionEvaluator {
+ public:
+ virtual ~SemanticExpressionEvaluator() = default;
+
+ // Applies `expression` to the `context` to produce a semantic value.
+ virtual StatusOr<const SemanticValue*> Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const = 0;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc b/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc
new file mode 100644
index 0000000..76b72c6
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval.cc
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+
+#include <limits>
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+StatusOr<const SemanticValue*> Reduce(
+ const SemanticExpressionEvaluator* composer, const EvalContext& context,
+ const ArithmeticExpression* expression, UnsafeArena* arena) {
+ T result;
+ switch (expression->op()) {
+ case ArithmeticExpression_::Operator_OP_ADD: {
+ result = 0;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MUL: {
+ result = 1;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MIN: {
+ result = std::numeric_limits<T>::max();
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MAX: {
+ result = std::numeric_limits<T>::min();
+ break;
+ }
+ default: {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unexpected op: " +
+ std::string(ArithmeticExpression_::EnumNameOperator(
+ expression->op())));
+ }
+ }
+ if (expression->values() != nullptr) {
+ for (const SemanticExpression* semantic_expression :
+ *expression->values()) {
+ TC3_ASSIGN_OR_RETURN(
+ const SemanticValue* value,
+ composer->Apply(context, semantic_expression, arena));
+ if (value == nullptr) {
+ continue;
+ }
+ if (!value->Has<T>()) {
+ return Status(
+ StatusCode::INVALID_ARGUMENT,
+ "Argument didn't evaluate as expected type: " +
+ std::string(reflection::EnumNameBaseType(value->base_type())));
+ }
+ const T scalar_value = value->Value<T>();
+ switch (expression->op()) {
+ case ArithmeticExpression_::Operator_OP_ADD: {
+ result += scalar_value;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MUL: {
+ result *= scalar_value;
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MIN: {
+ result = std::min(result, scalar_value);
+ break;
+ }
+ case ArithmeticExpression_::Operator_OP_MAX: {
+ result = std::max(result, scalar_value);
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+ }
+ }
+ return SemanticValue::Create(result, arena);
+}
+
+} // namespace
+
+StatusOr<const SemanticValue*> ArithmeticExpressionEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ArithmeticExpression);
+ const ArithmeticExpression* arithmetic_expression =
+ expression->expression_as_ArithmeticExpression();
+ switch (arithmetic_expression->base_type()) {
+ case reflection::BaseType::Byte:
+ return Reduce<int8>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UByte:
+ return Reduce<uint8>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Short:
+ return Reduce<int16>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UShort:
+ return Reduce<uint16>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Int:
+ return Reduce<int32>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::UInt:
+ return Reduce<uint32>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Long:
+ return Reduce<int64>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::ULong:
+ return Reduce<uint64>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Float:
+ return Reduce<float>(composer_, context, arithmetic_expression, arena);
+ case reflection::BaseType::Double:
+ return Reduce<double>(composer_, context, arithmetic_expression, arena);
+ default:
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unsupported for ArithmeticExpression: " +
+ std::string(reflection::EnumNameBaseType(
+ static_cast<reflection::BaseType>(
+ arithmetic_expression->base_type()))));
+ }
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval.h b/native/utils/grammar/semantics/evaluators/arithmetic-eval.h
new file mode 100644
index 0000000..38efc57
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Evaluates an arithmetic expression.
+// Expects zero or more arguments and produces either sum, product, minimum or
+// maximum of its arguments. If no arguments are specified, each operator
+// returns its identity value.
+class ArithmeticExpressionEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ArithmeticExpressionEvaluator(
+ const SemanticExpressionEvaluator* composer)
+ : composer_(composer) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+};
+
+} // namespace libtextclassifier3::grammar
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc b/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc
new file mode 100644
index 0000000..5385fc1
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/arithmetic-eval_test.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/arithmetic-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+class ArithmeticExpressionEvaluatorTest : public GrammarTest {
+ protected:
+ T Eval(const ArithmeticExpression_::Operator op) {
+ ArithmeticExpressionT arithmetic_expression;
+ arithmetic_expression.base_type = flatbuffers_base_type<T>::value;
+ arithmetic_expression.op = op;
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(1));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(2));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(3));
+ arithmetic_expression.values.push_back(
+ CreatePrimitiveConstExpression<T>(4));
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(arithmetic_expression));
+
+ // Setup evaluators.
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ArithmeticExpressionEvaluator arithmetic_eval(&const_eval);
+
+ // Run evaluator.
+ StatusOr<const SemanticValue*> result =
+ arithmetic_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ // Check result.
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_TRUE(result_value != nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using NumberTypes = ::testing::Types<int8, uint8, int16, uint16, int32, uint32,
+ int64, uint64, double, float>;
+TYPED_TEST_SUITE(ArithmeticExpressionEvaluatorTest, NumberTypes);
+
+TYPED_TEST(ArithmeticExpressionEvaluatorTest, ParsesNumber) {
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_ADD), 10);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MUL), 24);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MIN), 1);
+ EXPECT_EQ(this->Eval(ArithmeticExpression_::Operator_OP_MAX), 4);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval.cc b/native/utils/grammar/semantics/evaluators/compose-eval.cc
new file mode 100644
index 0000000..09bbf5c
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval.cc
@@ -0,0 +1,183 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+
+#include "utils/base/status_macros.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Tries setting a singular field.
+template <typename T>
+Status TrySetField(const reflection::Field* field, const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Set<T>(field, value->Value<T>())) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
+ }
+ return Status::OK;
+}
+
+template <>
+Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
+ const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Mutable(field)->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not set sub-field in result.");
+ }
+ return Status::OK;
+}
+
+// Tries adding a value to a repeated field.
+template <typename T>
+Status TryAddField(const reflection::Field* field, const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Repeated(field)->Add(value->Value<T>())) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
+ }
+ return Status::OK;
+}
+
+template <>
+Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
+ const SemanticValue* value,
+ MutableFlatbuffer* result) {
+ if (!result->Repeated(field)->Add()->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not add message to repeated field.");
+ }
+ return Status::OK;
+}
+
+// Tries adding or setting a value for a field.
+template <typename T>
+Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
+ const SemanticValue* value, MutableFlatbuffer* result) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!result->GetFieldWithParent(field_path, &parent, &field)) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
+ }
+ if (field->type()->base_type() == reflection::Vector) {
+ return TryAddField<T>(field, value, parent);
+ } else {
+ return TrySetField<T>(field, value, parent);
+ }
+}
+
+} // namespace
+
+StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const ComposeExpression* compose_expression =
+ expression->expression_as_ComposeExpression();
+ std::unique_ptr<MutableFlatbuffer> result =
+ semantic_value_builder_.NewTable(compose_expression->type());
+
+ if (result == nullptr) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
+ }
+
+ // Evaluate and set fields.
+ if (compose_expression->fields() != nullptr) {
+ for (const ComposeExpression_::Field* field :
+ *compose_expression->fields()) {
+ // Evaluate argument.
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ composer_->Apply(context, field->value(), arena));
+ if (value == nullptr) {
+ continue;
+ }
+
+ switch (value->base_type()) {
+ case reflection::BaseType::Bool: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<bool>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Byte: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int8>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UByte: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint8>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Short: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int16>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UShort: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint16>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Int: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int32>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::UInt: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint32>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Long: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<int64>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::ULong: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<uint64>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Float: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<float>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Double: {
+ TC3_RETURN_IF_ERROR(
+ TrySetOrAddValue<double>(field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::String: {
+ TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
+ field->path(), value, result.get()));
+ break;
+ }
+ case reflection::BaseType::Obj: {
+ TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
+ field->path(), value, result.get()));
+ break;
+ }
+ default:
+ return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
+ }
+ }
+ }
+
+ return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval.h b/native/utils/grammar/semantics/evaluators/compose-eval.h
new file mode 100644
index 0000000..ba3b6f9
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Combines arguments to a result type.
+class ComposeEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ComposeEvaluator(const SemanticExpressionEvaluator* composer,
+ const reflection::Schema* semantic_values_schema)
+ : composer_(composer), semantic_value_builder_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+ const MutableFlatbufferBuilder semantic_value_builder_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/compose-eval_test.cc b/native/utils/grammar/semantics/evaluators/compose-eval_test.cc
new file mode 100644
index 0000000..f26042a
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/compose-eval_test.cc
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/compose-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ComposeEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ComposeEvaluatorTest()
+ : const_eval_(semantic_values_schema_.get()) {}
+
+ // Evaluator that just returns a constant value.
+ ConstEvaluator const_eval_;
+};
+
+TEST_F(ComposeEvaluatorTest, SetsSingleField) {
+ TestDateT date;
+ date.day = 1;
+ date.month = 2;
+ date.year = 2020;
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->date()->day(), 1);
+ EXPECT_EQ(result_test_value->date()->month(), 2);
+ EXPECT_EQ(result_test_value->date()->year(), 2020);
+}
+
+TEST_F(ComposeEvaluatorTest, SetsStringField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"test_string"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<StringPiece>("this is a test");
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->test_string()->str(), "this is a test");
+}
+
+TEST_F(ComposeEvaluatorTest, SetsPrimitiveField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"day"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(1);
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestDate");
+ const TestDate* result_date = result_value->Table<TestDate>();
+ EXPECT_EQ(result_date->day(), 1);
+}
+
+TEST_F(ComposeEvaluatorTest, MergesMultipleField) {
+ TestDateT day;
+ day.day = 1;
+
+ TestDateT month;
+ month.month = 2;
+
+ TestDateT year;
+ year.year = 2020;
+
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ for (const TestDateT& component : std::vector<TestDateT>{day, month, year}) {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value =
+ CreateConstDateExpression(component);
+ }
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->date()->day(), 1);
+ EXPECT_EQ(result_test_value->date()->month(), 2);
+ EXPECT_EQ(result_test_value->date()->year(), 2020);
+}
+
+TEST_F(ComposeEvaluatorTest, SucceedsEvenWhenEmpty) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path = CreateUnpackedFieldPath({"date"});
+ compose_expression.fields.back()->value.reset(new SemanticExpressionT);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ // Just return null value.
+ struct : public SemanticExpressionEvaluator {
+ StatusOr<const SemanticValue*> Apply(const EvalContext&,
+ const SemanticExpression*,
+ UnsafeArena*) const override {
+ return nullptr;
+ }
+ } null_eval;
+
+ ComposeEvaluator compose_eval(&null_eval, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+}
+
+TEST_F(ComposeEvaluatorTest, AddsRepeatedPrimitiveField) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_enum"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(TestEnum_ENUM_1);
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_enum"});
+ compose_expression.fields.back()->value =
+ CreatePrimitiveConstExpression<int>(TestEnum_ENUM_2);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->repeated_enum()->size(), 2);
+ EXPECT_EQ(result_test_value->repeated_enum()->Get(0), TestEnum_ENUM_1);
+ EXPECT_EQ(result_test_value->repeated_enum()->Get(1), TestEnum_ENUM_2);
+}
+
+TEST_F(ComposeEvaluatorTest, AddsRepeatedSubmessage) {
+ ComposeExpressionT compose_expression;
+ compose_expression.type =
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_date"});
+ TestDateT date;
+ date.day = 1;
+ date.month = 2;
+ date.year = 2020;
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ }
+
+ {
+ compose_expression.fields.emplace_back(new ComposeExpression_::FieldT);
+ compose_expression.fields.back()->path =
+ CreateUnpackedFieldPath({"repeated_date"});
+ TestDateT date;
+ date.day = 3;
+ date.month = 4;
+ date.year = 2021;
+ compose_expression.fields.back()->value = CreateConstDateExpression(date);
+ }
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(compose_expression));
+
+ ComposeEvaluator compose_eval(&const_eval_, semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ compose_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->repeated_date()->size(), 2);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->day(), 1);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->month(), 2);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(0)->year(), 2020);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->day(), 3);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->month(), 4);
+ EXPECT_EQ(result_test_value->repeated_date()->Get(1)->year(), 2021);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/const-eval.h b/native/utils/grammar/semantics/evaluators/const-eval.h
new file mode 100644
index 0000000..67a4c54
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/const-eval.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns a constant value of a given type.
+class ConstEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ConstEvaluator(const reflection::Schema* semantic_values_schema)
+ : semantic_values_schema_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext&,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ConstValueExpression);
+ const ConstValueExpression* const_value_expression =
+ expression->expression_as_ConstValueExpression();
+ const reflection::BaseType base_type =
+ static_cast<reflection::BaseType>(const_value_expression->base_type());
+ const StringPiece data = StringPiece(
+ reinterpret_cast<const char*>(const_value_expression->value()->data()),
+ const_value_expression->value()->size());
+
+ if (base_type == reflection::BaseType::Obj) {
+ // Resolve the object type.
+ const int type_id = const_value_expression->type();
+ if (type_id < 0 ||
+ type_id >= semantic_values_schema_->objects()->size()) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid type.");
+ }
+ return SemanticValue::Create(semantic_values_schema_->objects()->Get(
+ const_value_expression->type()),
+ data, arena);
+ } else {
+ return SemanticValue::Create(base_type, data, arena);
+ }
+ }
+
+ private:
+ const reflection::Schema* semantic_values_schema_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/const-eval_test.cc b/native/utils/grammar/semantics/evaluators/const-eval_test.cc
new file mode 100644
index 0000000..02eea5d
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/const-eval_test.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ConstEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ConstEvaluatorTest() : const_eval_(semantic_values_schema_.get()) {}
+
+ const ConstEvaluator const_eval_;
+};
+
+TEST_F(ConstEvaluatorTest, CreatesConstantSemanticValues) {
+ TestValueT value;
+ value.a_float_value = 64.42;
+ value.test_string = "test string";
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackConstExpression(value);
+
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestValue");
+ const TestValue* result_test_value = result_value->Table<TestValue>();
+ EXPECT_EQ(result_test_value->test_string()->str(), "test string");
+ EXPECT_FLOAT_EQ(result_test_value->a_float_value(), 64.42);
+}
+
+template <typename T>
+class PrimitiveValueTest : public ConstEvaluatorTest {
+ protected:
+ T Eval(const T value) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<T>(value);
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_NE(result_value, nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using PrimitiveTypes = ::testing::Types<int8, uint8, int16, uint16, int32,
+ uint32, int64, uint64, double, float>;
+TYPED_TEST_SUITE(PrimitiveValueTest, PrimitiveTypes);
+
+TYPED_TEST(PrimitiveValueTest, CreatesConstantPrimitiveValues) {
+ EXPECT_EQ(this->Eval(42), 42);
+}
+
+TEST_F(ConstEvaluatorTest, CreatesStringValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<StringPiece>("this is a test.");
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->Value<StringPiece>().ToString(), "this is a test.");
+}
+
+TEST_F(ConstEvaluatorTest, CreatesBoolValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackPrimitiveConstExpression<bool>(true);
+ StatusOr<const SemanticValue*> result =
+ const_eval_.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_TRUE(result_value->Value<bool>());
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/constituent-eval.h b/native/utils/grammar/semantics/evaluators/constituent-eval.h
new file mode 100644
index 0000000..4b877fe
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/constituent-eval.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns the semantic value of an evaluated constituent.
+class ConstituentEvaluator : public SemanticExpressionEvaluator {
+ public:
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena*) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ConstituentExpression);
+ const ConstituentExpression* constituent_expression =
+ expression->expression_as_ConstituentExpression();
+ const auto constituent_it =
+ context.rule_constituents.find(constituent_expression->id());
+ if (constituent_it != context.rule_constituents.end()) {
+ return constituent_it->second;
+ }
+ // The constituent was not present in the rule parse tree, return a
+ // null value for it.
+ return nullptr;
+ }
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc b/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc
new file mode 100644
index 0000000..c40d1cc
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/constituent-eval_test.cc
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/constituent-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class ConstituentEvaluatorTest : public GrammarTest {
+ protected:
+ explicit ConstituentEvaluatorTest() {}
+
+ OwnedFlatbuffer<SemanticExpression> CreateConstituentExpression(
+ const int id) {
+ ConstituentExpressionT constituent_expression;
+ constituent_expression.id = id;
+ return CreateExpression(constituent_expression);
+ }
+
+ const ConstituentEvaluator constituent_eval_;
+};
+
+TEST_F(ConstituentEvaluatorTest, HandlesNotDefinedConstituents) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateConstituentExpression(/*id=*/42);
+
+ StatusOr<const SemanticValue*> result = constituent_eval_.Apply(
+ /*context=*/{}, expression.get(), /*arena=*/nullptr);
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie(), nullptr);
+}
+
+TEST_F(ConstituentEvaluatorTest, ForwardsConstituentSemanticValues) {
+ // Create example values for constituents.
+ EvalContext context;
+ TestValueT value_0;
+ value_0.test_string = "constituent 0 value";
+ context.rule_constituents[0] = CreateSemanticValue(value_0);
+
+ TestValueT value_42;
+ value_42.test_string = "constituent 42 value";
+ context.rule_constituents[42] = CreateSemanticValue(value_42);
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateConstituentExpression(/*id=*/42);
+
+ StatusOr<const SemanticValue*> result =
+ constituent_eval_.Apply(context, expression.get(), /*arena=*/nullptr);
+
+ EXPECT_TRUE(result.ok());
+ const TestValue* result_value = result.ValueOrDie()->Table<TestValue>();
+ EXPECT_EQ(result_value->test_string()->str(), "constituent 42 value");
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval.cc b/native/utils/grammar/semantics/evaluators/merge-values-eval.cc
new file mode 100644
index 0000000..d9bf544
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+
+namespace libtextclassifier3::grammar {
+
+StatusOr<const SemanticValue*> MergeValuesEvaluator::Apply(
+ const EvalContext& context, const SemanticExpression* expression,
+ UnsafeArena* arena) const {
+ const MergeValueExpression* merge_value_expression =
+ expression->expression_as_MergeValueExpression();
+ std::unique_ptr<MutableFlatbuffer> result =
+ semantic_value_builder_.NewTable(merge_value_expression->type());
+
+ if (result == nullptr) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
+ }
+
+ for (const SemanticExpression* semantic_expression :
+ *merge_value_expression->values()) {
+ TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
+ composer_->Apply(context, semantic_expression, arena));
+ if (value == nullptr) {
+ continue;
+ }
+ if ((value->type() != result->type()) ||
+ !result->MergeFrom(value->Table())) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Could not merge the results.");
+ }
+ }
+ return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval.h b/native/utils/grammar/semantics/evaluators/merge-values-eval.h
new file mode 100644
index 0000000..8fe49e3
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/status_macros.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Evaluate the “merge” semantic function expression.
+// Conceptually, the way this merge evaluator works is that each of the
+// arguments (semantic value) is merged into a return type semantic value.
+class MergeValuesEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit MergeValuesEvaluator(
+ const SemanticExpressionEvaluator* composer,
+ const reflection::Schema* semantic_values_schema)
+ : composer_(composer), semantic_value_builder_(semantic_values_schema) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override;
+
+ private:
+ const SemanticExpressionEvaluator* composer_;
+ const MutableFlatbufferBuilder semantic_value_builder_;
+};
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc b/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc
new file mode 100644
index 0000000..8d3d70f
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/merge-values-eval_test.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/merge-values-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class MergeValuesEvaluatorTest : public GrammarTest {
+ protected:
+ explicit MergeValuesEvaluatorTest()
+ : const_eval_(semantic_values_schema_.get()) {}
+
+ // Evaluator that just returns a constant value.
+ ConstEvaluator const_eval_;
+};
+
+TEST_F(MergeValuesEvaluatorTest, MergeSemanticValues) {
+ // Setup the data
+ TestDateT date_value_day;
+ date_value_day.day = 23;
+ TestDateT date_value_month;
+ date_value_month.month = 9;
+ TestDateT date_value_year;
+ date_value_year.year = 2019;
+
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateAndPackMergeValuesExpression(
+ {date_value_day, date_value_month, date_value_year});
+
+ MergeValuesEvaluator merge_values_eval(&const_eval_,
+ semantic_values_schema_.get());
+
+ StatusOr<const SemanticValue*> result =
+ merge_values_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ ASSERT_NE(result_value, nullptr);
+ EXPECT_EQ(result_value->type()->name()->str(),
+ "libtextclassifier3.grammar.TestDate");
+ const TestDate* result_test_date = result_value->Table<TestDate>();
+ EXPECT_EQ(result_test_date->day(), 23);
+ EXPECT_EQ(result_test_date->month(), 9);
+ EXPECT_EQ(result_test_date->year(), 2019);
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/parse-number-eval.h b/native/utils/grammar/semantics/evaluators/parse-number-eval.h
new file mode 100644
index 0000000..9171c65
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/parse-number-eval.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
+
+#include <string>
+
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/strings/numbers.h"
+
+namespace libtextclassifier3::grammar {
+
+// Parses a string as a number.
+class ParseNumberEvaluator : public SemanticExpressionEvaluator {
+ public:
+ explicit ParseNumberEvaluator(const SemanticExpressionEvaluator* composer)
+ : composer_(composer) {}
+
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_ParseNumberExpression);
+ const ParseNumberExpression* parse_number_expression =
+ expression->expression_as_ParseNumberExpression();
+
+ // Evaluate argument.
+ TC3_ASSIGN_OR_RETURN(
+ const SemanticValue* value,
+ composer_->Apply(context, parse_number_expression->value(), arena));
+ if (value == nullptr) {
+ return nullptr;
+ }
+ if (!value->Has<StringPiece>()) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Argument didn't evaluate as a string value.");
+ }
+ const std::string data = value->Value<std::string>();
+
+ // Parse the string data as a number.
+ const reflection::BaseType type =
+ static_cast<reflection::BaseType>(parse_number_expression->base_type());
+ if (flatbuffers::IsLong(type)) {
+ TC3_ASSIGN_OR_RETURN(const int64 value, TryParse<int64>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else if (flatbuffers::IsInteger(type)) {
+ TC3_ASSIGN_OR_RETURN(const int32 value, TryParse<int32>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else if (flatbuffers::IsFloat(type)) {
+ TC3_ASSIGN_OR_RETURN(const double value, TryParse<double>(data));
+ return SemanticValue::Create(type, value, arena);
+ } else {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Unsupported type: " + std::to_string(type));
+ }
+ }
+
+ private:
+ template <typename T>
+ bool Parse(const std::string& data, T* value) const;
+
+ template <>
+ bool Parse(const std::string& data, int32* value) const {
+ return ParseInt32(data.data(), value);
+ }
+
+ template <>
+ bool Parse(const std::string& data, int64* value) const {
+ return ParseInt64(data.data(), value);
+ }
+
+ template <>
+ bool Parse(const std::string& data, double* value) const {
+ return ParseDouble(data.data(), value);
+ }
+
+ template <typename T>
+ StatusOr<T> TryParse(const std::string& data) const {
+ T result;
+ if (!Parse<T>(data, &result)) {
+ return Status(StatusCode::INVALID_ARGUMENT, "Could not parse value.");
+ }
+ return result;
+ }
+
+ const SemanticExpressionEvaluator* composer_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc b/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc
new file mode 100644
index 0000000..e9f21d9
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/parse-number-eval_test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/parse-number-eval.h"
+
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/evaluators/const-eval.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+template <typename T>
+class ParseNumberEvaluatorTest : public GrammarTest {
+ protected:
+ T Eval(const StringPiece value) {
+ ParseNumberExpressionT parse_number_expression;
+ parse_number_expression.base_type = flatbuffers_base_type<T>::value;
+ parse_number_expression.value =
+ CreatePrimitiveConstExpression<StringPiece>(value);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(parse_number_expression));
+
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ParseNumberEvaluator parse_number_eval(&const_eval);
+
+ StatusOr<const SemanticValue*> result =
+ parse_number_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_TRUE(result.ok());
+ const SemanticValue* result_value = result.ValueOrDie();
+ EXPECT_NE(result_value, nullptr);
+ return result_value->Value<T>();
+ }
+};
+
+using NumberTypes = ::testing::Types<int8, uint8, int16, uint16, int32, uint32,
+ int64, uint64, double, float>;
+TYPED_TEST_SUITE(ParseNumberEvaluatorTest, NumberTypes);
+
+TYPED_TEST(ParseNumberEvaluatorTest, ParsesNumber) {
+ EXPECT_EQ(this->Eval("42"), 42);
+}
+
+TEST_F(GrammarTest, FailsOnInvalidArgument) {
+ ParseNumberExpressionT parse_number_expression;
+ parse_number_expression.base_type = flatbuffers_base_type<int32>::value;
+ parse_number_expression.value = CreatePrimitiveConstExpression<int32>(42);
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(std::move(parse_number_expression));
+
+ ConstEvaluator const_eval(semantic_values_schema_.get());
+ ParseNumberEvaluator parse_number_eval(&const_eval);
+
+ StatusOr<const SemanticValue*> result =
+ parse_number_eval.Apply(/*context=*/{}, expression.get(), &arena_);
+
+ EXPECT_FALSE(result.ok());
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/semantics/evaluators/span-eval.h b/native/utils/grammar/semantics/evaluators/span-eval.h
new file mode 100644
index 0000000..f8a5d5b
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/span-eval.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/base/statusor.h"
+#include "utils/grammar/semantics/eval-context.h"
+#include "utils/grammar/semantics/evaluator.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/semantics/value.h"
+
+namespace libtextclassifier3::grammar {
+
+// Returns a value lifted from a parse tree.
+class SpanAsStringEvaluator : public SemanticExpressionEvaluator {
+ public:
+ StatusOr<const SemanticValue*> Apply(const EvalContext& context,
+ const SemanticExpression* expression,
+ UnsafeArena* arena) const override {
+ TC3_DCHECK_EQ(expression->expression_type(),
+ SemanticExpression_::Expression_SpanAsStringExpression);
+ return SemanticValue::Create(
+ context.text_context->Span(context.parse_tree->codepoint_span), arena);
+ }
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_
diff --git a/native/utils/grammar/semantics/evaluators/span-eval_test.cc b/native/utils/grammar/semantics/evaluators/span-eval_test.cc
new file mode 100644
index 0000000..daba860
--- /dev/null
+++ b/native/utils/grammar/semantics/evaluators/span-eval_test.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/semantics/evaluators/span-eval.h"
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/grammar/semantics/expression_generated.h"
+#include "utils/grammar/testing/utils.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+class SpanTextEvaluatorTest : public GrammarTest {};
+
+TEST_F(SpanTextEvaluatorTest, CreatesSpanTextValues) {
+ OwnedFlatbuffer<SemanticExpression> expression =
+ CreateExpression(SpanAsStringExpressionT());
+ SpanAsStringEvaluator span_eval;
+ TextContext text = TextContextForText("This a test.");
+ ParseTree derivation(/*lhs=*/kUnassignedNonterm, CodepointSpan{5, 11},
+ /*match_offset=*/0, /*type=*/ParseTree::Type::kDefault);
+
+ StatusOr<const SemanticValue*> result = span_eval.Apply(
+ /*context=*/{&text, &derivation}, expression.get(), &arena_);
+
+ ASSERT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie()->Value<StringPiece>().ToString(), "a test");
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/next/semantics/expression.fbs b/native/utils/grammar/semantics/expression.fbs
similarity index 71%
rename from native/utils/grammar/next/semantics/expression.fbs
rename to native/utils/grammar/semantics/expression.fbs
index 40c1eb1..5397407 100755
--- a/native/utils/grammar/next/semantics/expression.fbs
+++ b/native/utils/grammar/semantics/expression.fbs
@@ -16,7 +16,7 @@
include "utils/flatbuffers/flatbuffers.fbs";
-namespace libtextclassifier3.grammar.next.SemanticExpression_;
+namespace libtextclassifier3.grammar.SemanticExpression_;
union Expression {
ConstValueExpression,
ConstituentExpression,
@@ -24,16 +24,17 @@
SpanAsStringExpression,
ParseNumberExpression,
MergeValueExpression,
+ ArithmeticExpression,
}
// A semantic expression.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table SemanticExpression {
expression:SemanticExpression_.Expression;
}
// A constant flatbuffer value.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table ConstValueExpression {
// The base type of the value.
base_type:int;
@@ -47,31 +48,25 @@
}
// The value of a rule constituent.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table ConstituentExpression {
// The id of the constituent.
id:ushort;
}
// The fields to set.
-namespace libtextclassifier3.grammar.next.ComposeExpression_;
+namespace libtextclassifier3.grammar.ComposeExpression_;
table Field {
// The field to set.
path:libtextclassifier3.FlatbufferFieldPath;
// The value.
value:SemanticExpression;
-
- // Whether the field can be absent: If set to true, evaluation to null will
- // not be treated as an error.
- // A value of null represents a non-present value that can e.g. arise from
- // optional parts of a rule that might not be present in a match.
- optional:bool;
}
// A combination: Compose a result from arguments.
// https://mitpress.mit.edu/sites/default/files/sicp/full-text/book/book-Z-H-4.html#%_toc_%_sec_1.1.1
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table ComposeExpression {
// The id of the type of the result.
type:int;
@@ -80,12 +75,12 @@
}
// Lifts a span as a value.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table SpanAsStringExpression {
}
// Parses a string as a number.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table ParseNumberExpression {
// The base type of the value.
base_type:int;
@@ -94,7 +89,7 @@
}
// Merge the semantic expressions.
-namespace libtextclassifier3.grammar.next;
+namespace libtextclassifier3.grammar;
table MergeValueExpression {
// The id of the type of the result.
type:int;
@@ -102,3 +97,23 @@
values:[SemanticExpression];
}
+// The operator of the arithmetic expression.
+namespace libtextclassifier3.grammar.ArithmeticExpression_;
+enum Operator : int {
+ NO_OP = 0,
+ OP_ADD = 1,
+ OP_MUL = 2,
+ OP_MAX = 3,
+ OP_MIN = 4,
+}
+
+// Simple arithmetic expression.
+namespace libtextclassifier3.grammar;
+table ArithmeticExpression {
+ // The base type of the operation.
+ base_type:int;
+
+ op:ArithmeticExpression_.Operator;
+ values:[SemanticExpression];
+}
+
diff --git a/native/utils/grammar/semantics/value.h b/native/utils/grammar/semantics/value.h
new file mode 100644
index 0000000..abf5eaf
--- /dev/null
+++ b/native/utils/grammar/semantics/value.h
@@ -0,0 +1,218 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
+
+#include "utils/base/arena.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3::grammar {
+
+// A semantic value as a typed, arena-allocated flatbuffer.
+// This denotes the possible results of the evaluation of a semantic expression.
+class SemanticValue {
+ public:
+ // Creates an arena allocated semantic value.
+ template <typename T>
+ static const SemanticValue* Create(const T value, UnsafeArena* arena) {
+ static_assert(!std::is_pointer<T>() && std::is_scalar<T>());
+ if (char* buffer = reinterpret_cast<char*>(
+ arena->AllocAligned(sizeof(T), alignof(T)))) {
+ flatbuffers::WriteScalar<T>(buffer, value);
+ return arena->AllocAndInit<SemanticValue>(
+ libtextclassifier3::flatbuffers_base_type<T>::value,
+ StringPiece(buffer, sizeof(T)));
+ }
+ return nullptr;
+ }
+
+ template <>
+ const SemanticValue* Create(const StringPiece value, UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(reflection::BaseType::String,
+ value);
+ }
+
+ template <>
+ const SemanticValue* Create(const UnicodeText value, UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(
+ reflection::BaseType::String,
+ StringPiece(value.data(), value.size_bytes()));
+ }
+
+ template <>
+ const SemanticValue* Create(const MutableFlatbuffer* value,
+ UnsafeArena* arena) {
+ const std::string buffer = value->Serialize();
+ return Create(
+ value->type(),
+ StringPiece(arena->Memdup(buffer.data(), buffer.size()), buffer.size()),
+ arena);
+ }
+
+ static const SemanticValue* Create(const reflection::Object* type,
+ const StringPiece data,
+ UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(type, data);
+ }
+
+ static const SemanticValue* Create(const reflection::BaseType base_type,
+ const StringPiece data,
+ UnsafeArena* arena) {
+ return arena->AllocAndInit<SemanticValue>(base_type, data);
+ }
+
+ template <typename T>
+ static const SemanticValue* Create(const reflection::BaseType base_type,
+ const T value, UnsafeArena* arena) {
+ switch (base_type) {
+ case reflection::BaseType::Bool:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Bool>::value>(value),
+ arena);
+ case reflection::BaseType::Byte:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Byte>::value>(value),
+ arena);
+ case reflection::BaseType::UByte:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UByte>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Short:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Short>::value>(
+ value),
+ arena);
+ case reflection::BaseType::UShort:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UShort>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Int:
+ return Create(
+ static_cast<flatbuffers_cpp_type<reflection::BaseType::Int>::value>(
+ value),
+ arena);
+ case reflection::BaseType::UInt:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::UInt>::value>(value),
+ arena);
+ case reflection::BaseType::Long:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Long>::value>(value),
+ arena);
+ case reflection::BaseType::ULong:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::ULong>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Float:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Float>::value>(
+ value),
+ arena);
+ case reflection::BaseType::Double:
+ return Create(
+ static_cast<
+ flatbuffers_cpp_type<reflection::BaseType::Double>::value>(
+ value),
+ arena);
+ default: {
+ TC3_LOG(ERROR) << "Unhandled type: " << base_type;
+ return nullptr;
+ }
+ }
+ }
+
+ explicit SemanticValue(const reflection::BaseType base_type,
+ const StringPiece data)
+ : base_type_(base_type), type_(nullptr), data_(data) {}
+ explicit SemanticValue(const reflection::Object* type, const StringPiece data)
+ : base_type_(reflection::BaseType::Obj), type_(type), data_(data) {}
+
+ template <typename T>
+ bool Has() const {
+ return base_type_ == libtextclassifier3::flatbuffers_base_type<T>::value;
+ }
+
+ template <>
+ bool Has<flatbuffers::Table>() const {
+ return base_type_ == reflection::BaseType::Obj;
+ }
+
+ template <typename T = flatbuffers::Table>
+ const T* Table() const {
+ TC3_CHECK(Has<flatbuffers::Table>());
+ return flatbuffers::GetRoot<T>(
+ reinterpret_cast<const unsigned char*>(data_.data()));
+ }
+
+ template <typename T>
+ const T Value() const {
+ TC3_CHECK(Has<T>());
+ return flatbuffers::ReadScalar<T>(data_.data());
+ }
+
+ template <>
+ const StringPiece Value<StringPiece>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return data_;
+ }
+
+ template <>
+ const std::string Value<std::string>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return data_.ToString();
+ }
+
+ template <>
+ const UnicodeText Value<UnicodeText>() const {
+ TC3_CHECK(Has<StringPiece>());
+ return UTF8ToUnicodeText(data_, /*do_copy=*/false);
+ }
+
+ const reflection::BaseType base_type() const { return base_type_; }
+ const reflection::Object* type() const { return type_; }
+
+ private:
+ // The base type.
+ const reflection::BaseType base_type_;
+
+ // The object type of the value.
+ const reflection::Object* type_;
+
+ StringPiece data_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
diff --git a/native/utils/grammar/testing/utils.h b/native/utils/grammar/testing/utils.h
new file mode 100644
index 0000000..709b94a
--- /dev/null
+++ b/native/utils/grammar/testing/utils.h
@@ -0,0 +1,239 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/arena.h"
+#include "utils/flatbuffers/reflection.h"
+#include "utils/grammar/parsing/derivation.h"
+#include "utils/grammar/parsing/parse-tree.h"
+#include "utils/grammar/semantics/value.h"
+#include "utils/grammar/testing/value_generated.h"
+#include "utils/grammar/text-context.h"
+#include "utils/i18n/locale.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3::grammar {
+
+inline std::ostream& operator<<(std::ostream& os, const ParseTree* parse_tree) {
+ return os << "ParseTree(lhs=" << parse_tree->lhs
+ << ", begin=" << parse_tree->codepoint_span.first
+ << ", end=" << parse_tree->codepoint_span.second << ")";
+}
+
+inline std::ostream& operator<<(std::ostream& os,
+ const Derivation& derivation) {
+ return os << "Derivation(rule_id=" << derivation.rule_id << ", "
+ << "parse_tree=" << derivation.parse_tree << ")";
+}
+
+MATCHER_P3(IsDerivation, rule_id, begin, end,
+ "is derivation of rule that " +
+ ::testing::DescribeMatcher<int>(rule_id, negation) +
+ ", begin that " +
+ ::testing::DescribeMatcher<int>(begin, negation) +
+ ", end that " + ::testing::DescribeMatcher<int>(end, negation)) {
+ return ::testing::ExplainMatchResult(CodepointSpan(begin, end),
+ arg.parse_tree->codepoint_span,
+ result_listener) &&
+ ::testing::ExplainMatchResult(rule_id, arg.rule_id, result_listener);
+}
+
+// A test fixture with common auxiliary test methods.
+class GrammarTest : public testing::Test {
+ protected:
+ explicit GrammarTest()
+ : unilib_(CreateUniLibForTesting()),
+ arena_(/*block_size=*/16 << 10),
+ semantic_values_schema_(
+ GetTestFileContent("utils/grammar/testing/value.bfbs")),
+ tokenizer_(libtextclassifier3::TokenizationType_ICU, unilib_.get(),
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false) {}
+
+ TextContext TextContextForText(const std::string& text) {
+ TextContext context;
+ context.text = UTF8ToUnicodeText(text);
+ context.tokens = tokenizer_.Tokenize(context.text);
+ context.codepoints = context.text.Codepoints();
+ context.codepoints.push_back(context.text.end());
+ context.locales = {Locale::FromBCP47("en")};
+ context.context_span.first = 0;
+ context.context_span.second = context.tokens.size();
+ return context;
+ }
+
+ // Creates a semantic expression union.
+ template <typename T>
+ SemanticExpressionT AsSemanticExpressionUnion(T&& expression) {
+ SemanticExpressionT semantic_expression;
+ semantic_expression.expression.Set(std::forward<T>(expression));
+ return semantic_expression;
+ }
+
+ template <typename T>
+ OwnedFlatbuffer<SemanticExpression> CreateExpression(T&& expression) {
+ return Pack<SemanticExpression>(
+ AsSemanticExpressionUnion(std::forward<T>(expression)));
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateEmptyExpression() {
+ return Pack<SemanticExpression>(SemanticExpressionT());
+ }
+
+ // Packs a flatbuffer.
+ template <typename T>
+ OwnedFlatbuffer<T> Pack(const typename T::NativeTableType&& value) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(T::Pack(builder, &value));
+ return OwnedFlatbuffer<T>(builder.Release());
+ }
+
+ // Creates a test semantic value.
+ const SemanticValue* CreateSemanticValue(const TestValueT& value) {
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ return arena_.AllocAndInit<SemanticValue>(
+ semantic_values_schema_->objects()->Get(
+ TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value()),
+ StringPiece(arena_.Memdup(value_buffer.data(), value_buffer.size()),
+ value_buffer.size()));
+ }
+
+ // Creates a primitive semantic value.
+ template <typename T>
+ const SemanticValue* CreatePrimitiveSemanticValue(const T value) {
+ return arena_.AllocAndInit<SemanticValue>(value);
+ }
+
+ std::unique_ptr<SemanticExpressionT> CreateConstExpression(
+ const TestValueT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackConstExpression(
+ const TestValueT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestValue")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestValue>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ return CreateExpression(const_value);
+ }
+
+ std::unique_ptr<SemanticExpressionT> CreateConstDateExpression(
+ const TestDateT& value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::Obj;
+ const_value.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ const std::string value_buffer = PackFlatbuffer<TestDate>(&value);
+ const_value.value.assign(value_buffer.begin(), value_buffer.end());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackMergeValuesExpression(
+ const std::vector<TestDateT>& values) {
+ MergeValueExpressionT merge_expression;
+ merge_expression.type = TypeIdForName(semantic_values_schema_.get(),
+ "libtextclassifier3.grammar.TestDate")
+ .value();
+ for (const TestDateT& test_date : values) {
+ merge_expression.values.emplace_back(new SemanticExpressionT);
+ merge_expression.values.back() = CreateConstDateExpression(test_date);
+ }
+ return CreateExpression(std::move(merge_expression));
+ }
+
+ template <typename T>
+ std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
+ const T value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = flatbuffers_base_type<T>::value;
+ const_value.value.resize(sizeof(T));
+ flatbuffers::WriteScalar(const_value.value.data(), value);
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ template <typename T>
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
+ const T value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = flatbuffers_base_type<T>::value;
+ const_value.value.resize(sizeof(T));
+ flatbuffers::WriteScalar(const_value.value.data(), value);
+ return CreateExpression(const_value);
+ }
+
+ template <>
+ OwnedFlatbuffer<SemanticExpression> CreateAndPackPrimitiveConstExpression(
+ const StringPiece value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::String;
+ const_value.value.assign(value.data(), value.data() + value.size());
+ return CreateExpression(const_value);
+ }
+
+ template <>
+ std::unique_ptr<SemanticExpressionT> CreatePrimitiveConstExpression(
+ const StringPiece value) {
+ ConstValueExpressionT const_value;
+ const_value.base_type = reflection::BaseType::String;
+ const_value.value.assign(value.data(), value.data() + value.size());
+ auto semantic_expression = std::make_unique<SemanticExpressionT>();
+ semantic_expression->expression.Set(const_value);
+ return semantic_expression;
+ }
+
+ const std::unique_ptr<UniLib> unilib_;
+ UnsafeArena arena_;
+ const OwnedFlatbuffer<reflection::Schema, std::string>
+ semantic_values_schema_;
+ const Tokenizer tokenizer_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TESTING_UTILS_H_
diff --git a/native/utils/grammar/testing/value.bfbs b/native/utils/grammar/testing/value.bfbs
new file mode 100644
index 0000000..6dd8538
--- /dev/null
+++ b/native/utils/grammar/testing/value.bfbs
Binary files differ
diff --git a/native/utils/grammar/testing/value.fbs b/native/utils/grammar/testing/value.fbs
new file mode 100755
index 0000000..0429491
--- /dev/null
+++ b/native/utils/grammar/testing/value.fbs
@@ -0,0 +1,44 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Test enum
+namespace libtextclassifier3.grammar;
+enum TestEnum : int {
+ UNSPECIFIED = 0,
+ ENUM_1 = 1,
+ ENUM_2 = 2,
+}
+
+// A test semantic value result.
+namespace libtextclassifier3.grammar;
+table TestValue {
+ value:int;
+ a_float_value:double;
+ test_string:string (shared);
+ date:TestDate;
+ enum_value:TestEnum;
+ repeated_enum:[TestEnum];
+ repeated_date:[TestDate];
+}
+
+// A test date value result.
+namespace libtextclassifier3.grammar;
+table TestDate {
+ day:int;
+ month:int;
+ year:int;
+}
+
diff --git a/native/utils/grammar/text-context.h b/native/utils/grammar/text-context.h
new file mode 100644
index 0000000..53e5f8b
--- /dev/null
+++ b/native/utils/grammar/text-context.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::grammar {
+
+// Input to the parser.
+struct TextContext {
+ // Returns a view on a span of the text.
+ const UnicodeText Span(const CodepointSpan& span) const {
+ return text.Substring(codepoints[span.first], codepoints[span.second],
+ /*do_copy=*/false);
+ }
+
+ // The input text.
+ UnicodeText text;
+
+ // Pre-enumerated codepoints for fast substring extraction.
+ std::vector<UnicodeText::const_iterator> codepoints;
+
+ // The tokenized input text.
+ std::vector<Token> tokens;
+
+ // Locales of the input text.
+ std::vector<Locale> locales;
+
+ // Text annotations.
+ std::vector<AnnotatedSpan> annotations;
+
+ // The span of tokens to consider.
+ TokenSpan context_span;
+};
+
+}; // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index fc5c28e..49135bf 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -192,15 +192,6 @@
continue;
}
- // If either callback is a filter, we can't share as we must always run
- // both filters.
- if ((lhs.callback.id != kNoCallback &&
- filters_.find(lhs.callback.id) != filters_.end()) ||
- (candidate->callback.id != kNoCallback &&
- filters_.find(candidate->callback.id) != filters_.end())) {
- continue;
- }
-
// If the nonterminal is already defined, it must match for sharing.
if (lhs.nonterminal != kUnassignedNonterm &&
lhs.nonterminal != candidate->nonterminal) {
@@ -406,13 +397,6 @@
void Ir::Serialize(const bool include_debug_information,
RulesSetT* output) const {
- // Set callback information.
- for (const CallbackId filter_callback_id : filters_) {
- output->callback.push_back(RulesSet_::CallbackEntry(
- filter_callback_id, RulesSet_::Callback(/*is_filter=*/true)));
- }
- SortStructsForBinarySearchLookup(&output->callback);
-
// Add information about predefined nonterminal classes.
output->nonterminals.reset(new RulesSet_::NonterminalsT);
output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
index ac15a44..adafa66 100644
--- a/native/utils/grammar/utils/ir.h
+++ b/native/utils/grammar/utils/ir.h
@@ -96,9 +96,8 @@
std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
};
- explicit Ir(const std::unordered_set<CallbackId>& filters = {},
- const int num_shards = 1)
- : num_nonterminals_(0), filters_(filters), shards_(num_shards) {}
+ explicit Ir(const int num_shards = 1)
+ : num_nonterminals_(0), shards_(num_shards) {}
// Adds a new non-terminal.
Nonterm AddNonterminal(const std::string& name = "") {
@@ -225,9 +224,6 @@
Nonterm num_nonterminals_;
std::unordered_set<Nonterm> nonshareable_;
- // The set of callbacks that should be treated as filters.
- std::unordered_set<CallbackId> filters_;
-
// The sharded rules.
std::vector<RulesShard> shards_;
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
index 4d12e76..279d99a 100644
--- a/native/utils/grammar/utils/ir_test.cc
+++ b/native/utils/grammar/utils/ir_test.cc
@@ -97,47 +97,21 @@
// Test sharing in the presence of callbacks.
constexpr CallbackId kOutput1 = 1;
constexpr CallbackId kOutput2 = 2;
- constexpr CallbackId kFilter1 = 3;
- constexpr CallbackId kFilter2 = 4;
- Ir ir(/*filters=*/{kFilter1, kFilter2});
+ Ir ir;
const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
const Nonterm x2 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput1, 0}}, "hello");
const Nonterm x3 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter1, 0}}, "hello");
- const Nonterm x4 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
- const Nonterm x5 =
- ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter2, 0}}, "hello");
// Duplicate entry.
- const Nonterm x6 =
+ const Nonterm x4 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
+ EXPECT_THAT(x3, Eq(x1));
EXPECT_THAT(x4, Eq(x1));
- EXPECT_THAT(x5, Ne(x1));
- EXPECT_THAT(x5, Ne(x3));
- EXPECT_THAT(x6, Ne(x3));
-}
-
-TEST(IrTest, HandlesSharingWithCallbacksWithDifferentParameters) {
- // Test sharing in the presence of callbacks.
- constexpr CallbackId kOutput = 1;
- constexpr CallbackId kFilter = 2;
- Ir ir(/*filters=*/{kFilter});
-
- const Nonterm x1 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 0}}, "world");
- const Nonterm x2 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 1}}, "world");
- const Nonterm x3 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 0}}, "world");
- const Nonterm x4 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 1}}, "world");
-
- EXPECT_THAT(x2, Eq(x1));
- EXPECT_THAT(x3, Ne(x1));
- EXPECT_THAT(x4, Ne(x1));
- EXPECT_THAT(x4, Ne(x3));
}
TEST(IrTest, SerializesRulesToFlatbufferFormat) {
@@ -181,7 +155,7 @@
}
TEST(IrTest, HandlesRulesSharding) {
- Ir ir(/*filters=*/{}, /*num_shards=*/2);
+ Ir ir(/*num_shards=*/2);
const Nonterm verb = ir.AddUnshareableNonterminal();
const Nonterm set_reminder = ir.AddUnshareableNonterminal();
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index c988194..623124a 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -161,10 +161,16 @@
void Rules::AddAlias(const std::string& nonterminal_name,
const std::string& alias) {
+#ifndef TC3_USE_CXX14
TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
.first->second,
nonterminal_name)
<< "Cannot redefine alias: " << alias;
+#else
+ nonterminal_alias_[alias] = nonterminal_name;
+ TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
+ << "Cannot redefine alias: " << alias;
+#endif
}
// Defines a nonterminal for an externally provided annotation.
@@ -408,7 +414,7 @@
}
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
- Ir rules(filters_, num_shards_);
+ Ir rules(num_shards_);
std::unordered_map<int, Nonterm> nonterminal_ids;
// Pending rules to process.
@@ -424,7 +430,7 @@
}
// Assign (unmergeable) Nonterm values to any nonterminals that have
- // multiple rules or that have a filter callback on some rule.
+ // multiple rules.
for (int i = 0; i < nonterminals_.size(); i++) {
const NontermInfo& nonterminal = nonterminals_[i];
@@ -437,15 +443,8 @@
(nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
!nonterminal.regex_rules.empty());
for (const int rule_index : nonterminal.rules) {
- const Rule& rule = rules_[rule_index];
-
// Schedule rule.
scheduled_rules.insert({i, rule_index});
-
- if (rule.callback != kNoCallback &&
- filters_.find(rule.callback) != filters_.end()) {
- unmergeable = true;
- }
}
if (unmergeable) {
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index b818d39..a6851f3 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -34,19 +34,15 @@
// All rules for a grammar will be collected in a rules object.
//
// Rules r;
-// CallbackId date_output_callback = 1;
-// CallbackId day_filter_callback = 2; r.DefineFilter(day_filter_callback);
-// CallbackId year_filter_callback = 3; r.DefineFilter(year_filter_callback);
-// r.Add("<date>", {"<monthname>", "<day>", <year>"},
-// date_output_callback);
+// r.Add("<date>", {"<monthname>", "<day>", <year>"});
// r.Add("<monthname>", {"January"});
// ...
// r.Add("<monthname>", {"December"});
-// r.Add("<day>", {"<string_of_digits>"}, day_filter_callback);
-// r.Add("<year>", {"<string_of_digits>"}, year_filter_callback);
+// r.Add("<day>", {"<string_of_digits>"});
+// r.Add("<year>", {"<string_of_digits>"});
//
-// The Add() method adds a rule with a given lhs, rhs, and (optionally)
-// callback. The rhs is just a list of terminals and nonterminals. Anything
+// The Add() method adds a rule with a given lhs, rhs/
+// The rhs is just a list of terminals and nonterminals. Anything
// surrounded in angle brackets is considered a nonterminal. A "?" can follow
// any element of the RHS, like this:
//
@@ -55,9 +51,8 @@
// This indicates that the <day> and "," parts of the rhs are optional.
// (This is just notational shorthand for adding a bunch of rules.)
//
-// Once you're done adding rules and callbacks to the Rules object,
-// call r.Finalize() on it. This lowers the rule set into an internal
-// representation.
+// Once you're done adding rules, r.Finalize() lowers the rule set into an
+// internal representation.
class Rules {
public:
explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {}
@@ -173,9 +168,6 @@
// nonterminal.
void AddAlias(const std::string& nonterminal_name, const std::string& alias);
- // Defines a new filter id.
- void DefineFilter(const CallbackId filter_id) { filters_.insert(filter_id); }
-
// Lowers the rule set into the intermediate representation.
// Treats nonterminals given by the argument `predefined_nonterminals` as
// defined externally. This allows to define rules that are dependent on
@@ -183,6 +175,9 @@
// fed to the matcher by the lexer.
Ir Finalize(const std::set<std::string>& predefined_nonterminals = {}) const;
+ const std::vector<NontermInfo>& nonterminals() const { return nonterminals_; }
+ const std::vector<Rule>& rules() const { return rules_; }
+
private:
void ExpandOptionals(
int lhs, const std::vector<RhsElement>& rhs, CallbackId callback,
@@ -230,9 +225,6 @@
// Rules.
std::vector<Rule> rules_;
std::vector<std::string> regex_rules_;
-
- // Ids of callbacks that should be treated as filters.
- std::unordered_set<CallbackId> filters_;
};
} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
index 30be704..8db88ab 100644
--- a/native/utils/grammar/utils/rules_test.cc
+++ b/native/utils/grammar/utils/rules_test.cc
@@ -51,14 +51,12 @@
TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
Rules rules;
const CallbackId output = 1;
- const CallbackId filter = 2;
- rules.DefineFilter(filter);
rules.Add("<verb>", {"buy"});
- rules.Add("<verb>", {"bring"}, output, 0);
- rules.Add("<verb>", {"remind"}, output, 0);
+ rules.Add("<verb>", {"bring"});
+ rules.Add("<verb>", {"remind"});
rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
- rules.Add("<action>", {"<reminder>"}, filter, 0);
+ rules.Add("<action>", {"<reminder>"}, output, 0);
const Ir ir = rules.Finalize();
RulesSetT frozen_rules;
@@ -68,9 +66,7 @@
EXPECT_EQ(frozen_rules.terminals,
std::string("bring\0buy\0me\0remind\0to\0", 23));
- // We have two identical output calls and one filter call in the rule set
- // definition above.
- EXPECT_THAT(frozen_rules.lhs, SizeIs(2));
+ EXPECT_THAT(frozen_rules.lhs, SizeIs(1));
EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(3));
EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
diff --git a/native/utils/i18n/locale-list.cc b/native/utils/i18n/locale-list.cc
new file mode 100644
index 0000000..a0be5ac
--- /dev/null
+++ b/native/utils/i18n/locale-list.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+LocaleList LocaleList::ParseFrom(const std::string& locale_tags) {
+ std::vector<StringPiece> split_locales = strings::Split(locale_tags, ',');
+ std::string reference_locale;
+ if (!split_locales.empty()) {
+ // Assigns the first parsed locale to reference_locale.
+ reference_locale = split_locales[0].ToString();
+ } else {
+ reference_locale = "";
+ }
+ std::vector<Locale> locales;
+ for (const StringPiece& locale_str : split_locales) {
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ TC3_LOG(WARNING) << "Failed to parse the detected_text_language_tag: "
+ << locale_str.ToString();
+ }
+ locales.push_back(locale);
+ }
+ return LocaleList(locales, split_locales, reference_locale);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/i18n/locale-list.h b/native/utils/i18n/locale-list.h
new file mode 100644
index 0000000..cf2e06d
--- /dev/null
+++ b/native/utils/i18n/locale-list.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+
+#include <string>
+
+#include "utils/i18n/locale.h"
+#include "utils/strings/split.h"
+
+namespace libtextclassifier3 {
+
+// Parses and hold data about locales (combined by delimiter ',').
+class LocaleList {
+ public:
+ // Constructs the
+ // - Collection of locale tag from local_tags
+ // - Collection of Locale objects from a valid BCP47 tag. (If the tag is
+ // invalid, an object is created but return false for IsInvalid() call.
+ // - Assigns the first parsed locale to reference_locale.
+ static LocaleList ParseFrom(const std::string& locale_tags);
+
+ std::vector<Locale> GetLocales() const { return locales_; }
+ std::vector<StringPiece> GetLocaleTags() const { return split_locales_; }
+ std::string GetReferenceLocale() const { return reference_locale_; }
+
+ private:
+ LocaleList(const std::vector<Locale>& locales,
+ const std::vector<StringPiece>& split_locales,
+ const StringPiece& reference_locale)
+ : locales_(locales),
+ split_locales_(split_locales),
+ reference_locale_(reference_locale.ToString()) {}
+
+ const std::vector<Locale> locales_;
+ const std::vector<StringPiece> split_locales_;
+ const std::string reference_locale_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
diff --git a/native/utils/i18n/locale-list_test.cc b/native/utils/i18n/locale-list_test.cc
new file mode 100644
index 0000000..d7cfd17
--- /dev/null
+++ b/native/utils/i18n/locale-list_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::SizeIs;
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(LocaleTest, ParsedLocalesSanityCheck) {
+ LocaleList locale_list = LocaleList::ParseFrom("en-US,zh-CN,ar,en");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(4));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(4));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en-US");
+}
+
+TEST(LocaleTest, ParsedLocalesEmpty) {
+ LocaleList locale_list = LocaleList::ParseFrom("");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(0));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(0));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "");
+}
+
+TEST(LocaleTest, ParsedLocalesIvalid) {
+ LocaleList locale_list = LocaleList::ParseFrom("en,invalid");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(2));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(2));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en");
+ EXPECT_TRUE(locale_list.GetLocales()[0].IsValid());
+ EXPECT_FALSE(locale_list.GetLocales()[1].IsValid());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
index 42de67e..39ade45 100644
--- a/native/utils/java/jni-base.cc
+++ b/native/utils/java/jni-base.cc
@@ -24,11 +24,13 @@
return env->EnsureLocalCapacity(capacity) == JNI_OK;
}
-bool JniExceptionCheckAndClear(JNIEnv* env) {
+bool JniExceptionCheckAndClear(JNIEnv* env, bool print_exception_on_error) {
TC3_CHECK(env != nullptr);
const bool result = env->ExceptionCheck();
if (result) {
- env->ExceptionDescribe();
+ if (print_exception_on_error) {
+ env->ExceptionDescribe();
+ }
env->ExceptionClear();
}
return result;
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index 0bc46fa..211000a 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -65,7 +65,8 @@
bool EnsureLocalCapacity(JNIEnv* env, int capacity);
// Returns true if there was an exception. Also it clears the exception.
-bool JniExceptionCheckAndClear(JNIEnv* env);
+bool JniExceptionCheckAndClear(JNIEnv* env,
+ bool print_exception_on_error = true);
// A deleter to be used with std::unique_ptr to delete JNI global references.
class GlobalRefDeleter {
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
index 952fe95..5ac60ef 100644
--- a/native/utils/java/jni-helper.h
+++ b/native/utils/java/jni-helper.h
@@ -152,8 +152,10 @@
jmethodID method_id, ...);
template <class T>
- static StatusOr<T> CallStaticIntMethod(JNIEnv* env, jclass clazz,
- jmethodID method_id, ...);
+ static StatusOr<T> CallStaticIntMethod(JNIEnv* env,
+ bool print_exception_on_error,
+ jclass clazz, jmethodID method_id,
+ ...);
};
template <typename T>
@@ -169,14 +171,19 @@
}
template <class T>
-StatusOr<T> JniHelper::CallStaticIntMethod(JNIEnv* env, jclass clazz,
- jmethodID method_id, ...) {
+StatusOr<T> JniHelper::CallStaticIntMethod(JNIEnv* env,
+ bool print_exception_on_error,
+ jclass clazz, jmethodID method_id,
+ ...) {
va_list args;
va_start(args, method_id);
jint result = env->CallStaticIntMethodV(clazz, method_id, args);
va_end(args);
- TC3_NO_EXCEPTION_OR_RETURN;
+ if (JniExceptionCheckAndClear(env, print_exception_on_error)) {
+ return {Status::UNKNOWN};
+ }
+
return result;
}
diff --git a/native/utils/regex-match_test.cc b/native/utils/regex-match_test.cc
index c45fb29..c7a7740 100644
--- a/native/utils/regex-match_test.cc
+++ b/native/utils/regex-match_test.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/jvm-test-utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
@@ -28,11 +29,10 @@
class RegexMatchTest : public testing::Test {
protected:
- RegexMatchTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
+ RegexMatchTest() : unilib_(libtextclassifier3::CreateUniLibForTesting()) {}
+ std::unique_ptr<UniLib> unilib_;
};
-#ifdef TC3_UNILIB_ICU
#ifndef TC3_DISABLE_LUA
TEST_F(RegexMatchTest, HandlesSimpleVerification) {
EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
@@ -65,7 +65,7 @@
return luhn(match[1].text);
)";
const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
+ unilib_->CreateRegexPattern(pattern);
ASSERT_TRUE(regex_pattern != nullptr);
const std::unique_ptr<UniLib::RegexMatcher> matcher =
regex_pattern->Matcher(message);
@@ -83,7 +83,7 @@
UTF8ToUnicodeText("never gonna (?:give (you) up|let (you) down)",
/*do_copy=*/true);
const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib_.CreateRegexPattern(pattern);
+ unilib_->CreateRegexPattern(pattern);
ASSERT_TRUE(regex_pattern != nullptr);
UnicodeText message =
UTF8ToUnicodeText("never gonna give you up - never gonna let you down");
@@ -108,7 +108,6 @@
EXPECT_THAT(GetCapturingGroupText(matcher.get(), 2).value(),
testing::Eq("you"));
}
-#endif
} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
index 0a05718..b4d9b83 100755
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
@@ -14,8 +14,8 @@
// limitations under the License.
//
-include "utils/zlib/buffer.fbs";
include "utils/i18n/language-tag.fbs";
+include "utils/zlib/buffer.fbs";
namespace libtextclassifier3;
table Resource {
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index da66ff6..20f72c4 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -50,6 +50,10 @@
SortCodepointRanges(internal_tokenizer_codepoint_ranges,
&internal_tokenizer_codepoint_ranges_);
+ if (type_ == TokenizationType_MIXED && split_on_script_change) {
+ TC3_LOG(ERROR) << "The option `split_on_script_change` is unavailable for "
+ "the selected tokenizer type (mixed).";
+ }
}
const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
@@ -234,15 +238,20 @@
if (!break_iterator) {
return false;
}
+ const int context_unicode_size = context_unicode.size_codepoints();
int last_unicode_index = 0;
int unicode_index = 0;
auto token_begin_it = context_unicode.begin();
while ((unicode_index = break_iterator->Next()) !=
UniLib::BreakIterator::kDone) {
const int token_length = unicode_index - last_unicode_index;
+ if (token_length + last_unicode_index > context_unicode_size) {
+ return false;
+ }
auto token_end_it = token_begin_it;
std::advance(token_end_it, token_length);
+ TC3_CHECK(token_end_it <= context_unicode.end());
// Determine if the whole token is whitespace.
bool is_whitespace = true;
diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc
index 7b56ce2..2ddd38c 100644
--- a/native/utils/utf8/unicodetext.cc
+++ b/native/utils/utf8/unicodetext.cc
@@ -202,6 +202,14 @@
return IsValidUTF8(repr_.data_, repr_.size_);
}
+std::vector<UnicodeText::const_iterator> UnicodeText::Codepoints() const {
+ std::vector<UnicodeText::const_iterator> codepoints;
+ for (auto it = begin(); it != end(); it++) {
+ codepoints.push_back(it);
+ }
+ return codepoints;
+}
+
bool UnicodeText::operator==(const UnicodeText& other) const {
if (repr_.size_ != other.repr_.size_) {
return false;
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 9810480..4ca0dd2 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -20,6 +20,7 @@
#include <iterator>
#include <string>
#include <utility>
+#include <vector>
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
@@ -174,6 +175,9 @@
UnicodeText& push_back(char32 ch);
void clear();
+ // Returns an iterator for each codepoint.
+ std::vector<const_iterator> Codepoints() const;
+
std::string ToUTF8String() const;
std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
static std::string UTF8Substring(const const_iterator& it_begin,
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
index 7423cf3..70b8fec 100644
--- a/native/utils/utf8/unilib-common.cc
+++ b/native/utils/utf8/unilib-common.cc
@@ -407,6 +407,10 @@
0x275E, 0x276E, 0x276F, 0x2E42, 0x301D, 0x301E, 0x301F, 0xFF02};
constexpr int kNumQuotation = ARRAYSIZE(kQuotation);
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=ampersand
+constexpr char32 kAmpersand[] = {0x0026, 0xFE60, 0xFF06, 0x1F674, 0x1F675};
+constexpr int kNumAmpersand = ARRAYSIZE(kAmpersand);
+
#undef ARRAYSIZE
static_assert(kNumOpeningBrackets == kNumClosingBrackets,
@@ -596,6 +600,10 @@
return GetMatchIndex(kQuotation, kNumQuotation, codepoint) >= 0;
}
+bool IsAmpersand(char32 codepoint) {
+ return GetMatchIndex(kAmpersand, kNumAmpersand, codepoint) >= 0;
+}
+
bool IsLatinLetter(char32 codepoint) {
return (GetOverlappingRangeIndex(
kLatinLettersRangesStart, kLatinLettersRangesEnd,
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
index 2788f3c..b192034 100644
--- a/native/utils/utf8/unilib-common.h
+++ b/native/utils/utf8/unilib-common.h
@@ -37,6 +37,7 @@
bool IsDot(char32 codepoint);
bool IsApostrophe(char32 codepoint);
bool IsQuotation(char32 codepoint);
+bool IsAmpersand(char32 codepoint);
bool IsLatinLetter(char32 codepoint);
bool IsArabicLetter(char32 codepoint);
@@ -52,6 +53,23 @@
char32 ToUpper(char32 codepoint);
char32 GetPairedBracket(char32 codepoint);
+// Checks if the text format is not likely to be a number. Used to avoid most of
+// the java exceptions thrown when fail to parse.
+template <class T>
+bool PassesIntPreChesks(const UnicodeText& text, const T result) {
+ if (text.empty() ||
+ (std::is_same<T, int32>::value && text.size_codepoints() > 10) ||
+ (std::is_same<T, int64>::value && text.size_codepoints() > 19)) {
+ return false;
+ }
+ for (auto it = text.begin(); it != text.end(); ++it) {
+ if (!IsDigit(*it)) {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index a0f4cb7..befe639 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -25,8 +25,8 @@
#include "utils/base/logging.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
@@ -81,6 +81,20 @@
// Implementations that call out to JVM. Behold the beauty.
// -----------------------------------------------------------------------------
+StatusOr<int32> UniLibBase::Length(const UnicodeText& text) const {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> text_java,
+ jni_cache_->ConvertToJavaString(text));
+
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN(int utf16_length,
+ JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_length));
+
+ return JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_code_point_count, 0,
+ utf16_length);
+}
+
bool UniLibBase::ParseInt32(const UnicodeText& text, int32* result) const {
return ParseInt(text, result);
}
@@ -94,29 +108,23 @@
return false;
}
- JNIEnv* env = jni_cache_->GetEnv();
auto it_dot = text.begin();
for (; it_dot != text.end() && !IsDot(*it_dot); it_dot++) {
}
- int64 integer_part;
+ int32 integer_part;
if (!ParseInt(UnicodeText::Substring(text.begin(), it_dot, /*do_copy=*/false),
&integer_part)) {
return false;
}
- int64 fractional_part = 0;
+ int32 fractional_part = 0;
if (it_dot != text.end()) {
- std::string fractional_part_str =
- UnicodeText::UTF8Substring(++it_dot, text.end());
- TC3_ASSIGN_OR_RETURN_FALSE(
- const ScopedLocalRef<jstring> fractional_text_java,
- jni_cache_->ConvertToJavaString(fractional_part_str));
- TC3_ASSIGN_OR_RETURN_FALSE(
- fractional_part,
- JniHelper::CallStaticIntMethod<int64>(
- env, jni_cache_->integer_class.get(), jni_cache_->integer_parse_int,
- fractional_text_java.get()));
+ if (!ParseInt(
+ UnicodeText::Substring(++it_dot, text.end(), /*do_copy=*/false),
+ &fractional_part)) {
+ return false;
+ }
}
double factional_part_double = fractional_part;
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index 4845de0..8b04789 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -32,6 +32,7 @@
#include "utils/java/jni-cache.h"
#include "utils/java/jni-helper.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
@@ -56,6 +57,8 @@
char32 ToUpper(char32 codepoint) const;
char32 GetPairedBracket(char32 codepoint) const;
+ StatusOr<int32> Length(const UnicodeText& text) const;
+
// Forward declaration for friend.
class RegexPattern;
@@ -197,13 +200,21 @@
return false;
}
+ // Avoid throwing exceptions when the text is unlikely to be a number.
+ int32 result32 = 0;
+ if (!PassesIntPreChesks(text, result32)) {
+ return false;
+ }
+
JNIEnv* env = jni_cache_->GetEnv();
TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> text_java,
jni_cache_->ConvertToJavaString(text));
TC3_ASSIGN_OR_RETURN_FALSE(
- *result, JniHelper::CallStaticIntMethod<T>(
- env, jni_cache_->integer_class.get(),
- jni_cache_->integer_parse_int, text_java.get()));
+ *result,
+ JniHelper::CallStaticIntMethod<T>(
+ env,
+ /*print_exception_on_error=*/false, jni_cache_->integer_class.get(),
+ jni_cache_->integer_parse_int, text_java.get()));
return true;
}
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
index 18cc261..ffda7d9 100644
--- a/native/utils/utf8/unilib.h
+++ b/native/utils/utf8/unilib.h
@@ -30,9 +30,6 @@
#elif defined TC3_UNILIB_APPLE
#include "utils/utf8/unilib-apple.h"
#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
-#elif defined TC3_UNILIB_DUMMY
-#include "utils/utf8/unilib-dummy.h"
-#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
#else
#error No TC3_UNILIB implementation specified.
#endif
@@ -116,6 +113,10 @@
return libtextclassifier3::IsQuotation(codepoint);
}
+ bool IsAmpersand(char32 codepoint) const {
+ return libtextclassifier3::IsAmpersand(codepoint);
+ }
+
bool IsLatinLetter(char32 codepoint) const {
return libtextclassifier3::IsLatinLetter(codepoint);
}
@@ -151,6 +152,31 @@
bool IsLetter(char32 codepoint) const {
return libtextclassifier3::IsLetter(codepoint);
}
+
+ bool IsValidUtf8(const UnicodeText& text) const {
+ // Basic check of structural validity of UTF8.
+ if (!text.is_valid()) {
+ return false;
+ }
+ // In addition to that, we declare that a valid UTF8 is when the number of
+ // codepoints in the string as measured by ICU is the same as the number of
+ // codepoints as measured by UnicodeText. Because if we don't do this check,
+ // the indices might differ, and cause trouble, because the assumption
+ // throughout the code is that ICU indices and UnicodeText indices are the
+ // same.
+ // NOTE: This is not perfect, as this doesn't check the alignment of the
+ // codepoints, but for the practical purposes should be enough.
+ const StatusOr<int32> icu_length = Length(text);
+ if (!icu_length.ok()) {
+ return false;
+ }
+
+ if (icu_length.ValueOrDie() != text.size_codepoints()) {
+ return false;
+ }
+
+ return true;
+ }
};
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
index 518f4c8..ed0f184 100644
--- a/native/utils/utf8/unilib_test-include.cc
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -58,6 +58,9 @@
EXPECT_TRUE(unilib_->IsApostrophe(u'ߴ'));
EXPECT_TRUE(unilib_->IsQuotation(u'"'));
EXPECT_TRUE(unilib_->IsQuotation(u'”'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'&'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'﹠'));
+ EXPECT_TRUE(unilib_->IsAmpersand(u'&'));
EXPECT_TRUE(unilib_->IsLatinLetter('A'));
EXPECT_TRUE(unilib_->IsArabicLetter(u'ب')); // ARABIC LETTER BEH
@@ -360,6 +363,12 @@
EXPECT_EQ(result, 1000000000);
}
+TEST_F(UniLibTest, Integer32ParseOverflowNumber) {
+ int32 result;
+ EXPECT_FALSE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("9123456789", /*do_copy=*/false), &result));
+}
+
TEST_F(UniLibTest, Integer32ParseEmptyString) {
int result;
EXPECT_FALSE(
@@ -518,5 +527,22 @@
UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
}
+TEST_F(UniLibTest, Length) {
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("hello", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("ěščřž", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ // Test Invalid UTF8.
+ // This testing condition needs to be != 1, as Apple character counting seems
+ // to return 0 when the input is invalid UTF8, while ICU will treat the
+ // invalid codepoint as 3 separate bytes.
+ EXPECT_NE(
+ unilib_->Length(UTF8ToUnicodeText("\xed\xa0\x80", /*do_copy=*/false))
+ .ValueOrDie(),
+ 1);
+}
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index 0fee3b3..44af148 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -312,7 +312,15 @@
*/
public void onNotificationExpansionChanged(
StatusBarNotification statusBarNotification, boolean isExpanded) {
- SmartSuggestionsLogSession session = sessionCache.get(statusBarNotification.getKey());
+ onNotificationExpansionChanged(statusBarNotification.getKey(), isExpanded);
+ }
+
+ /**
+ * Similar to {@link onNotificationExpansionChanged}, except that this takes the notificataion key
+ * as input.
+ */
+ public void onNotificationExpansionChanged(String key, boolean isExpanded) {
+ SmartSuggestionsLogSession session = sessionCache.get(key);
if (session == null) {
return;
}