Snap for 6916203 from d4d720b1095e3bd6b1531b6553e6fe20eec82cf3 to mainline-release
Change-Id: I113a2500715e7f7a9301e753a9b1df7148246c4f
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 3c8e10b..cbcd136 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -10,6 +10,9 @@
},
{
"name": "libtextclassifier_tests"
+ },
+ {
+ "name": "libtextclassifier_java_tests"
}
]
}
\ No newline at end of file
diff --git a/coverage/Android.bp b/coverage/Android.bp
new file mode 100644
index 0000000..699c629
--- /dev/null
+++ b/coverage/Android.bp
@@ -0,0 +1,12 @@
+android_library {
+ name: "TextClassifierCoverageLib",
+
+ srcs: ["src/**/*.java"],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ ],
+
+ sdk_version: "current",
+}
diff --git a/coverage/AndroidManifest.xml b/coverage/AndroidManifest.xml
new file mode 100644
index 0000000..1897421
--- /dev/null
+++ b/coverage/AndroidManifest.xml
@@ -0,0 +1,9 @@
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.testing">
+
+ <uses-sdk android:minSdkVersion="29" />
+
+ <application>
+ </application>
+
+</manifest>
\ No newline at end of file
diff --git a/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java b/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java
new file mode 100644
index 0000000..73ed185
--- /dev/null
+++ b/coverage/src/com/android/textclassifier/testing/SignalMaskInfo.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+package com.android.textclassifier.testing;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Class for reading a process' signal masks from the /proc filesystem. Looks for the
+ * BLOCKED, CAUGHT, IGNORED and PENDING masks from /proc/self/status, each of which is a
+ * 64 bit bitmask with one bit per signal.
+ *
+ * Maintains a map from SignalMaskInfo.Type to the bitmask. The {@code isValid} method
+ * will only return true if all 4 masks were successfully parsed. Provides lookup
+ * methods per signal, e.g. {@code isPending(signum)} which will throw
+ * {@code IllegalStateException} if the current data is not valid.
+ */
+public class SignalMaskInfo {
+ private enum Type {
+ BLOCKED("SigBlk"),
+ CAUGHT("SigCgt"),
+ IGNORED("SigIgn"),
+ PENDING("SigPnd");
+ // The tag for this mask in /proc/self/status
+ private final String tag;
+
+ Type(String tag) {
+ this.tag = tag + ":\t";
+ }
+
+ public String getTag() {
+ return tag;
+ }
+
+ public static Map<Type, Long> parseProcinfo(String path) {
+ Map<Type, Long> map = new HashMap<>();
+ try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ for (Type mask : values()) {
+ long value = mask.tryToParse(line);
+ if (value >= 0) {
+ map.put(mask, value);
+ }
+ }
+ }
+ } catch (NumberFormatException | IOException e) {
+ // Ignored - the map will end up being invalid instead.
+ }
+ return map;
+ }
+
+ private long tryToParse(String line) {
+ if (line.startsWith(tag)) {
+ return Long.valueOf(line.substring(tag.length()), 16);
+ } else {
+ return -1;
+ }
+ }
+ }
+
+ private static final String PROCFS_PATH = "/proc/self/status";
+ private Map<Type, Long> maskMap = null;
+
+ SignalMaskInfo() {
+ refresh();
+ }
+
+ public void refresh() {
+ maskMap = Type.parseProcinfo(PROCFS_PATH);
+ }
+
+ public boolean isValid() {
+ return (maskMap != null && maskMap.size() == Type.values().length);
+ }
+
+ public boolean isCaught(int signal) {
+ return isSignalInMask(signal, Type.CAUGHT);
+ }
+
+ public boolean isBlocked(int signal) {
+ return isSignalInMask(signal, Type.BLOCKED);
+ }
+
+ public boolean isPending(int signal) {
+ return isSignalInMask(signal, Type.PENDING);
+ }
+
+ public boolean isIgnored(int signal) {
+ return isSignalInMask(signal, Type.IGNORED);
+ }
+
+ private void checkValid() {
+ if (!isValid()) {
+ throw new IllegalStateException();
+ }
+ }
+
+ private boolean isSignalInMask(int signal, Type mask) {
+ long bit = 1L << (signal - 1);
+ return (getSignalMask(mask) & bit) != 0;
+ }
+
+ private long getSignalMask(Type mask) {
+ checkValid();
+ Long value = maskMap.get(mask);
+ if (value == null) {
+ throw new IllegalStateException();
+ }
+ return value;
+ }
+}
diff --git a/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java b/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java
new file mode 100644
index 0000000..a78ca98
--- /dev/null
+++ b/coverage/src/com/android/textclassifier/testing/TextClassifierInstrumentationListener.java
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.textclassifier.testing;
+
+import android.system.ErrnoException;
+import android.system.Os;
+import android.util.Log;
+
+import androidx.test.internal.runner.listener.InstrumentationRunListener;
+
+import org.junit.runner.Result;
+
+/**
+ * To make native coverage measurement possible.
+ */
+public class TextClassifierInstrumentationListener extends InstrumentationRunListener {
+ private static final String LOG_TAG = "androidtc";
+ // Signal used to trigger a dump of Clang coverage information.
+ // See {@code maybeDumpNativeCoverage} below.
+ private static final int COVERAGE_SIGNAL = 37;
+
+ @Override
+ public void testRunFinished(Result result) throws Exception {
+ maybeDumpNativeCoverage();
+ super.testRunFinished(result);
+ }
+
+ /**
+ * If this test process is instrumented for native coverage, then trigger a dump
+ * of the coverage data and wait until either we detect the dumping has finished or 60 seconds,
+ * whichever is shorter.
+ *
+ * Background: Coverage builds install a signal handler for signal 37 which flushes coverage
+ * data to disk, which may take a few seconds. Tests running as an app process will get
+ * killed with SIGKILL once the app code exits, even if the coverage handler is still running.
+ *
+ * Method: If a handler is installed for signal 37, then assume this is a coverage run and
+ * send signal 37. The handler is non-reentrant and so signal 37 will then be blocked until
+ * the handler completes. So after we send the signal, we loop checking the blocked status
+ * for signal 37 until we hit the 60 second deadline. If the signal is blocked then sleep for
+ * 2 seconds, and if it becomes unblocked then the handler exitted so we can return early.
+ * If the signal is not blocked at the start of the loop then most likely the handler has
+ * not yet been invoked. This should almost never happen as it should get blocked on delivery
+ * when we call {@code Os.kill()}, so sleep for a shorter duration (100ms) and try again. There
+ * is a race condition here where the handler is delayed but then runs for less than 100ms and
+ * gets missed, in which case this method will loop with 100ms sleeps until the deadline.
+ *
+ * In the case where the handler runs for more than 60 seconds, the test process will be allowed
+ * to exit so coverage information may be incomplete.
+ *
+ * There is no API for determining signal dispositions, so this method uses the
+ * {@link SignalMaskInfo} class to read the data from /proc. If there is an error parsing
+ * the /proc data then this method will also loop until the 60s deadline passes.
+ */
+ private void maybeDumpNativeCoverage() {
+ SignalMaskInfo siginfo = new SignalMaskInfo();
+ if (!siginfo.isValid()) {
+ Log.e(LOG_TAG, "Invalid signal info");
+ return;
+ }
+
+ if (!siginfo.isCaught(COVERAGE_SIGNAL)) {
+ // Process is not instrumented for coverage
+ Log.i(LOG_TAG, "Not dumping coverage, no handler installed");
+ return;
+ }
+
+ Log.i(LOG_TAG,
+ String.format("Sending coverage dump signal %d to pid %d uid %d", COVERAGE_SIGNAL,
+ Os.getpid(), Os.getuid()));
+ try {
+ Os.kill(Os.getpid(), COVERAGE_SIGNAL);
+ } catch (ErrnoException e) {
+ Log.e(LOG_TAG, "Unable to send coverage signal", e);
+ return;
+ }
+
+ long start = System.currentTimeMillis();
+ long deadline = start + 60 * 1000L;
+ while (System.currentTimeMillis() < deadline) {
+ siginfo.refresh();
+ try {
+ if (siginfo.isValid() && siginfo.isBlocked(COVERAGE_SIGNAL)) {
+ // Signal is currently blocked so assume a handler is running
+ Thread.sleep(2000L);
+ siginfo.refresh();
+ if (siginfo.isValid() && !siginfo.isBlocked(COVERAGE_SIGNAL)) {
+ // Coverage handler exited while we were asleep
+ Log.i(LOG_TAG,
+ String.format("Coverage dump detected finished after %dms",
+ System.currentTimeMillis() - start));
+ break;
+ }
+ } else {
+ // Coverage signal handler not yet started or invalid siginfo
+ Thread.sleep(100L);
+ }
+ } catch (InterruptedException e) {
+ // ignored
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/java/Android.bp b/java/Android.bp
index 26efacd..f7ff63b 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -65,6 +65,6 @@
genrule {
name: "statslog-textclassifier-java-gen",
tools: ["stats-log-api-gen"],
- cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier --javaPackage com.android.textclassifier --javaClass TextClassifierStatsLog",
- out: ["com/android/textclassifier/TextClassifierStatsLog.java"],
+ cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier --javaPackage com.android.textclassifier.common.statsd --javaClass TextClassifierStatsLog",
+ out: ["com/android/textclassifier/common/statsd/TextClassifierStatsLog.java"],
}
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index 9f02689..7c251e4 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -30,6 +30,8 @@
<uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
+ <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" />
+ <uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/>
<application android:label="@string/tcs_app_name"
android:icon="@drawable/tcs_app_icon"
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index d2c1e38..3d3b359 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -66,7 +66,15 @@
@Override
public void onCreate() {
super.onCreate();
- textClassifier = new TextClassifierImpl(this, new TextClassifierSettings());
+
+ TextClassifierSettings settings = new TextClassifierSettings();
+ ModelFileManager modelFileManager = new ModelFileManager(this, settings);
+ textClassifier = new TextClassifierImpl(this, settings, modelFileManager);
+ }
+
+ @Override
+ public void onDestroy() {
+ super.onDestroy();
}
@Override
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index a6f64d8..0552ad2 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -16,22 +16,33 @@
package com.android.textclassifier;
+import android.content.Context;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
import android.text.TextUtils;
import androidx.annotation.GuardedBy;
+import androidx.annotation.StringDef;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.LangIdModel;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
@@ -42,27 +53,65 @@
import java.util.stream.Collectors;
import javax.annotation.Nullable;
-/** Manages model files that are listed by the model files supplier. */
+/**
+ * Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
+ * model files to load.
+ */
final class ModelFileManager {
private static final String TAG = "ModelFileManager";
+ private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
- private final Supplier<ImmutableList<ModelFile>> modelFileSupplier;
+ private final File downloadModelDir;
+ private final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers;
- public ModelFileManager(Supplier<ImmutableList<ModelFile>> modelFileSupplier) {
- this.modelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ /** Create a ModelFileManager based on hardcoded model file locations. */
+ public ModelFileManager(Context context, TextClassifierSettings settings) {
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(settings);
+ this.downloadModelDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ if (!downloadModelDir.exists()) {
+ downloadModelDir.mkdirs();
+ }
+
+ ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>> suppliersBuilder =
+ ImmutableMap.builder();
+ for (String modelType : ModelType.values()) {
+ suppliersBuilder.put(
+ modelType, new ModelFileSupplierImpl(settings, modelType, downloadModelDir));
+ }
+ this.modelFileSuppliers = suppliersBuilder.build();
}
- /** Returns an immutable list of model files listed by the given model files supplier. */
- public ImmutableList<ModelFile> listModelFiles() {
- return modelFileSupplier.get();
+ @VisibleForTesting
+ ModelFileManager(
+ File downloadModelDir,
+ ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers) {
+ this.downloadModelDir = Preconditions.checkNotNull(downloadModelDir);
+ this.modelFileSuppliers = Preconditions.checkNotNull(modelFileSuppliers);
+ }
+
+ /**
+ * Returns an immutable list of model files listed by the given model files supplier.
+ *
+ * @param modelType which type of model files to look for
+ */
+ public ImmutableList<ModelFile> listModelFiles(@ModelType.ModelTypeDef String modelType) {
+ if (modelFileSuppliers.containsKey(modelType)) {
+ return modelFileSuppliers.get(modelType).get();
+ }
+ return ImmutableList.of();
}
/**
* Returns the best model file for the given localelist, {@code null} if nothing is found.
*
- * @param localeList the required locales, use {@code null} if there is no preference.
+ * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+ * @param localeList an ordered list of user preferences for locales, use {@code null} if there is
+ * no preference.
*/
- public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
+ @Nullable
+ public ModelFile findBestModelFile(
+ @ModelType.ModelTypeDef String modelType, @Nullable LocaleList localeList) {
final String languages =
localeList == null || localeList.isEmpty()
? LocaleList.getDefault().toLanguageTags()
@@ -70,7 +119,7 @@
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
ModelFile bestModel = null;
- for (ModelFile model : listModelFiles()) {
+ for (ModelFile model : listModelFiles(modelType)) {
if (model.isAnyLanguageSupported(languageRangeList)) {
if (model.isPreferredTo(bestModel)) {
bestModel = model;
@@ -80,9 +129,108 @@
return bestModel;
}
+ /**
+ * Returns a {@link File} that represents the destination to download a model.
+ *
+ * <p>Each model file's name is uniquely formatted based on its unique remote URL address.
+ *
+ * <p>{@link ModelDownloadManager} needs to call this to get the right location and file name.
+ *
+ * @param modelType the type of the model image to download
+ * @param url the unique remote url of the model image
+ */
+ public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String url) {
+ String fileName = String.format("%s.%d.model", modelType, url.hashCode());
+ return new File(downloadModelDir, fileName);
+ }
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("ModelFileManager:");
+ printWriter.increaseIndent();
+ for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ printWriter.println(modelType + " model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFile modelFile : listModelFiles(modelType)) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ }
+ printWriter.decreaseIndent();
+ }
+
/** Default implementation of the model file supplier. */
- public static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
- private final File updatedModelFile;
+ @VisibleForTesting
+ static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
+ private static final String FACTORY_MODEL_DIR = "/etc/textclassifier/";
+
+ private static final class ModelFileInfo {
+ private final String modelNameRegex;
+ private final String configUpdaterModelPath;
+ private final Function<Integer, Integer> versionSupplier;
+ private final Function<Integer, String> supportedLocalesSupplier;
+
+ public ModelFileInfo(
+ String modelNameRegex,
+ String configUpdaterModelPath,
+ Function<Integer, Integer> versionSupplier,
+ Function<Integer, String> supportedLocalesSupplier) {
+ this.modelNameRegex = Preconditions.checkNotNull(modelNameRegex);
+ this.configUpdaterModelPath = Preconditions.checkNotNull(configUpdaterModelPath);
+ this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
+ this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ }
+
+ public String getModelNameRegex() {
+ return modelNameRegex;
+ }
+
+ public String getConfigUpdaterModelPath() {
+ return configUpdaterModelPath;
+ }
+
+ public Function<Integer, Integer> getVersionSupplier() {
+ return versionSupplier;
+ }
+
+ public Function<Integer, String> getSupportedLocalesSupplier() {
+ return supportedLocalesSupplier;
+ }
+ }
+
+ private static final ImmutableMap<String, ModelFileInfo> MODEL_FILE_INFO_MAP =
+ ImmutableMap.<String, ModelFileInfo>builder()
+ .put(
+ ModelType.ANNOTATOR,
+ new ModelFileInfo(
+ "(annotator|textclassifier)\\.(.*)\\.model",
+ "/data/misc/textclassifier/textclassifier.model",
+ AnnotatorModel::getVersion,
+ AnnotatorModel::getLocales))
+ .put(
+ ModelType.LANG_ID,
+ new ModelFileInfo(
+ "lang_id.model",
+ "/data/misc/textclassifier/lang_id.model",
+ LangIdModel::getVersion,
+ fd -> ModelFile.LANGUAGE_INDEPENDENT))
+ .put(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new ModelFileInfo(
+ "actions_suggestions\\.(.*)\\.model",
+ "/data/misc/textclassifier/actions_suggestions.model",
+ ActionsSuggestionsModel::getVersion,
+ ActionsSuggestionsModel::getLocales))
+ .build();
+
+ private final TextClassifierSettings settings;
+ @ModelType.ModelTypeDef private final String modelType;
+ private final File configUpdaterModelFile;
+ private final File downloaderModelDir;
private final File factoryModelDir;
private final Pattern modelFilenamePattern;
private final Function<Integer, Integer> versionSupplier;
@@ -93,14 +241,36 @@
private ImmutableList<ModelFile> factoryModels;
public ModelFileSupplierImpl(
+ TextClassifierSettings settings,
+ @ModelType.ModelTypeDef String modelType,
+ File downloaderModelDir) {
+ this(
+ settings,
+ modelType,
+ new File(FACTORY_MODEL_DIR),
+ MODEL_FILE_INFO_MAP.get(modelType).getModelNameRegex(),
+ new File(MODEL_FILE_INFO_MAP.get(modelType).getConfigUpdaterModelPath()),
+ downloaderModelDir,
+ MODEL_FILE_INFO_MAP.get(modelType).getVersionSupplier(),
+ MODEL_FILE_INFO_MAP.get(modelType).getSupportedLocalesSupplier());
+ }
+
+ @VisibleForTesting
+ ModelFileSupplierImpl(
+ TextClassifierSettings settings,
+ @ModelType.ModelTypeDef String modelType,
File factoryModelDir,
- String factoryModelFileNameRegex,
- File updatedModelFile,
+ String modelFileNameRegex,
+ File configUpdaterModelFile,
+ File downloaderModelDir,
Function<Integer, Integer> versionSupplier,
Function<Integer, String> supportedLocalesSupplier) {
- this.updatedModelFile = Preconditions.checkNotNull(updatedModelFile);
+ this.settings = Preconditions.checkNotNull(settings);
+ this.modelType = Preconditions.checkNotNull(modelType);
this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
+ this.modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(modelFileNameRegex));
+ this.configUpdaterModelFile = Preconditions.checkNotNull(configUpdaterModelFile);
+ this.downloaderModelDir = Preconditions.checkNotNull(downloaderModelDir);
this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
}
@@ -108,9 +278,12 @@
@Override
public ImmutableList<ModelFile> get() {
final List<ModelFile> modelFiles = new ArrayList<>();
- // The update model has the highest precedence.
- if (updatedModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(updatedModelFile);
+ // The dwonloader and config updater model have higher precedences.
+ if (downloaderModelDir.exists() && settings.isModelDownloadManagerEnabled()) {
+ modelFiles.addAll(getMatchedModelFiles(downloaderModelDir));
+ }
+ if (configUpdaterModelFile.exists()) {
+ final ModelFile updatedModel = createModelFile(configUpdaterModelFile);
if (updatedModel != null) {
modelFiles.add(updatedModel);
}
@@ -118,28 +291,28 @@
// Factory models should never have overlapping locales, so the order doesn't matter.
synchronized (lock) {
if (factoryModels == null) {
- factoryModels = getFactoryModels();
+ factoryModels = getMatchedModelFiles(factoryModelDir);
}
modelFiles.addAll(factoryModels);
}
return ImmutableList.copyOf(modelFiles);
}
- private ImmutableList<ModelFile> getFactoryModels() {
- List<ModelFile> factoryModelFiles = new ArrayList<>();
- if (factoryModelDir.exists() && factoryModelDir.isDirectory()) {
- final File[] files = factoryModelDir.listFiles();
+ private ImmutableList<ModelFile> getMatchedModelFiles(File parentDir) {
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ if (parentDir.exists() && parentDir.isDirectory()) {
+ final File[] files = parentDir.listFiles();
for (File file : files) {
final Matcher matcher = modelFilenamePattern.matcher(file.getName());
if (matcher.matches() && file.isFile()) {
final ModelFile model = createModelFile(file);
if (model != null) {
- factoryModelFiles.add(model);
+ modelFilesBuilder.add(model);
}
}
}
}
- return ImmutableList.copyOf(factoryModelFiles);
+ return modelFilesBuilder.build();
}
/** Returns null if the path did not point to a compatible model. */
@@ -166,6 +339,7 @@
supportedLocales.add(Locale.forLanguageTag(langTag));
}
return new ModelFile(
+ modelType,
file,
version,
supportedLocales,
@@ -196,6 +370,7 @@
public static final class ModelFile {
public static final String LANGUAGE_INDEPENDENT = "*";
+ @ModelType.ModelTypeDef private final String modelType;
private final File file;
private final int version;
private final List<Locale> supportedLocales;
@@ -203,11 +378,13 @@
private final boolean languageIndependent;
public ModelFile(
+ @ModelType.ModelTypeDef String modelType,
File file,
int version,
List<Locale> supportedLocales,
String supportedLocalesStr,
boolean languageIndependent) {
+ this.modelType = Preconditions.checkNotNull(modelType);
this.file = Preconditions.checkNotNull(file);
this.version = version;
this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
@@ -215,6 +392,12 @@
this.languageIndependent = languageIndependent;
}
+ /** Returns the type of this model, defined in {@link ModelType}. */
+ @ModelType.ModelTypeDef
+ public String getModelType() {
+ return modelType;
+ }
+
/** Returns the absolute path to the model file. */
public String getPath() {
return file.getAbsolutePath();
@@ -236,16 +419,6 @@
return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
}
- /** Returns an immutable lists of supported locales. */
- public List<Locale> getSupportedLocales() {
- return Collections.unmodifiableList(supportedLocales);
- }
-
- /** Returns the original supported locals string read from the model file. */
- public String getSupportedLocalesStr() {
- return supportedLocalesStr;
- }
-
/** Returns if this model file is preferred to the given one. */
public boolean isPreferredTo(@Nullable ModelFile model) {
// A model is preferred to no model.
@@ -294,7 +467,8 @@
public String toString() {
return String.format(
Locale.US,
- "ModelFile { path=%s name=%s version=%d locales=%s }",
+ "ModelFile { type=%s path=%s name=%s version=%d locales=%s }",
+ modelType,
getPath(),
getName(),
version,
@@ -307,5 +481,25 @@
.map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
.collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
}
+
+ /** Effectively an enum class to represent types of models. */
+ public static final class ModelType {
+ @Retention(RetentionPolicy.SOURCE)
+ @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
+ public @interface ModelTypeDef {}
+
+ public static final String ANNOTATOR = "annotator";
+ public static final String LANG_ID = "lang_id";
+ public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
+
+ public static final ImmutableList<String> VALUES =
+ ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
+
+ public static ImmutableList<String> values() {
+ return VALUES;
+ }
+
+ private ModelType() {}
+ }
}
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 429ed6b..7f297f6 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -43,6 +43,7 @@
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -84,25 +85,8 @@
private static final String TAG = "TextClassifierImpl";
- private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
- // Annotator
- private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
- "textclassifier\\.(.*)\\.model";
- private static final File ANNOTATOR_UPDATED_MODEL_FILE =
- new File("/data/misc/textclassifier/textclassifier.model");
-
- // LangIdModel
- private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
- private static final File UPDATED_LANG_ID_MODEL_FILE =
- new File("/data/misc/textclassifier/lang_id.model");
-
- // Actions
- private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
- "actions_suggestions\\.(.*)\\.model";
- private static final File UPDATED_ACTIONS_MODEL =
- new File("/data/misc/textclassifier/actions_suggestions.model");
-
private final Context context;
+ private final ModelFileManager modelFileManager;
private final TextClassifier fallback;
private final GenerateLinksLogger generateLinksLogger;
@@ -131,46 +115,25 @@
private final TextClassifierSettings settings;
- private final ModelFileManager annotatorModelFileManager;
- private final ModelFileManager langIdModelFileManager;
- private final ModelFileManager actionsModelFileManager;
private final TemplateIntentFactory templateIntentFactory;
- TextClassifierImpl(Context context, TextClassifierSettings settings, TextClassifier fallback) {
+ TextClassifierImpl(
+ Context context,
+ TextClassifierSettings settings,
+ ModelFileManager modelFileManager,
+ TextClassifier fallback) {
this.context = Preconditions.checkNotNull(context);
- this.fallback = Preconditions.checkNotNull(fallback);
this.settings = Preconditions.checkNotNull(settings);
- generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
- annotatorModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
- ANNOTATOR_UPDATED_MODEL_FILE,
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales));
- langIdModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_LANG_ID_MODEL_FILE,
- LangIdModel::getVersion,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
- actionsModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_ACTIONS_MODEL,
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales));
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
templateIntentFactory = new TemplateIntentFactory();
}
- TextClassifierImpl(Context context, TextClassifierSettings settings) {
- this(context, settings, TextClassifier.NO_OP);
+ TextClassifierImpl(
+ Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
+ this(context, settings, modelFileManager, TextClassifier.NO_OP);
}
@WorkerThread
@@ -522,7 +485,7 @@
synchronized (lock) {
localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
- annotatorModelFileManager.findBestModelFile(localeList);
+ modelFileManager.findBestModelFile(ModelType.ANNOTATOR, localeList);
if (bestModel == null) {
throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
}
@@ -553,7 +516,8 @@
private Optional<LangIdModel> getLangIdImpl() {
synchronized (lock) {
- final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null);
+ final ModelFileManager.ModelFile bestModel =
+ modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localeList= */ null);
if (bestModel == null) {
return Optional.absent();
}
@@ -586,7 +550,8 @@
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
final ModelFileManager.ModelFile bestModel =
- actionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ modelFileManager.findBestModelFile(
+ ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault());
if (bestModel == null) {
return null;
}
@@ -765,27 +730,12 @@
void dump(IndentingPrintWriter printWriter) {
synchronized (lock) {
printWriter.println("TextClassifierImpl:");
+
printWriter.increaseIndent();
- printWriter.println("Annotator model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("LangID model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("Actions model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
+ modelFileManager.dump(printWriter);
printWriter.printPair("mFallback", fallback);
printWriter.decreaseIndent();
+
printWriter.println();
settings.dump(printWriter);
}
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
index 3decd38..005bd7c 100644
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -19,6 +19,8 @@
import android.provider.DeviceConfig;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
+import androidx.annotation.NonNull;
+import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
@@ -42,6 +44,8 @@
* @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
*/
public final class TextClassifierSettings {
+ public static final String NAMESPACE = DeviceConfig.NAMESPACE_TEXTCLASSIFIER;
+
private static final String DELIMITER = ":";
/** Whether the user language profile feature is enabled. */
@@ -101,6 +105,27 @@
*/
private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
"detect_languages_from_text_enabled";
+
+ /** Whether to enable model downloading with ModelDownloadManager */
+ @VisibleForTesting
+ static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
+ /** The prefix of the URL to download models. E.g. https://www.gstatic.com/android/ */
+ @VisibleForTesting static final String ANNOTATOR_URL_PREFIX = "annotator_url_prefix";
+
+ @VisibleForTesting static final String LANG_ID_URL_PREFIX = "lang_id_url_prefix";
+
+ @VisibleForTesting
+ static final String ACTIONS_SUGGESTIONS_URL_PREFIX = "actions_suggestions_url_prefix";
+ /** The suffix of the URL to download models. E.g. q/711/en.fb */
+ @VisibleForTesting
+ static final String PRIMARY_ANNOTATOR_URL_SUFFIX = "primary_annotator_url_suffix";
+
+ @VisibleForTesting static final String PRIMARY_LANG_ID_URL_SUFFIX = "primary_lang_id_url_suffix";
+
+ @VisibleForTesting
+ static final String PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX =
+ "primary_actions_suggestions_url_suffix";
+
/**
* A colon(:) separated string that specifies the configuration to use when including surrounding
* context text in language detection queries.
@@ -164,34 +189,96 @@
private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
+ private static final String ANNOTATOR_URL_PREFIX_DEFAULT =
+ "https://www.gstatic.com/android/text_classifier/";
+ private static final String LANG_ID_URL_PREFIX_DEFAULT =
+ "https://www.gstatic.com/android/text_classifier/langid/";
+ private static final String ACTIONS_SUGGESTIONS_URL_PREFIX_DEFAULT =
+ "https://www.gstatic.com/android/text_classifier/actions/";
+ private static final String PRIMARY_ANNOTATOR_URL_SUFFIX_DEFAULT = "";
+ private static final String PRIMARY_LANG_ID_URL_SUFFIX_DEFAULT = "";
+ private static final String PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX_DEFAULT = "";
private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+ @VisibleForTesting
+ interface IDeviceConfig {
+ default int getInt(@NonNull String namespace, @NonNull String name, @NonNull int defaultValue) {
+ return defaultValue;
+ }
+
+ default float getFloat(
+ @NonNull String namespace, @NonNull String name, @NonNull float defaultValue) {
+ return defaultValue;
+ }
+
+ default String getString(
+ @NonNull String namespace, @NonNull String name, @Nullable String defaultValue) {
+ return defaultValue;
+ }
+
+ default boolean getBoolean(
+ @NonNull String namespace, @NonNull String name, boolean defaultValue) {
+ return defaultValue;
+ }
+ }
+
+ private static final IDeviceConfig DEFAULT_DEVICE_CONFIG =
+ new IDeviceConfig() {
+ @Override
+ public int getInt(
+ @NonNull String namespace, @NonNull String name, @NonNull int defaultValue) {
+ return DeviceConfig.getInt(namespace, name, defaultValue);
+ }
+
+ @Override
+ public float getFloat(
+ @NonNull String namespace, @NonNull String name, @NonNull float defaultValue) {
+ return DeviceConfig.getFloat(namespace, name, defaultValue);
+ }
+
+ @Override
+ public String getString(
+ @NonNull String namespace, @NonNull String name, @NonNull String defaultValue) {
+ return DeviceConfig.getString(namespace, name, defaultValue);
+ }
+
+ @Override
+ public boolean getBoolean(
+ @NonNull String namespace, @NonNull String name, @NonNull boolean defaultValue) {
+ return DeviceConfig.getBoolean(namespace, name, defaultValue);
+ }
+ };
+
+ private final IDeviceConfig deviceConfig;
+
+ public TextClassifierSettings() {
+ this(DEFAULT_DEVICE_CONFIG);
+ }
+
+ @VisibleForTesting
+ TextClassifierSettings(IDeviceConfig deviceConfig) {
+ this.deviceConfig = deviceConfig;
+ }
+
public int getSuggestSelectionMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ return deviceConfig.getInt(
+ NAMESPACE, SUGGEST_SELECTION_MAX_RANGE_LENGTH, SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
}
public int getClassifyTextMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
+ return deviceConfig.getInt(
+ NAMESPACE, CLASSIFY_TEXT_MAX_RANGE_LENGTH, CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
}
public int getGenerateLinksMaxTextLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_MAX_TEXT_LENGTH,
- GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
+ return deviceConfig.getInt(
+ NAMESPACE, GENERATE_LINKS_MAX_TEXT_LENGTH, GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
}
public int getGenerateLinksLogSampleRate() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_LOG_SAMPLE_RATE,
- GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
+ return deviceConfig.getInt(
+ NAMESPACE, GENERATE_LINKS_LOG_SAMPLE_RATE, GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
}
public List<String> getEntityListDefault() {
@@ -217,51 +304,79 @@
}
public float getLangIdThresholdOverride() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANG_ID_THRESHOLD_OVERRIDE,
- LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
+ return deviceConfig.getFloat(
+ NAMESPACE, LANG_ID_THRESHOLD_OVERRIDE, LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
}
public float getTranslateActionThreshold() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_ACTION_THRESHOLD,
- TRANSLATE_ACTION_THRESHOLD_DEFAULT);
+ return deviceConfig.getFloat(
+ NAMESPACE, TRANSLATE_ACTION_THRESHOLD, TRANSLATE_ACTION_THRESHOLD_DEFAULT);
}
public boolean isUserLanguageProfileEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- USER_LANGUAGE_PROFILE_ENABLED,
- USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
+ return deviceConfig.getBoolean(
+ NAMESPACE, USER_LANGUAGE_PROFILE_ENABLED, USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
}
public boolean isTemplateIntentFactoryEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TEMPLATE_INTENT_FACTORY_ENABLED,
- TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+ return deviceConfig.getBoolean(
+ NAMESPACE, TEMPLATE_INTENT_FACTORY_ENABLED, TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
}
public boolean isTranslateInClassificationEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ return deviceConfig.getBoolean(
+ NAMESPACE,
TRANSLATE_IN_CLASSIFICATION_ENABLED,
TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
}
public boolean isDetectLanguagesFromTextEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
+ return deviceConfig.getBoolean(
+ NAMESPACE, DETECT_LANGUAGES_FROM_TEXT_ENABLED, DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
}
public float[] getLangIdContextSettings() {
return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
}
+ public boolean isModelDownloadManagerEnabled() {
+ return deviceConfig.getBoolean(
+ NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
+ }
+
+ public String getModelURLPrefix(@ModelType.ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return deviceConfig.getString(
+ NAMESPACE, ANNOTATOR_URL_PREFIX, ANNOTATOR_URL_PREFIX_DEFAULT);
+ case ModelType.LANG_ID:
+ return deviceConfig.getString(NAMESPACE, LANG_ID_URL_PREFIX, LANG_ID_URL_PREFIX_DEFAULT);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return deviceConfig.getString(
+ NAMESPACE, ACTIONS_SUGGESTIONS_URL_PREFIX, ACTIONS_SUGGESTIONS_URL_PREFIX_DEFAULT);
+ default:
+ return "";
+ }
+ }
+
+ public String getPrimaryModelURLSuffix(@ModelType.ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return deviceConfig.getString(
+ NAMESPACE, PRIMARY_ANNOTATOR_URL_SUFFIX, PRIMARY_ANNOTATOR_URL_SUFFIX_DEFAULT);
+ case ModelType.LANG_ID:
+ return deviceConfig.getString(
+ NAMESPACE, PRIMARY_LANG_ID_URL_SUFFIX, PRIMARY_LANG_ID_URL_SUFFIX_DEFAULT);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return deviceConfig.getString(
+ NAMESPACE,
+ PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX,
+ PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX_DEFAULT);
+ default:
+ return "";
+ }
+ }
+
void dump(IndentingPrintWriter pw) {
pw.println("TextClassifierSettings:");
pw.increaseIndent();
@@ -282,17 +397,26 @@
pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
+ pw.printPair("model_download_manager_enabled", isModelDownloadManagerEnabled());
+ pw.printPair("annotator_url_prefix", getModelURLPrefix(ModelType.ANNOTATOR));
+ pw.printPair("lang_id_url_prefix", getModelURLPrefix(ModelType.LANG_ID));
+ pw.printPair(
+ "actions_suggestions_url_prefix", getModelURLPrefix(ModelType.ACTIONS_SUGGESTIONS));
+ pw.decreaseIndent();
+ pw.printPair("primary_annotator_url_suffix", getPrimaryModelURLSuffix(ModelType.ANNOTATOR));
+ pw.printPair("primary_lang_id_url_suffix", getPrimaryModelURLSuffix(ModelType.LANG_ID));
+ pw.printPair(
+ "primary_actions_suggestions_url_suffix",
+ getPrimaryModelURLSuffix(ModelType.ACTIONS_SUGGESTIONS));
pw.decreaseIndent();
}
- private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ private List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
+ return parse(deviceConfig.getString(NAMESPACE, key, null), defaultValue);
}
- private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ private float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
+ return parse(deviceConfig.getString(NAMESPACE, key, null), defaultValue);
}
private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
index c132749..45785f1 100644
--- a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -16,8 +16,6 @@
package com.android.textclassifier.common.statsd;
-import android.util.StatsEvent;
-import android.util.StatsLog;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
import androidx.collection.ArrayMap;
@@ -142,24 +140,20 @@
Optional<ModelInfo> langIdModel) {
String annotatorModelName = annotatorModel.transform(ModelInfo::toModelName).or("");
String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or("");
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID)
- .writeString(callId)
- .writeInt(TextClassifierEvent.TYPE_LINKS_GENERATED)
- .writeString(annotatorModelName)
- .writeInt(TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN)
- .writeInt(/* eventIndex */ 0)
- .writeString(entityType)
- .writeInt(stats.numLinks)
- .writeInt(stats.numLinksTextLength)
- .writeInt(text.length())
- .writeLong(latencyMs)
- .writeString(callingPackageName)
- .writeString(langIdModelName)
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID,
+ callId,
+ TextClassifierEvent.TYPE_LINKS_GENERATED,
+ annotatorModelName,
+ TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
+ /* eventIndex */ 0,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName,
+ langIdModelName);
if (TcLog.ENABLE_FULL_LOGGING) {
TcLog.v(
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
index 41f546c..307be6b 100644
--- a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -19,8 +19,6 @@
import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Strings.nullToEmpty;
-import android.util.StatsEvent;
-import android.util.StatsLog;
import android.view.textclassifier.TextClassifier;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils;
@@ -69,24 +67,20 @@
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.TextSelectionEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TEXT_SELECTION_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(getEventType(event))
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeInt(event.getEventIndex())
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeInt(event.getRelativeWordStartIndex())
- .writeInt(event.getRelativeWordEndIndex())
- .writeInt(event.getRelativeSuggestedWordStartIndex())
- .writeInt(event.getRelativeSuggestedWordEndIndex())
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TEXT_SELECTION_EVENT_ATOM_ID,
+ sessionId == null ? null : sessionId.getValue(),
+ getEventType(event),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ event.getRelativeWordStartIndex(),
+ event.getRelativeWordEndIndex(),
+ event.getRelativeSuggestedWordStartIndex(),
+ event.getRelativeSuggestedWordEndIndex(),
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null));
}
private static int getEventType(TextClassifierEvent.TextSelectionEvent event) {
@@ -103,24 +97,20 @@
private static void logTextLinkifyEvent(
TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
ImmutableList<String> modelNames = getModelNames(event);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(TEXT_LINKIFY_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeInt(event.getEventIndex())
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeInt(/* numOfLinks */ 0)
- .writeInt(/* linkedTextLength */ 0)
- .writeInt(/* textLength */ 0)
- .writeLong(/* latencyInMillis */ 0L)
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ TEXT_LINKIFY_EVENT_ATOM_ID,
+ sessionId == null ? null : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ /* numOfLinks */ 0,
+ /* linkedTextLength */ 0,
+ /* textLength */ 0,
+ /* latencyInMillis */ 0L,
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null));
}
private static void logConversationActionsEvent(
@@ -128,46 +118,37 @@
TextClassifierEvent.ConversationActionsEvent event) {
String resultId = nullToEmpty(event.getResultId());
ImmutableList<String> modelNames = ResultIdUtils.getModelNames(resultId);
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(CONVERSATION_ACTIONS_EVENT_ATOM_ID)
- // TODO: Update ExtServices to set the session id.
- .writeString(
- sessionId == null
- ? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
- : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 1))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 2))
- .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
- .writeString(getPackageName(event))
- .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
- .writeString(getItemAt(modelNames, /* index= */ 2, /* defaultValue= */ null))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ CONVERSATION_ACTIONS_EVENT_ATOM_ID,
+ // TODO: Update ExtServices to set the session id.
+ sessionId == null
+ ? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
+ : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getItemAt(event.getEntityTypes(), /* index= */ 1),
+ getItemAt(event.getEntityTypes(), /* index= */ 2),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getPackageName(event),
+ getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null),
+ getItemAt(modelNames, /* index= */ 2, /* defaultValue= */ null));
}
private static void logLanguageDetectionEvent(
@Nullable TextClassificationSessionId sessionId,
TextClassifierEvent.LanguageDetectionEvent event) {
- StatsEvent statsEvent =
- StatsEvent.newBuilder()
- .setAtomId(LANGUAGE_DETECTION_EVENT_ATOM_ID)
- .writeString(sessionId == null ? null : sessionId.getValue())
- .writeInt(event.getEventType())
- .writeString(getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null))
- .writeInt(getWidgetType(event))
- .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
- .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
- .writeInt(getIntAt(event.getActionIndices(), /* index= */ 0))
- .writeString(getPackageName(event))
- .usePooledBuffer()
- .build();
- StatsLog.write(statsEvent);
+ TextClassifierStatsLog.write(
+ LANGUAGE_DETECTION_EVENT_ATOM_ID,
+ sessionId == null ? null : sessionId.getValue(),
+ event.getEventType(),
+ getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getIntAt(event.getActionIndices(), /* index= */ 0),
+ getPackageName(event));
}
@Nullable
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index 15ec570..a0cd0ec 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -36,6 +36,7 @@
"TextClassifierServiceLib",
"statsdprotolite",
"textclassifierprotoslite",
+ "TextClassifierCoverageLib"
],
jni_libs: [
@@ -48,9 +49,11 @@
],
plugins: ["androidx.room_room-compiler-plugin",],
- platform_apis: true,
+ sdk_version: "system_current",
use_embedded_native_libs: true,
compile_multilib: "both",
instrumentation_for: "TextClassifierService",
+
+ data: ["testdata/*"]
}
\ No newline at end of file
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 4964caf..5fde758 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -4,6 +4,7 @@
<uses-sdk android:minSdkVersion="29" android:targetSdkVersion="30"/>
<uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
+ <uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE"/>
<application>
<uses-library android:name="android.test.runner"/>
diff --git a/java/tests/instrumentation/AndroidTest.xml b/java/tests/instrumentation/AndroidTest.xml
index e02a338..48a3f09 100644
--- a/java/tests/instrumentation/AndroidTest.xml
+++ b/java/tests/instrumentation/AndroidTest.xml
@@ -22,9 +22,15 @@
<option name="test-file-name" value="TextClassifierServiceTest.apk" />
</target_preparer>
+ <target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
+ <option name="cleanup" value="true" />
+ <option name="push" value="testdata->/data/local/tmp/TextClassifierServiceTest/testdata" />
+ </target_preparer>
+
<test class="com.android.tradefed.testtype.AndroidJUnitTest" >
<option name="package" value="com.android.textclassifier.tests" />
<option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ <option name="device-listeners" value="com.android.textclassifier.testing.TextClassifierInstrumentationListener" />
</test>
<object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
index 06d47d6..8ef3908 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -17,6 +17,8 @@
package com.android.textclassifier;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import android.os.LocaleList;
@@ -27,6 +29,7 @@
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
@@ -43,33 +46,54 @@
@SmallTest
@RunWith(AndroidJUnit4.class)
-public class ModelFileManagerTest {
+public final class ModelFileManagerTest {
private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+ private static final String URL = "http://www.gstatic.com/android/text_classifier/q/711/en.fb";
+ private static final String URL_2 = "http://www.gstatic.com/android/text_classifier/q/712/en.fb";
+
+ @ModelFile.ModelType.ModelTypeDef
+ private static final String MODEL_TYPE = ModelFile.ModelType.ANNOTATOR;
+
+ @ModelFile.ModelType.ModelTypeDef
+ private static final String MODEL_TYPE_2 = ModelFile.ModelType.LANG_ID;
+
@Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
- private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
- private ModelFileManager modelFileManager;
+ @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
+
private File rootTestDir;
private File factoryModelDir;
- private File updatedModelFile;
+ private File configUpdaterModelFile;
+ private File downloaderModelDir;
+
+ private ModelFileManager modelFileManager;
+ private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
- modelFileManager = new ModelFileManager(modelFileSupplier);
- rootTestDir = ApplicationProvider.getApplicationContext().getCacheDir();
- factoryModelDir = new File(rootTestDir, "factory");
- updatedModelFile = new File(rootTestDir, "updated.model");
+ rootTestDir =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
+ factoryModelDir = new File(rootTestDir, "factory");
+ configUpdaterModelFile = new File(rootTestDir, "configupdater.model");
+ downloaderModelDir = new File(rootTestDir, "downloader");
+
+ modelFileManager =
+ new ModelFileManager(downloaderModelDir, ImmutableMap.of(MODEL_TYPE, modelFileSupplier));
modelFileSupplierImpl =
new ModelFileManager.ModelFileSupplierImpl(
+ new TextClassifierSettings(mockDeviceConfig),
+ MODEL_TYPE,
factoryModelDir,
"test\\d.model",
- updatedModelFile,
+ configUpdaterModelFile,
+ downloaderModelDir,
fd -> 1,
fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
rootTestDir.mkdirs();
factoryModelDir.mkdirs();
+ downloaderModelDir.mkdirs();
Locale.setDefault(DEFAULT_LOCALE);
}
@@ -82,10 +106,11 @@
@Test
public void get() {
ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles();
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(MODEL_TYPE);
assertThat(modelFiles).hasSize(1);
assertThat(modelFiles.get(0)).isEqualTo(modelFile);
@@ -94,14 +119,16 @@
@Test
public void findBestModel_versionCode() {
ModelFileManager.ModelFile olderModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
ModelFileManager.ModelFile newerModelFile =
- new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.getEmptyLocaleList());
assertThat(bestModelFile).isEqualTo(newerModelFile);
}
@@ -110,10 +137,12 @@
public void findBestModel_languageDependentModelIsPreferred() {
Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
1,
Collections.singletonList(locale),
@@ -123,7 +152,8 @@
.thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags(locale.toLanguageTag()));
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE, LocaleList.forLanguageTags(locale.toLanguageTag()));
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@@ -131,10 +161,12 @@
public void findBestModel_noMatchedLanguageModel() {
Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, Collections.emptyList(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
1,
Collections.singletonList(locale),
@@ -145,17 +177,19 @@
.thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@Test
public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
1,
Collections.singletonList(DEFAULT_LOCALE),
@@ -166,7 +200,7 @@
.thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
}
@@ -174,6 +208,7 @@
public void findBestModel_languageIsMoreImportantThanVersion() {
ModelFileManager.ModelFile matchButOlderModel =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("fr")),
@@ -182,6 +217,7 @@
ModelFileManager.ModelFile mismatchButNewerModel =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
2,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -192,7 +228,7 @@
.thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("fr"));
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
assertThat(bestModelFile).isEqualTo(matchButOlderModel);
}
@@ -200,6 +236,7 @@
public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
ModelFileManager.ModelFile matchLocaleModel =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -207,20 +244,41 @@
false);
ModelFileManager.ModelFile languageIndependentModel =
- new ModelFileManager.ModelFile(new File("/path/a"), 2, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(), "", true);
when(modelFileSupplier.get())
.thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(LocaleList.forLanguageTags("ja"));
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
assertThat(bestModelFile).isEqualTo(matchLocaleModel);
}
@Test
+ public void getDownloadTargetFile_targetFileInCorrectDir() {
+ File targetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
+ assertThat(targetFile.getParentFile()).isEqualTo(downloaderModelDir);
+ }
+
+ @Test
+ public void getDownloadTargetFile_filePathIsUnique() {
+ File targetFileOne = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
+ File targetFileTwo = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
+ File targetFileThree = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_2);
+ File targetFileFour = modelFileManager.getDownloadTargetFile(MODEL_TYPE_2, URL);
+
+ assertThat(targetFileOne.getAbsolutePath()).isEqualTo(targetFileTwo.getAbsolutePath());
+ assertThat(targetFileOne.getAbsolutePath()).isNotEqualTo(targetFileThree.getAbsolutePath());
+ assertThat(targetFileOne.getAbsolutePath()).isNotEqualTo(targetFileFour.getAbsolutePath());
+ assertThat(targetFileThree.getAbsolutePath()).isNotEqualTo(targetFileFour.getAbsolutePath());
+ }
+
+ @Test
public void modelFileEquals() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -229,6 +287,7 @@
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -242,6 +301,7 @@
public void modelFile_different() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -250,6 +310,7 @@
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/b"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -263,6 +324,7 @@
public void modelFile_getPath() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -276,6 +338,7 @@
public void modelFile_getName() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -289,6 +352,7 @@
public void modelFile_isPreferredTo_languageDependentIsBetter() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
1,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -296,7 +360,8 @@
false);
ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -305,6 +370,7 @@
public void modelFile_isPreferredTo_version() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
+ MODEL_TYPE,
new File("/path/a"),
2,
Collections.singletonList(Locale.forLanguageTag("ja")),
@@ -312,7 +378,8 @@
false);
ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(new File("/path/b"), 1, Collections.emptyList(), "", false);
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, new File("/path/b"), 1, ImmutableList.of(), "", false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -321,7 +388,7 @@
public void modelFile_toModelInfo() {
ModelFileManager.ModelFile modelFile =
new ModelFileManager.ModelFile(
- new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
ModelInfo modelInfo = modelFile.toModelInfo();
@@ -331,9 +398,11 @@
@Test
public void modelFile_toModelInfos() {
ModelFile englishModelFile =
- new ModelFile(new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
+ new ModelFile(
+ MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
ModelFile japaneseModelFile =
- new ModelFile(new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ new ModelFile(
+ MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
ImmutableList<Optional<ModelInfo>> modelInfos =
ModelFileManager.ModelFile.toModelInfos(
@@ -349,7 +418,14 @@
@Test
public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- updatedModelFile.createNewFile();
+ when(mockDeviceConfig.getBoolean(
+ eq(TextClassifierSettings.NAMESPACE),
+ eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
+ anyBoolean()))
+ .thenReturn(false);
+ configUpdaterModelFile.createNewFile();
+ File downloaderModelFile = new File(downloaderModelDir, "test0.model");
+ downloaderModelFile.createNewFile();
File model1 = new File(factoryModelDir, "test1.model");
model1.createNewFile();
File model2 = new File(factoryModelDir, "test2.model");
@@ -363,7 +439,34 @@
assertThat(modelFiles).hasSize(3);
assertThat(modelFilePaths)
.containsExactly(
- updatedModelFile.getAbsolutePath(), model1.getAbsolutePath(), model2.getAbsolutePath());
+ configUpdaterModelFile.getAbsolutePath(),
+ model1.getAbsolutePath(),
+ model2.getAbsolutePath());
+ }
+
+ @Test
+ public void testFileSupplierImpl_includeDownloaderFile() throws IOException {
+ when(mockDeviceConfig.getBoolean(
+ eq(TextClassifierSettings.NAMESPACE),
+ eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
+ anyBoolean()))
+ .thenReturn(true);
+ configUpdaterModelFile.createNewFile();
+ File downloaderModelFile = new File(downloaderModelDir, "test0.model");
+ downloaderModelFile.createNewFile();
+ File factoryModelFile = new File(factoryModelDir, "test1.model");
+ factoryModelFile.createNewFile();
+
+ List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ List<String> modelFilePaths =
+ modelFiles.stream().map(ModelFile::getPath).collect(Collectors.toList());
+
+ assertThat(modelFiles).hasSize(3);
+ assertThat(modelFilePaths)
+ .containsExactly(
+ configUpdaterModelFile.getAbsolutePath(),
+ downloaderModelFile.getAbsolutePath(),
+ factoryModelFile.getAbsolutePath());
}
@Test
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
new file mode 100644
index 0000000..88c0ac8
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import java.io.File;
+
+/** Utils to access test data files. */
+public final class TestDataUtils {
+ /** Returns the root folder that contains the test data. */
+ public static File getTestDataFolder() {
+ return new File("/data/local/tmp/TextClassifierServiceTest/");
+ }
+
+ private TestDataUtils() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 6d80673..22674dd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -40,17 +40,22 @@
import android.view.textclassifier.TextSelection;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Locale;
+import java.util.function.Supplier;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.junit.Before;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -63,6 +68,51 @@
private static final String NO_TYPE = null;
private TextClassifierImpl classifier;
+ private static final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>>
+ MODEL_FILES_SUPPLIER =
+ new ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>>()
+ .put(
+ ModelType.ANNOTATOR,
+ () ->
+ ImmutableList.of(
+ new ModelFile(
+ ModelType.ANNOTATOR,
+ new File(
+ TestDataUtils.getTestDataFolder(), "testdata/annotator.model"),
+ 711,
+ ImmutableList.of(Locale.ENGLISH),
+ "en",
+ false)))
+ .put(
+ ModelType.ACTIONS_SUGGESTIONS,
+ (Supplier<ImmutableList<ModelFile>>)
+ () ->
+ ImmutableList.of(
+ new ModelFile(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(
+ TestDataUtils.getTestDataFolder(), "testdata/actions.model"),
+ 104,
+ ImmutableList.of(Locale.ENGLISH),
+ "en",
+ false)))
+ .put(
+ ModelType.LANG_ID,
+ (Supplier<ImmutableList<ModelFile>>)
+ () ->
+ ImmutableList.of(
+ new ModelFile(
+ ModelType.LANG_ID,
+ new File(
+ TestDataUtils.getTestDataFolder(), "testdata/langid.model"),
+ 1,
+ ImmutableList.of(),
+ "*",
+ true)))
+ .build();
+ private final ModelFileManager modelFileManager =
+ new ModelFileManager(
+ /* downloadModelDir= */ TestDataUtils.getTestDataFolder(), MODEL_FILES_SUPPLIER);
@Before
public void setup() {
@@ -71,7 +121,8 @@
.setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
.setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
.build();
- classifier = new TextClassifierImpl(context, new TextClassifierSettings());
+ TextClassifierSettings settings = new TextClassifierSettings();
+ classifier = new TextClassifierImpl(context, settings, modelFileManager);
}
@Test
@@ -221,9 +272,6 @@
assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
}
- // TODO(tonymak): Enable it once we drop the v8 image to Android. I have already run this test
- // after pushing a test model to a device manually.
- @Ignore
@Test
public void testClassifyText_foreignText() {
LocaleList originalLocales = LocaleList.getDefault();
@@ -381,7 +429,6 @@
assertThat(textLanguage, isTextLanguage("ja"));
}
- @Ignore // Doesn't work without a language-based model.
@Test
public void testSuggestConversationActions_textReplyOnly_maxOne() {
ConversationActions.Message message =
@@ -406,7 +453,6 @@
assertThat(conversationAction.getTextReply()).isNotNull();
}
- @Ignore // Doesn't work without a language-based model.
@Test
public void testSuggestConversationActions_textReplyOnly_noMax() {
ConversationActions.Message message =
@@ -457,7 +503,6 @@
assertNoPackageInfoInExtras(actionIntent);
}
- @Ignore // Doesn't work without a language-based model.
@Test
public void testSuggestConversationActions_copy() {
ConversationActions.Message message =
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
index 21ed0b6..c0a823e 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -22,6 +22,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import androidx.test.platform.app.InstrumentationRegistry;
+import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
import java.util.function.Consumer;
import org.junit.After;
import org.junit.Before;
@@ -92,6 +93,49 @@
.inOrder());
}
+ @Test
+ public void modelURLPrefixSetting() {
+ assertSettings(
+ TextClassifierSettings.ANNOTATOR_URL_PREFIX,
+ "prefix:annotator",
+ settings ->
+ assertThat(settings.getModelURLPrefix(ModelType.ANNOTATOR))
+ .isEqualTo("prefix:annotator"));
+ assertSettings(
+ TextClassifierSettings.LANG_ID_URL_PREFIX,
+ "prefix:lang_id",
+ settings ->
+ assertThat(settings.getModelURLPrefix(ModelType.LANG_ID)).isEqualTo("prefix:lang_id"));
+ assertSettings(
+ TextClassifierSettings.ACTIONS_SUGGESTIONS_URL_PREFIX,
+ "prefix:actions_suggestions",
+ settings ->
+ assertThat(settings.getModelURLPrefix(ModelType.ACTIONS_SUGGESTIONS))
+ .isEqualTo("prefix:actions_suggestions"));
+ }
+
+ @Test
+ public void primaryModelURLSuffixSetting() {
+ assertSettings(
+ TextClassifierSettings.PRIMARY_ANNOTATOR_URL_SUFFIX,
+ "suffix:annotator",
+ settings ->
+ assertThat(settings.getPrimaryModelURLSuffix(ModelType.ANNOTATOR))
+ .isEqualTo("suffix:annotator"));
+ assertSettings(
+ TextClassifierSettings.PRIMARY_LANG_ID_URL_SUFFIX,
+ "suffix:lang_id",
+ settings ->
+ assertThat(settings.getPrimaryModelURLSuffix(ModelType.LANG_ID))
+ .isEqualTo("suffix:lang_id"));
+ assertSettings(
+ TextClassifierSettings.PRIMARY_ACTIONS_SUGGESTIONS_URL_SUFFIX,
+ "suffix:actions_suggestions",
+ settings ->
+ assertThat(settings.getPrimaryModelURLSuffix(ModelType.ACTIONS_SUGGESTIONS))
+ .isEqualTo("suffix:actions_suggestions"));
+ }
+
private static void assertSettings(
String key, String value, Consumer<TextClassifierSettings> settingsConsumer) {
final String originalValue =
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index ab241c5..216cd5d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -107,7 +107,7 @@
assertThat(intent.getData().toString()).isEqualTo(DATA);
assertThat(intent.getType()).isEqualTo(TYPE);
assertThat(intent.getFlags()).isEqualTo(FLAG);
- assertThat(intent.getCategories()).containsExactly((Object[]) CATEGORY);
+ assertThat(intent.getCategories()).containsExactlyElementsIn(CATEGORY);
assertThat(intent.getPackage()).isNull();
assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
index 6d01a64..3585f87 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
@@ -20,7 +20,6 @@
import com.android.textclassifier.Entity;
import com.google.common.truth.FailureMetadata;
-import com.google.common.truth.MathUtil;
import com.google.common.truth.Subject;
import javax.annotation.Nullable;
@@ -44,8 +43,9 @@
if (!entity.getEntityType().equals(this.entity.getEntityType())) {
failWithActual("expected to have type", entity.getEntityType());
}
- if (!MathUtil.equalWithinTolerance(entity.getScore(), this.entity.getScore(), TOLERANCE)) {
- failWithActual("expected to have confidence score", entity.getScore());
- }
+ check("expected to have confidence score")
+ .that(entity.getScore())
+ .isWithin(TOLERANCE)
+ .of(this.entity.getScore());
}
}
diff --git a/java/tests/instrumentation/testdata/actions.model b/java/tests/instrumentation/testdata/actions.model
new file mode 100755
index 0000000..74422f6
--- /dev/null
+++ b/java/tests/instrumentation/testdata/actions.model
Binary files differ
diff --git a/java/tests/instrumentation/testdata/annotator.model b/java/tests/instrumentation/testdata/annotator.model
new file mode 100755
index 0000000..f5fcc23
--- /dev/null
+++ b/java/tests/instrumentation/testdata/annotator.model
Binary files differ
diff --git a/java/tests/instrumentation/testdata/langid.model b/java/tests/instrumentation/testdata/langid.model
new file mode 100755
index 0000000..e94dada
--- /dev/null
+++ b/java/tests/instrumentation/testdata/langid.model
Binary files differ
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 8b6cf2e..0ddb01c 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -295,6 +295,14 @@
public ActionSuggestionOptions() {}
}
+ /**
+ * Retrieves the pointer to the native object. Note: Need to keep the {@code
+ * ActionsSuggestionsModel} alive as long as the pointer is used.
+ */
+ long getNativeModelPointer() {
+ return nativeGetNativeModelPtr(actionsModelPtr);
+ }
+
private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
private static native long nativeNewActionsModelFromPath(
@@ -325,4 +333,6 @@
boolean generateAndroidIntents);
private native void nativeCloseActionsModel(long ptr);
+
+ private native long nativeGetNativeModelPtr(long context);
}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index a116f0a..d2001ed 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -711,6 +711,7 @@
private final float userLocationAccuracyMeters;
private final String userFamiliarLanguageTags;
private final boolean usePodNer;
+ private final boolean triggerDictionaryOnBeginnerWords;
private ClassificationOptions(
long referenceTimeMsUtc,
@@ -722,7 +723,8 @@
double userLocationLng,
float userLocationAccuracyMeters,
String userFamiliarLanguageTags,
- boolean usePodNer) {
+ boolean usePodNer,
+ boolean triggerDictionaryOnBeginnerWords) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
@@ -733,6 +735,7 @@
this.userLocationAccuracyMeters = userLocationAccuracyMeters;
this.userFamiliarLanguageTags = userFamiliarLanguageTags;
this.usePodNer = usePodNer;
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
}
/** Can be used to build a ClassificationOptions instance. */
@@ -747,6 +750,7 @@
private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
private String userFamiliarLanguageTags = "";
private boolean usePodNer = true;
+ private boolean triggerDictionaryOnBeginnerWords = false;
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -798,6 +802,12 @@
return this;
}
+ public Builder setTrigerringDictionaryOnBeginnerWords(
+ boolean triggerDictionaryOnBeginnerWords) {
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ return this;
+ }
+
public ClassificationOptions build() {
return new ClassificationOptions(
referenceTimeMsUtc,
@@ -809,7 +819,8 @@
userLocationLng,
userLocationAccuracyMeters,
userFamiliarLanguageTags,
- usePodNer);
+ usePodNer,
+ triggerDictionaryOnBeginnerWords);
}
}
@@ -859,6 +870,10 @@
public boolean getUsePodNer() {
return usePodNer;
}
+
+ public boolean getTriggerDictionaryOnBeginnerWords() {
+ return triggerDictionaryOnBeginnerWords;
+ }
}
/** Represents options for the annotate call. */
@@ -877,6 +892,7 @@
private final double userLocationLng;
private final float userLocationAccuracyMeters;
private final boolean usePodNer;
+ private final boolean triggerDictionaryOnBeginnerWords;
private AnnotationOptions(
long referenceTimeMsUtc,
@@ -892,7 +908,8 @@
double userLocationLat,
double userLocationLng,
float userLocationAccuracyMeters,
- boolean usePodNer) {
+ boolean usePodNer,
+ boolean triggerDictionaryOnBeginnerWords) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
this.referenceTimezone = referenceTimezone;
this.locales = locales;
@@ -907,6 +924,7 @@
this.hasLocationPermission = hasLocationPermission;
this.hasPersonalizationPermission = hasPersonalizationPermission;
this.usePodNer = usePodNer;
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
}
/** Can be used to build an AnnotationOptions instance. */
@@ -925,6 +943,7 @@
private double userLocationLng = INVALID_LONGITUDE;
private float userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
private boolean usePodNer = true;
+ private boolean triggerDictionaryOnBeginnerWords = false;
public Builder setReferenceTimeMsUtc(long referenceTimeMsUtc) {
this.referenceTimeMsUtc = referenceTimeMsUtc;
@@ -996,6 +1015,11 @@
return this;
}
+ public Builder setTriggerDictionaryOnBeginnerWords(boolean triggerDictionaryOnBeginnerWords) {
+ this.triggerDictionaryOnBeginnerWords = triggerDictionaryOnBeginnerWords;
+ return this;
+ }
+
public AnnotationOptions build() {
return new AnnotationOptions(
referenceTimeMsUtc,
@@ -1011,7 +1035,8 @@
userLocationLat,
userLocationLng,
userLocationAccuracyMeters,
- usePodNer);
+ usePodNer,
+ triggerDictionaryOnBeginnerWords);
}
}
@@ -1077,6 +1102,10 @@
public boolean getUsePodNer() {
return usePodNer;
}
+
+ public boolean getTriggerDictionaryOnBeginnerWords() {
+ return triggerDictionaryOnBeginnerWords;
+ }
}
/**
diff --git a/native/Android.bp b/native/Android.bp
index 461b9a0..b120565 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -81,13 +81,15 @@
"-funsigned-char",
"-fvisibility=hidden",
+
"-DLIBTEXTCLASSIFIER_UNILIB_ICU",
"-DZLIB_CONST",
"-DSAFTM_COMPACT_LOGGING",
"-DTC3_WITH_ACTIONS_OPS",
"-DTC3_UNILIB_JAVAICU",
"-DTC3_CALENDAR_JAVAICU",
- "-DTC3_AOSP"
+ "-DTC3_AOSP",
+ "-DTC3_VOCAB_ANNOTATOR_DUMMY"
],
product_variables: {
@@ -127,6 +129,20 @@
depfile: true,
}
+genrule {
+ name: "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers_test",
+ srcs: ["utils/flatbuffers/flatbuffers_test.fbs"],
+ out: ["utils/flatbuffers/flatbuffers_test_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_utils_lua_utils_tests",
+ srcs: ["utils/lua_utils_tests.fbs"],
+ out: ["utils/lua_utils_tests_generated.h"],
+ defaults: ["fbgen"],
+}
+
// -----------------
// libtextclassifier
// -----------------
@@ -166,6 +182,7 @@
data: [
"**/test_data/*",
+ "**/*.bfbs",
],
srcs: ["**/*.cc"],
@@ -176,10 +193,17 @@
static_libs: [
"libgmock_ndk",
"libgtest_ndk_c++",
- "libbase_ndk"
+ "libbase_ndk",
+ ],
+
+ generated_headers: [
+ "libtextclassifier_fbgen_utils_flatbuffers_flatbuffers_test",
+ "libtextclassifier_fbgen_utils_lua_utils_tests",
],
compile_multilib: "prefer32",
+
+ sdk_variant_only: true
}
// ------------------------------------
@@ -188,15 +212,14 @@
cc_test_library {
name: "libjvm_test_launcher",
defaults: ["libtextclassifier_defaults"],
-
srcs: [
":libtextclassifier_java_test_sources",
+ "actions/test-utils.cc",
+ "utils/testing/annotator.cc",
"utils/testing/logging_event_listener.cc",
- "testing/jvm_test_launcher.cc"
+ "testing/jvm_test_launcher.cc",
],
-
version_script: "jni.lds",
-
static_libs: [
"libgmock_ndk",
"libgtest_ndk_c++",
@@ -218,11 +241,15 @@
"androidx.test.espresso.core",
"androidx.test.ext.truth",
"truth-prebuilt",
+ "TextClassifierCoverageLib",
],
jni_libs: [
"libjvm_test_launcher",
],
jni_uses_sdk_apis: true,
+ data: [
+ "**/test_data/*",
+ ],
test_config: "JavaTest.xml",
compile_multilib: "both",
}
diff --git a/native/JavaTest.xml b/native/JavaTest.xml
index 2f4a0c1..5393fd8 100644
--- a/native/JavaTest.xml
+++ b/native/JavaTest.xml
@@ -21,10 +21,17 @@
<option name="cleanup-apks" value="true" />
<option name="test-file-name" value="libtextclassifier_java_tests.apk" />
</target_preparer>
+ <target_preparer class="com.android.compatibility.common.tradefed.targetprep.FilePusher">
+ <option name="cleanup" value="true" />
+ <option name="push" value="actions->/data/local/tmp/actions" />
+ <option name="push" value="annotator->/data/local/tmp/annotator" />
+ <option name="push" value="utils->/data/local/tmp/utils" />
+ </target_preparer>
<test class="com.android.tradefed.testtype.AndroidJUnitTest" >
<option name="package" value="com.google.android.textclassifier.tests" />
<option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ <option name="device-listeners" value="com.android.textclassifier.testing.TextClassifierInstrumentationListener" />
</test>
<object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index 425920f..af2ae1c 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -17,9 +17,14 @@
filegroup {
name: "libtextclassifier_java_test_sources",
srcs: [
+ "actions/grammar-actions_test.cc",
+ "annotator/datetime/parser_test.cc",
+ "utils/intents/intent-generator-test-lib.cc",
+ "annotator/grammar/grammar-annotator_test.cc",
+ "annotator/grammar/test-utils.cc",
"annotator/number/number_test-include.cc",
+ "annotator/annotator_test-include.cc",
"utils/utf8/unilib_test-include.cc",
"utils/calendar/calendar_test-include.cc",
- "utils/intents/intent-generator-test-lib.cc",
],
}
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
new file mode 100644
index 0000000..d3f13e4
--- /dev/null
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 93ef544..f550cc7 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -18,13 +18,17 @@
#include <memory>
+#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-actions.h"
+#endif
#include "actions/types.h"
#include "actions/utils.h"
#include "actions/zlib-utils.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
+#if !defined(TC3_DISABLE_LUA)
#include "utils/lua-utils.h"
+#endif
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/strings/split.h"
@@ -35,23 +39,6 @@
namespace libtextclassifier3 {
-const std::string& ActionsSuggestions::kViewCalendarType =
- *[]() { return new std::string("view_calendar"); }();
-const std::string& ActionsSuggestions::kViewMapType =
- *[]() { return new std::string("view_map"); }();
-const std::string& ActionsSuggestions::kTrackFlightType =
- *[]() { return new std::string("track_flight"); }();
-const std::string& ActionsSuggestions::kOpenUrlType =
- *[]() { return new std::string("open_url"); }();
-const std::string& ActionsSuggestions::kSendSmsType =
- *[]() { return new std::string("send_sms"); }();
-const std::string& ActionsSuggestions::kCallPhoneType =
- *[]() { return new std::string("call_phone"); }();
-const std::string& ActionsSuggestions::kSendEmailType =
- *[]() { return new std::string("send_email"); }();
-const std::string& ActionsSuggestions::kShareLocation =
- *[]() { return new std::string("share_location"); }();
-
constexpr float kDefaultFloat = 0.0;
constexpr bool kDefaultBool = false;
constexpr int kDefaultInt = 1;
@@ -317,6 +304,7 @@
}
}
+#if !defined(TC3_DISABLE_LUA)
std::string actions_script;
if (GetUncompressedString(model_->lua_actions_script(),
model_->compressed_lua_actions_script(),
@@ -327,6 +315,7 @@
return false;
}
}
+#endif // TC3_DISABLE_LUA
if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
model_->ranking_options(), decompressor.get(),
@@ -1193,6 +1182,7 @@
return result;
}
+#if !defined(TC3_DISABLE_LUA)
bool ActionsSuggestions::SuggestActionsFromLua(
const Conversation& conversation, const TfLiteModelExecutor* model_executor,
const tflite::Interpreter* interpreter,
@@ -1211,6 +1201,15 @@
}
return lua_actions->SuggestActions(actions);
}
+#else
+bool ActionsSuggestions::SuggestActionsFromLua(
+ const Conversation& conversation, const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const {
+ return true;
+}
+#endif
bool ActionsSuggestions::GatherActionsSuggestions(
const Conversation& conversation, const Annotator* annotator,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 1fee9a1..04c8aa7 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -115,17 +115,6 @@
static constexpr int kLocalUserId = 0;
- // Should be in sync with those defined in Android.
- // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
- static const std::string& kViewCalendarType;
- static const std::string& kViewMapType;
- static const std::string& kTrackFlightType;
- static const std::string& kOpenUrlType;
- static const std::string& kSendSmsType;
- static const std::string& kCallPhoneType;
- static const std::string& kSendEmailType;
- static const std::string& kShareLocation;
-
protected:
// Exposed for testing.
bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
@@ -298,6 +287,72 @@
return function(model);
}
+class ActionsSuggestionsTypes {
+ public:
+ // Should be in sync with those defined in Android.
+ // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
+ static const std::string& ViewCalendar() {
+ static const std::string& value =
+ *[]() { return new std::string("view_calendar"); }();
+ return value;
+ }
+ static const std::string& ViewMap() {
+ static const std::string& value =
+ *[]() { return new std::string("view_map"); }();
+ return value;
+ }
+ static const std::string& TrackFlight() {
+ static const std::string& value =
+ *[]() { return new std::string("track_flight"); }();
+ return value;
+ }
+ static const std::string& OpenUrl() {
+ static const std::string& value =
+ *[]() { return new std::string("open_url"); }();
+ return value;
+ }
+ static const std::string& SendSms() {
+ static const std::string& value =
+ *[]() { return new std::string("send_sms"); }();
+ return value;
+ }
+ static const std::string& CallPhone() {
+ static const std::string& value =
+ *[]() { return new std::string("call_phone"); }();
+ return value;
+ }
+ static const std::string& SendEmail() {
+ static const std::string& value =
+ *[]() { return new std::string("send_email"); }();
+ return value;
+ }
+ static const std::string& ShareLocation() {
+ static const std::string& value =
+ *[]() { return new std::string("share_location"); }();
+ return value;
+ }
+ static const std::string& CreateReminder() {
+ static const std::string& value =
+ *[]() { return new std::string("create_reminder"); }();
+ return value;
+ }
+ static const std::string& TextReply() {
+ static const std::string& value =
+ *[]() { return new std::string("text_reply"); }();
+ return value;
+ }
+ static const std::string& AddContact() {
+ static const std::string& value =
+ *[]() { return new std::string("add_contact"); }();
+ return value;
+ }
+ static const std::string& Copy() {
+ static const std::string& value =
+ *[]() { return new std::string("copy"); }();
+ return value;
+ }
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 9b49635..ed92981 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -1106,7 +1106,7 @@
// Check that the location sharing model triggered.
bool has_location_sharing_action = false;
for (const ActionSuggestion& action : response.actions) {
- if (action.type == ActionsSuggestions::kShareLocation) {
+ if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
has_location_sharing_action = true;
break;
}
@@ -1132,7 +1132,7 @@
ActionSuggestionSpecT* action =
actions_model->rules->regex_rule.back()->actions.back()->action.get();
action->score = 1.0f;
- action->type = ActionsSuggestions::kShareLocation;
+ action->type = ActionsSuggestionsTypes::ShareLocation();
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 1648fb3..1d5c2fb 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -524,3 +524,12 @@
new libtextclassifier3::ScopedMmap(fd, offset, size));
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ if (!ptr) {
+ return 0L;
+ }
+ return reinterpret_cast<jlong>(
+ reinterpret_cast<ActionsSuggestionsJniContext*>(ptr)->model());
+}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index 75e2e67..5d6a79d 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -67,6 +67,9 @@
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index c394e5b..7d626a8 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -14,16 +14,16 @@
// limitations under the License.
//
-include "utils/codepoint-range.fbs";
-include "utils/grammar/rules.fbs";
-include "actions/actions-entity-data.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
-include "utils/resources.fbs";
-include "utils/tokenizer.fbs";
-include "utils/intents/intent-config.fbs";
include "annotator/model.fbs";
include "utils/normalization.fbs";
+include "actions/actions-entity-data.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/resources.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/grammar/rules.fbs";
+include "utils/tokenizer.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
file_identifier "TC3A";
diff --git a/native/actions/grammar-actions_test.cc b/native/actions/grammar-actions_test.cc
new file mode 100644
index 0000000..e738dee
--- /dev/null
+++ b/native/actions/grammar-actions_test.cc
@@ -0,0 +1,726 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/grammar-actions.h"
+
+#include <iostream>
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/test-utils.h"
+#include "actions/types.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/jvm-test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+
+class TestGrammarActions : public GrammarActions {
+ public:
+ explicit TestGrammarActions(
+ const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
+ const MutableFlatbufferBuilder* entity_data_builder = nullptr)
+ : GrammarActions(unilib, grammar_rules, entity_data_builder,
+
+ /*smart_reply_action_type=*/"text_reply") {}
+};
+
+class GrammarActionsTest : public testing::Test {
+ protected:
+ struct AnnotationSpec {
+ int group_id = 0;
+ std::string annotation_name = "";
+ bool use_annotation_match = false;
+ };
+
+ GrammarActionsTest()
+ : unilib_(CreateUniLibForTesting()),
+ serialized_entity_data_schema_(TestEntityDataSchema()),
+ entity_data_builder_(new MutableFlatbufferBuilder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ serialized_entity_data_schema_.data()))) {}
+
+ void SetTokenizerOptions(
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
+ action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
+ action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
+ false;
+ }
+
+ flatbuffers::DetachedBuffer PackRules(
+ const RulesModel_::GrammarRulesT& action_grammar_rules) const {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(
+ RulesModel_::GrammarRules::Pack(builder, &action_grammar_rules));
+ return builder.Release();
+ }
+
+ int AddActionSpec(const std::string& type, const std::string& response_text,
+ const std::vector<AnnotationSpec>& annotations,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int action_id = action_grammar_rules->actions.size();
+ action_grammar_rules->actions.emplace_back(
+ new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->action.reset(new ActionSuggestionSpecT);
+ actions_spec->action->response_text = response_text;
+ actions_spec->action->priority_score = 1.0;
+ actions_spec->action->score = 1.0;
+ actions_spec->action->type = type;
+ // Create annotations for specified capturing groups.
+ for (const AnnotationSpec& annotation : annotations) {
+ actions_spec->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ actions_spec->capturing_group.back()->group_id = annotation.group_id;
+ actions_spec->capturing_group.back()->annotation_name =
+ annotation.annotation_name;
+ actions_spec->capturing_group.back()->annotation_type =
+ annotation.annotation_name;
+ actions_spec->capturing_group.back()->use_annotation_match =
+ annotation.use_annotation_match;
+ }
+
+ return action_id;
+ }
+
+ int AddSmartReplySpec(
+ const std::string& response_text,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ return AddActionSpec("text_reply", response_text, {}, action_grammar_rules);
+ }
+
+ int AddCapturingMatchSmartReplySpec(
+ const int match_id,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int action_id = action_grammar_rules->actions.size();
+ action_grammar_rules->actions.emplace_back(
+ new RulesModel_::RuleActionSpecT);
+ RulesModel_::RuleActionSpecT* actions_spec =
+ action_grammar_rules->actions.back().get();
+ actions_spec->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ actions_spec->capturing_group.back()->group_id = match_id;
+ actions_spec->capturing_group.back()->text_reply.reset(
+ new ActionSuggestionSpecT);
+ actions_spec->capturing_group.back()->text_reply->priority_score = 1.0;
+ actions_spec->capturing_group.back()->text_reply->score = 1.0;
+ return action_id;
+ }
+
+ int AddRuleMatch(const std::vector<int>& action_ids,
+ RulesModel_::GrammarRulesT* action_grammar_rules) const {
+ const int rule_match_id = action_grammar_rules->rule_match.size();
+ action_grammar_rules->rule_match.emplace_back(
+ new RulesModel_::GrammarRules_::RuleMatchT);
+ action_grammar_rules->rule_match.back()->action_id.insert(
+ action_grammar_rules->rule_match.back()->action_id.end(),
+ action_ids.begin(), action_ids.end());
+ return rule_match_id;
+ }
+
+ std::unique_ptr<UniLib> unilib_;
+ const std::string serialized_entity_data_schema_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
+};
+
+TEST_F(GrammarActionsTest, ProducesSmartReplies) {
+ // Create test rules.
+ // Rule: ^knock knock.?$ -> "Who's there?", "Yes?"
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add(
+ "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules),
+ AddSmartReplySpec("Yes?", &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Knock knock"}}}, &result));
+
+ EXPECT_THAT(result,
+ ElementsAre(IsSmartReply("Who's there?"), IsSmartReply("Yes?")));
+}
+
+TEST_F(GrammarActionsTest, ProducesSmartRepliesFromCapturingMatches) {
+ // Create test rules.
+ // Rule: ^Text <reply> to <command>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+
+ rules.Add("<scripted_reply>",
+ {"<^>", "text", "<captured_reply>", "to", "<command>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddCapturingMatchSmartReplySpec(
+ /*match_id=*/0, &action_grammar_rules)},
+ &action_grammar_rules));
+
+ // <command> ::= unsubscribe | cancel | confirm | receive
+ rules.Add("<command>", {"unsubscribe"});
+ rules.Add("<command>", {"cancel"});
+ rules.Add("<command>", {"confirm"});
+ rules.Add("<command>", {"receive"});
+
+ // <reply> ::= help | stop | cancel | yes
+ rules.Add("<reply>", {"help"});
+ rules.Add("<reply>", {"stop"});
+ rules.Add("<reply>", {"cancel"});
+ rules.Add("<reply>", {"yes"});
+ rules.AddValueMapping("<captured_reply>", {"<reply>"},
+ /*value=*/0);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0,
+ /*text=*/"Text YES to confirm your subscription"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("YES")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0,
+ /*text=*/"text Stop to cancel your order"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Stop")));
+ }
+}
+
+TEST_F(GrammarActionsTest, ProducesAnnotationsForActions) {
+ // Create test rules.
+ // Rule: please dial <phone>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+
+ rules.Add(
+ "<call_phone>", {"please", "dial", "<phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
+ /*annotations=*/{{0 /*value*/, "phone"}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+ // phone ::= +00 00 000 00 00
+ rules.AddValueMapping("<phone>",
+ {"+", "<2_digits>", "<2_digits>", "<3_digits>",
+ "<2_digits>", "<2_digits>"},
+ /*value=*/0);
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
+ EXPECT_THAT(result.front().annotations,
+ ElementsAre(IsActionSuggestionAnnotation(
+ "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
+}
+
+TEST_F(GrammarActionsTest, HandlesLocales) {
+ // Create test rules.
+ // Rule: ^knock knock.?$ -> "Who's there?"
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules(/*num_shards=*/2);
+ rules.Add(
+ "<knock>", {"<^>", "knock", "knock", ".?", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Who's there?", &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.Add(
+ "<toc>", {"<knock>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddSmartReplySpec("Qui est là?", &action_grammar_rules)},
+ &action_grammar_rules),
+ /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false,
+ /*shard=*/1);
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ // Set locales for rules.
+ action_grammar_rules.rules->rules.back()->locale.emplace_back(
+ new LanguageTagT);
+ action_grammar_rules.rules->rules.back()->locale.back()->language = "fr";
+
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ // Check default.
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC", /*annotations=*/{},
+ /*detected_text_language_tags=*/"en"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?")));
+ }
+
+ // Check fr.
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"knock knock",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC", /*annotations=*/{},
+ /*detected_text_language_tags=*/"fr-CH"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("Who's there?"),
+ IsSmartReply("Qui est là?")));
+ }
+}
+
+TEST_F(GrammarActionsTest, HandlesAssertions) {
+ // Create test rules.
+ // Rule: <flight> -> Track flight.
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+
+ // Capture flight code.
+ rules.AddValueMapping("<flight>", {"<carrier>", "<flight_code>"},
+ /*value=*/0);
+
+ // Flight: carrier + flight code and check right context.
+ rules.Add(
+ "<track_flight>", {"<flight>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("track_flight", /*response_text=*/"",
+ /*annotations=*/{{0 /*value*/, "flight"}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"LX38 aa 44 LX 38.38"}}},
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("track_flight"),
+ IsActionOfType("track_flight")));
+ EXPECT_THAT(result[0].annotations,
+ ElementsAre(IsActionSuggestionAnnotation("flight", "LX38",
+ CodepointSpan{0, 4})));
+ EXPECT_THAT(result[1].annotations,
+ ElementsAre(IsActionSuggestionAnnotation("flight", "aa 44",
+ CodepointSpan{5, 10})));
+}
+
+TEST_F(GrammarActionsTest, SetsFixedEntityData) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+
+ // Create smart reply and static entity data.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
+ entity_data->Serialize();
+ action_grammar_rules.actions[spec_id]->action->entity_data.reset(
+ new ActionsEntityDataT);
+ action_grammar_rules.actions[spec_id]->action->entity_data->text =
+ "I have the high ground.";
+
+ rules.Add("<greeting>", {"<^>", "hello", "there", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "I have the high ground.");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(GrammarActionsTest, SetsEntityDataFromCapturingMatches) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+
+ // Create smart reply and static entity data.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Kenobi");
+ action_grammar_rules.actions[spec_id]->action->serialized_entity_data =
+ entity_data->Serialize();
+
+ // Specify results for capturing matches.
+ const int greeting_match_id = 0;
+ const int location_match_id = 1;
+ {
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = greeting_match_id;
+ group->entity_field.reset(new FlatbufferFieldPathT);
+ group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field->field.back()->field_name = "greeting";
+ }
+ {
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = location_match_id;
+ group->entity_field.reset(new FlatbufferFieldPathT);
+ group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field->field.back()->field_name = "location";
+ }
+
+ rules.Add("<location>", {"there"});
+ rules.Add("<location>", {"here"});
+ rules.AddValueMapping("<captured_location>", {"<location>"},
+ /*value=*/location_match_id);
+ rules.AddValueMapping("<greeting>", {"hello", "<captured_location>"},
+ /*value=*/greeting_match_id);
+ rules.Add("<test>", {"<^>", "<greeting>", "<$>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Hello there");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+TEST_F(GrammarActionsTest, SetsFixedEntityDataFromCapturingGroups) {
+ // Create test rules.
+ // Rule: ^hello there$
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+
+ // Create smart reply.
+ const int spec_id =
+ AddSmartReplySpec("General Kenobi!", &action_grammar_rules);
+ action_grammar_rules.actions[spec_id]->capturing_group.emplace_back(
+ new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::RuleActionSpec_::RuleCapturingGroupT* group =
+ action_grammar_rules.actions[spec_id]->capturing_group.back().get();
+ group->group_id = 0;
+ group->entity_data.reset(new ActionsEntityDataT);
+ group->entity_data->text = "You are a bold one.";
+
+ rules.AddValueMapping("<greeting>", {"<^>", "hello", "there", "<$>"},
+ /*value=*/0);
+ rules.Add("<test>", {"<greeting>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({spec_id}, &action_grammar_rules));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
+ entity_data_builder_.get());
+
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Hello there"}}}, &result));
+
+ // Check the produces smart replies.
+ EXPECT_THAT(result, ElementsAre(IsSmartReply("General Kenobi!")));
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result[0].serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "You are a bold one.");
+}
+
+TEST_F(GrammarActionsTest, ProducesActionsWithAnnotations) {
+ // Create test rules.
+ // Rule: please dial <phone>
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<call_phone>", {"please", "dial", "<phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("call_phone", /*response_text=*/"",
+ /*annotations=*/
+ {{0 /*value*/, "phone",
+ /*use_annotation_match=*/true}},
+ &action_grammar_rules)},
+ &action_grammar_rules));
+ rules.AddValueMapping("<phone>", {"<phone_annotation>"},
+ /*value=*/0);
+
+ grammar::Ir ir = rules.Finalize(
+ /*predefined_nonterminals=*/{"<phone_annotation>"});
+ ir.Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+
+ // Map "phone" annotation to "<phone_annotation>" nonterminal.
+ action_grammar_rules.rules->nonterminals->annotation_nt.emplace_back(
+ new grammar::RulesSet_::Nonterminals_::AnnotationNtEntryT);
+ action_grammar_rules.rules->nonterminals->annotation_nt.back()->key = "phone";
+ action_grammar_rules.rules->nonterminals->annotation_nt.back()->value =
+ ir.GetNonterminalForName("<phone_annotation>");
+
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()));
+
+ std::vector<ActionSuggestion> result;
+
+ // Sanity check that no result are produced when no annotations are provided.
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"Please dial +41 79 123 45 67"}}},
+ &result));
+ EXPECT_THAT(result, IsEmpty());
+
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0,
+ /*text=*/"Please dial +41 79 123 45 67",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"UTC",
+ /*annotations=*/
+ {{CodepointSpan{12, 28}, {ClassificationResult{"phone", 1.0}}}}}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("call_phone")));
+ EXPECT_THAT(result.front().annotations,
+ ElementsAre(IsActionSuggestionAnnotation(
+ "phone", "+41 79 123 45 67", CodepointSpan{12, 28})));
+}
+
+TEST_F(GrammarActionsTest, HandlesExclusions) {
+ // Create test rules.
+ RulesModel_::GrammarRulesT action_grammar_rules;
+ SetTokenizerOptions(&action_grammar_rules);
+ action_grammar_rules.rules.reset(new grammar::RulesSetT);
+
+ grammar::Rules rules;
+ rules.Add("<excluded>", {"be", "safe"});
+ rules.AddWithExclusion("<tokens_but_not_excluded>", {"<token>", "<token>"},
+ /*excluded_nonterminal=*/"<excluded>");
+
+ rules.Add("<set_reminder>",
+ {"do", "not", "forget", "to", "<tokens_but_not_excluded>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(
+ GrammarActions::Callback::kActionRuleMatch),
+ /*callback_param=*/
+ AddRuleMatch({AddActionSpec("set_reminder", /*response_text=*/"",
+ /*annotations=*/
+ {}, &action_grammar_rules)},
+ &action_grammar_rules));
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ action_grammar_rules.rules.get());
+ flatbuffers::DetachedBuffer serialized_rules =
+ PackRules(action_grammar_rules);
+ TestGrammarActions grammar_actions(
+ unilib_.get(),
+ flatbuffers::GetRoot<RulesModel_::GrammarRules>(serialized_rules.data()),
+ entity_data_builder_.get());
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0, /*text=*/"do not forget to bring milk"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be there!"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{
+ {/*user_id=*/0, /*text=*/"do not forget to buy safe or vault!"}}},
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsActionOfType("set_reminder")));
+ }
+
+ {
+ std::vector<ActionSuggestion> result;
+ EXPECT_TRUE(grammar_actions.SuggestActions(
+ {/*messages=*/{{/*user_id=*/0, /*text=*/"do not forget to be safe!"}}},
+ &result));
+ EXPECT_THAT(result, IsEmpty());
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index 5a03da5..d52ecaa 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,11 +20,15 @@
#include <set>
#include <vector>
+#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
+#endif
#include "actions/zlib-utils.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
+#if !defined(TC3_DISABLE_LUA)
#include "utils/lua-utils.h"
+#endif
namespace libtextclassifier3 {
namespace {
@@ -215,6 +219,7 @@
return false;
}
+#if !defined(TC3_DISABLE_LUA)
std::string lua_ranking_script;
if (GetUncompressedString(options_->lua_ranking_script(),
options_->compressed_lua_ranking_script(),
@@ -225,6 +230,7 @@
return false;
}
}
+#endif
return true;
}
@@ -336,6 +342,7 @@
SortByScoreAndType(&response->actions);
}
+#if !defined(TC3_DISABLE_LUA)
// Run lua ranking snippet, if provided.
if (!lua_bytecode_.empty()) {
auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
@@ -346,6 +353,7 @@
return false;
}
}
+#endif
return true;
}
diff --git a/native/actions/test_data/actions_suggestions_test.hashgram.model b/native/actions/test_data/actions_suggestions_test.hashgram.model
old mode 100755
new mode 100644
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
old mode 100755
new mode 100644
index 97a0a14..5d265c1
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
old mode 100755
new mode 100644
index 84be451..e6d8758
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
new file mode 100644
index 0000000..708b0be
--- /dev/null
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/types.h b/native/actions/types.h
index c971529..8862262 100644
--- a/native/actions/types.h
+++ b/native/actions/types.h
@@ -75,7 +75,7 @@
// Extras information.
std::string serialized_entity_data;
- const ActionsEntityData* entity_data() {
+ const ActionsEntityData* entity_data() const {
return LoadAndVerifyFlatbuffer<ActionsEntityData>(
serialized_entity_data.data(), serialized_entity_data.size());
}
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 2dc9b5c..a2d8281 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -190,8 +190,32 @@
return nullptr;
}
- auto classifier =
- std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromString(
+ const std::string& buffer, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->owned_buffer_ = buffer;
+ const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
+ classifier->owned_buffer_.size());
+ if (model == nullptr) {
+ return nullptr;
+ }
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -214,8 +238,12 @@
return nullptr;
}
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, unilib, calendarlib));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->mmap_ = std::move(*mmap);
+ unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
+ calendarlib =
+ MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
+ classifier->ValidateAndInitialize(model, unilib, calendarlib);
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -238,8 +266,12 @@
return nullptr;
}
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
+ auto classifier = std::unique_ptr<Annotator>(new Annotator());
+ classifier->mmap_ = std::move(*mmap);
+ classifier->owned_unilib_ = std::move(unilib);
+ classifier->owned_calendarlib_ = std::move(calendarlib);
+ classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
+ classifier->owned_calendarlib_.get());
if (!classifier->IsInitialized()) {
return nullptr;
}
@@ -288,40 +320,12 @@
return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
}
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
+void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ model_ = model;
+ unilib_ = unilib;
+ calendarlib_ = calendarlib;
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(std::move(unilib)),
- unilib_(owned_unilib_.get()),
- owned_calendarlib_(std::move(calendarlib)),
- calendarlib_(owned_calendarlib_.get()) {
- ValidateAndInitialize();
-}
-
-Annotator::Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib)
- : model_(model),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
-
-void Annotator::ValidateAndInitialize() {
initialized_ = false;
if (model_ == nullptr) {
@@ -512,10 +516,20 @@
unilib_, model_->grammar_model(), entity_data_builder_.get()));
}
+ // The following #ifdef is here to aid quality evaluation of a situation, when
+ // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
+ // it.
+#if !defined(TC3_DISABLE_POD_NER)
if (model_->pod_ner_model()) {
pod_ner_annotator_ =
PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
}
+#endif
+
+ if (model_->vocab_model()) {
+ vocab_annotator_ = VocabAnnotator::Create(
+ model_->vocab_model(), *selection_feature_processor_, *unilib_);
+ }
if (model_->entity_data_schema()) {
entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
@@ -709,7 +723,6 @@
if (ExperimentalAnnotator::IsEnabled()) {
experimental_annotator_.reset(new ExperimentalAnnotator(
model_->experimental_model(), *selection_feature_processor_, *unilib_));
-
return true;
}
return false;
@@ -1891,6 +1904,15 @@
candidates.push_back({selection_indices, {pod_ner_annotator_result}});
}
+ ClassificationResult vocab_annotator_result;
+ if (vocab_annotator_ &&
+ vocab_annotator_->ClassifyText(
+ context_unicode, selection_indices, detected_text_language_tags,
+ options.trigger_dictionary_on_beginner_words,
+ &vocab_annotator_result)) {
+ candidates.push_back({selection_indices, {vocab_annotator_result}});
+ }
+
if (experimental_annotator_) {
experimental_annotator_->ClassifyText(context_unicode, selection_indices,
candidates);
@@ -2221,6 +2243,14 @@
return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
}
+ // Annotate with the vocab annotator.
+ if (vocab_annotator_ != nullptr &&
+ !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
+ options.trigger_dictionary_on_beginner_words,
+ candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
+ }
+
// Annotate with the experimental annotator.
if (experimental_annotator_ != nullptr &&
!experimental_annotator_->Annotate(context_unicode, candidates)) {
@@ -2638,13 +2668,16 @@
std::string quantity;
GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
&quantity, &quantity_exponent);
- if (quantity_exponent != 0) {
+ if ((quantity_exponent > 0 && quantity_exponent < 9) ||
+ (quantity_exponent == 9 && data->money->amount_whole_part <= 2)) {
data->money->amount_whole_part =
data->money->amount_whole_part * pow(10, quantity_exponent) +
data->money->nanos / pow(10, 9 - quantity_exponent);
data->money->nanos = data->money->nanos %
static_cast<int>(pow(10, 9 - quantity_exponent)) *
pow(10, quantity_exponent);
+ }
+ if (quantity_exponent > 0) {
data->money->unnormalized_amount = strings::JoinStrings(
" ", {data->money->unnormalized_amount, quantity});
}
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index 67f10d3..f55be4d 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -42,6 +42,7 @@
#include "annotator/strip-unpaired-brackets.h"
#include "annotator/translate/translate.h"
#include "annotator/types.h"
+#include "annotator/vocab/vocab-annotator.h"
#include "annotator/zlib-utils.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
@@ -107,6 +108,10 @@
static std::unique_ptr<Annotator> FromUnownedBuffer(
const char* buffer, int size, const UniLib* unilib = nullptr,
const CalendarLib* calendarlib = nullptr);
+ // Copies the underlying model buffer string.
+ static std::unique_ptr<Annotator> FromString(
+ const std::string& buffer, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
// Takes ownership of the mmap.
static std::unique_ptr<Annotator> FromScopedMmap(
std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
@@ -236,22 +241,14 @@
float score;
};
- // Constructs and initializes text classifier from given model.
- // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib);
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
-
- // Constructs, validates and initializes text classifier from given model.
- // Does not own the buffer that backs 'model'.
- Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib);
+ // NOTE: ValidateAndInitialize needs to be called before any other method.
+ Annotator() : initialized_(false) {}
// Checks that model contains all required fields, and initializes internal
// datastructures.
- void ValidateAndInitialize();
+ // Needs to be called before any other method is.
+ void ValidateAndInitialize(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib);
// Initializes regular expressions for the regex model.
bool InitializeRegexModel(ZlibDecompressor* decompressor);
@@ -447,6 +444,10 @@
std::unique_ptr<const GrammarAnnotator> grammar_annotator_;
+ std::string owned_buffer_;
+ std::unique_ptr<UniLib> owned_unilib_;
+ std::unique_ptr<CalendarLib> owned_calendarlib_;
+
private:
struct CompiledRegexPattern {
const RegexModel_::Pattern* config;
@@ -499,9 +500,7 @@
std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
selection_regex_patterns_;
- std::unique_ptr<UniLib> owned_unilib_;
const UniLib* unilib_;
- std::unique_ptr<CalendarLib> owned_calendarlib_;
const CalendarLib* calendarlib_;
std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
@@ -513,6 +512,7 @@
std::unique_ptr<const TranslateAnnotator> translate_annotator_;
std::unique_ptr<const PodNerAnnotator> pod_ner_annotator_;
std::unique_ptr<const ExperimentalAnnotator> experimental_annotator_;
+ std::unique_ptr<const VocabAnnotator> vocab_annotator_;
// Builder for creating extra data.
const reflection::Schema* entity_data_schema_;
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index 155e038..a049a22 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -261,6 +261,16 @@
classifier_options.user_familiar_language_tags,
JStringToUtf8String(env, user_familiar_language_tags.get()));
+ // .getTriggerDictionaryOnBeginnerWords()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_trigger_dictionary_on_beginner_words,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getTriggerDictionaryOnBeginnerWords", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ classifier_options.trigger_dictionary_on_beginner_words,
+ JniHelper::CallBooleanMethod(env, joptions,
+ get_trigger_dictionary_on_beginner_words));
+
return classifier_options;
}
@@ -335,6 +345,16 @@
annotation_options.permissions.has_personalization_permission =
has_personalization_permission;
annotation_options.annotate_mode = static_cast<AnnotateMode>(annotate_mode);
+
+ // .getTriggerDictionaryOnBeginnerWords()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_trigger_dictionary_on_beginner_words,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getTriggerDictionaryOnBeginnerWords", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ annotation_options.trigger_dictionary_on_beginner_words,
+ JniHelper::CallBooleanMethod(env, joptions,
+ get_trigger_dictionary_on_beginner_words));
return annotation_options;
}
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
new file mode 100644
index 0000000..a8fda33
--- /dev/null
+++ b/native/annotator/annotator_test-include.cc
@@ -0,0 +1,3012 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator_test-include.h"
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "annotator/annotator.h"
+#include "annotator/model_generated.h"
+#include "annotator/test-utils.h"
+#include "annotator/types-test-util.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/testing/annotator.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+using ::testing::Contains;
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::IsEmpty;
+using ::testing::UnorderedElementsAreArray;
+
+std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
+
+std::string GetModelWithGrammarPath() {
+ return GetModelPath() + "test_grammar_model.fb";
+}
+
+void FillDatetimeAnnotationOptionsToModel(
+ ModelT* unpacked_model, const DateAnnotationOptions& options) {
+ unpacked_model->grammar_datetime_model->annotation_options.reset(
+ new GrammarDatetimeModel_::AnnotationOptionsT);
+ unpacked_model->grammar_datetime_model->annotation_options
+ ->enable_date_range = options.enable_date_range;
+ unpacked_model->grammar_datetime_model->annotation_options
+ ->include_preposition = options.include_preposition;
+ unpacked_model->grammar_datetime_model->annotation_options
+ ->merge_adjacent_components = options.merge_adjacent_components;
+ unpacked_model->grammar_datetime_model->annotation_options
+ ->enable_special_day_offset = options.enable_special_day_offset;
+ for (const auto& extra_dates_rule_id : options.extra_requested_dates) {
+ unpacked_model->grammar_datetime_model->annotation_options
+ ->extra_requested_dates.push_back(extra_dates_rule_id);
+ }
+}
+
+std::string GetGrammarModel(const DateAnnotationOptions& options) {
+ const std::string test_model = ReadFile(GetModelWithGrammarPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ FillDatetimeAnnotationOptionsToModel(unpacked_model.get(), options);
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::string GetGrammarModel() {
+ DateAnnotationOptions options;
+ return GetGrammarModel(options);
+}
+
+void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
+ const std::string& currency,
+ const std::string& amount, const int whole_part,
+ const int decimal_part, const int nanos) {
+ ASSERT_GT(result.size(), 0);
+ ASSERT_GT(result[0].classification.size(), 0);
+ ASSERT_EQ(result[0].classification[0].collection, "money");
+
+ const EntityData* entity_data =
+ GetEntityData(result[0].classification[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->money(), nullptr);
+ EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
+ EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
+ EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
+ EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
+ EXPECT_EQ(entity_data->money()->nanos(), nanos);
+}
+
+TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
+ calendarlib_.get());
+ EXPECT_FALSE(classifier);
+}
+
+void VerifyClassifyText(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // More lines.
+ EXPECT_EQ("other",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {15, 27})));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {90, 103})));
+
+ // Single word.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
+
+ // Junk. These should not crash the test.
+ classifier->ClassifyText("", {0, 0});
+ classifier->ClassifyText("asdf", {0, 0});
+ classifier->ClassifyText("asdf", {0, 27});
+ classifier->ClassifyText("asdf", {-30, 300});
+ classifier->ClassifyText("asdf", {-10, -1});
+ classifier->ClassifyText("asdf", {100, 17});
+ classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
+
+ // Test invalid utf8 input.
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "\xf0\x9f\x98\x8b\x8b", {0, 0})));
+}
+
+TEST_F(AnnotatorTest, ClassifyText) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyText(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextWithGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyText(std::move(classifier.get()));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 6})));
+
+ ClassificationOptions classification_options;
+ classification_options.detected_text_language_tags = "en";
+ EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 6}, classification_options)));
+
+ classification_options.detected_text_language_tags = "uz";
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "isotope", {0, 6}, classification_options)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+
+ unpacked_model->classification_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+
+ // The classification model is still needed for selection scores.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(
+ classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone classification
+ unpacked_model->output_options->filtered_collections_classification.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // Check that the address classification still passes.
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "Barack Obama", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ EXPECT_EQ("person",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("email",
+ FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+
+ EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
+ EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
+ {7, 12})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "cc: 4012 8888 8888 1881", {4, 23})));
+ EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
+ "2221 0067 4735 6281", {0, 19})));
+ // Luhn check fails.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
+ {0, 19})));
+
+ // More lines.
+ EXPECT_EQ("url",
+ FirstResult(classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+}
+
+#ifndef TC3_DISABLE_LUA
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->lua_verifier = 0;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ unpacked_model->regex_model->lua_verifier.push_back(
+ "return match[2].text==\"99\"");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Custom rule triggers and is correctly verified.
+ EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
+ "99-00-123456-12345678", {0, 21})));
+
+ // Custom verification fails.
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "90-00-123456-12345678", {0, 21})));
+}
+#endif // TC3_DISABLE_LUA
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check with full name.
+ {
+ auto classifications =
+ classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+ }
+
+ // Check only with first name.
+ {
+ auto classifications =
+ classifier->ClassifyText("Barack is 57 years old", {0, 22});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+ }
+}
+
+TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ // Upper case last name as post-processing.
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->capturing_group[2]->normalization_options.reset(
+ new NormalizationOptionsT);
+ pattern->capturing_group[2]
+ ->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ auto classifications =
+ classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person_with_age", classifications[0].collection);
+
+ // Check entity data normalization.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "OBAMA");
+}
+
+TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.clear();
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
+ /*score=*/1.0, /*priority_score=*/1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
+ /*score=*/1.0, /*priority_score=*/0.0));
+
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight1",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ }
+
+ unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("flight2",
+ FirstResult(classifier->ClassifyText(
+ "Your flight LX373 is delayed by 3 hours.", {12, 17})));
+ }
+}
+
+TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ // Add test regex models. One of them has higher priority score than
+ // the other. We'll test that always the one with higher priority score
+ // ends up winning.
+ unpacked_model->regex_model->patterns.clear();
+ const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", flight_regex, /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
+ /*score=*/1.0, /*priority_score=*/1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", flight_regex, /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
+ /*score=*/1.0, /*priority_score=*/0.0));
+
+ // "flight" that wins should have a priority score of 1.0.
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::vector<AnnotatedSpan> results =
+ classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
+ ASSERT_THAT(results, Not(IsEmpty()));
+ EXPECT_THAT(results[0].classification, Not(IsEmpty()));
+ EXPECT_GE(results[0].classification[0].priority_score, 0.9);
+ }
+
+ // When we increase the priority score, the "flight" that wins should have a
+ // priority score of 3.0.
+ unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
+ {
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::vector<AnnotatedSpan> results =
+ classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
+ ASSERT_THAT(results, Not(IsEmpty()));
+ EXPECT_THAT(results[0].classification, Not(IsEmpty()));
+ EXPECT_GE(results[0].classification[0].priority_score, 2.9);
+ }
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
+ CodepointSpan(12, 19));
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ CodepointSpan(15, 27));
+ EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
+ CodepointSpan(4, 23));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
+ MakePattern("date_range",
+ "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
+ "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
+ custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
+ unpacked_model->regex_model->patterns.push_back(
+ std::move(custom_selection_bounds_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
+ {21, 23}),
+ CodepointSpan(10, 34));
+ EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
+ CodepointSpan(9, 17));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ CodepointSpan(26, 62));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
+ unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Check conflict resolution.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
+ {55, 57}),
+ CodepointSpan(55, 62));
+}
+
+TEST_F(AnnotatorTest, AnnotateRegex) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false,
+ /*enabled_for_annotation=*/true, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->verify_luhn_checksum = true;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ IsAnnotatedSpan(107, 126, "payment_card")}));
+}
+
+TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // ICAO is only used for selected airlines.
+ // Expected: LX373, EZY1234 and U21234.
+ const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
+ IsAnnotatedSpan(23, 30, "flight"),
+ IsAnnotatedSpan(32, 38, "flight")}));
+}
+
+#ifndef TC3_DISABLE_LUA
+TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> verified_pattern =
+ MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/true, 1.0);
+ verified_pattern->verification_options.reset(new VerificationOptionsT);
+ verified_pattern->verification_options->lua_verifier = 0;
+ unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
+ unpacked_model->regex_model->lua_verifier.push_back(
+ "return match[2].text==\"99\"");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "your parcel is on the way: 99-00-123456-12345678";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
+}
+#endif // TC3_DISABLE_LUA
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ annotations[0].classification[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ // Check `age`.
+ EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
+
+ // Check `is_alive`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+
+ // Check `former_us_president`.
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
+}
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ // Upper case last name as post-processing.
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->capturing_group[2]->normalization_options.reset(
+ new NormalizationOptionsT);
+ pattern->capturing_group[2]
+ ->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check normalization.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ annotations[0].classification[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "OBAMA");
+}
+
+TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ AddTestRegexModel(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = false;
+ auto annotations =
+ classifier->Annotate("Barack Obama is 57 years old", options);
+ EXPECT_EQ(1, annotations.size());
+ EXPECT_EQ(1, annotations[0].classification.size());
+ EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
+
+ // Check entity data.
+ EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
+}
+
+TEST_F(AnnotatorTest, PhoneFiltering) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789", {7, 20})));
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 25})));
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 28})));
+}
+
+TEST_F(AnnotatorTest, SuggestSelection) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ CodepointSpan(15, 21));
+
+ // Try passing whole string.
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+ CodepointSpan(0, 27));
+
+ // Single letter.
+ EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
+
+ // Single word.
+ EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ CodepointSpan(11, 25));
+ EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
+ CodepointSpan(12, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
+ CodepointSpan(11, 15));
+ EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
+ CodepointSpan(12, 15));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
+ CodepointSpan(11, 13));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
+ CodepointSpan(11, 12));
+ EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
+ CodepointSpan(11, 12));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ // Selection model needs to be present for annotation.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
+
+ // Disable the number annotator. With the selection model disabled, there is
+ // no feature processor, which is required for the number annotator.
+ unpacked_model->number_annotator_options->enabled = false;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 14));
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "call me at (800) 123-456 today", {11, 24})));
+
+ EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 23));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone selection
+ unpacked_model->output_options->filtered_collections_selection.push_back(
+ "phone");
+ // We need to force this for filtering.
+ unpacked_model->selection_options->always_classify_suggested_selection = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ CodepointSpan(11, 14));
+
+ // Address selection should still work.
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(
+ classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
+ {16, 22}),
+ CodepointSpan(6, 33));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
+ CodepointSpan(4, 16));
+ EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
+ CodepointSpan(0, 12));
+
+ SelectionOptions options;
+ EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
+ CodepointSpan(0, 12));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // From the right.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama, gave a speech at", {15, 26}),
+ CodepointSpan(15, 26));
+
+ // From the right multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
+ CodepointSpan(15, 26));
+
+ // From the left multiple.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
+ CodepointSpan(21, 32));
+
+ // From both sides.
+ EXPECT_EQ(classifier->SuggestSelection(
+ "this afternoon !BarackObama,- gave a speech at", {16, 27}),
+ CodepointSpan(16, 27));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // Try passing in bunch of invalid selections.
+ EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
+ CodepointSpan(-10, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
+ CodepointSpan(0, 27));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
+ CodepointSpan(-30, 300));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
+ CodepointSpan(-10, -1));
+ EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
+ CodepointSpan(100, 17));
+
+ // Try passing invalid utf8.
+ EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
+ CodepointSpan(-1, -1));
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
+ CodepointSpan(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
+ CodepointSpan(10, 11));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
+ CodepointSpan(23, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
+ CodepointSpan(23, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
+ {14, 17}),
+ CodepointSpan(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
+ CodepointSpan(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
+ CodepointSpan(14, 40));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
+ CodepointSpan(4, 5));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
+ CodepointSpan(7, 8));
+
+ // With a punctuation around the selected whitespace.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
+ CodepointSpan(14, 41));
+
+ // When all's whitespace, should return the original indices.
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
+ CodepointSpan(0, 1));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
+ CodepointSpan(0, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
+ CodepointSpan(2, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
+ CodepointSpan(5, 6));
+}
+
+TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
+ UnicodeText text;
+
+ text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(3, 4));
+ text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(3, 4));
+
+ // Nothing on the left.
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(4, 5));
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
+ CodepointSpan(0, 1));
+
+ // Whitespace only.
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
+ CodepointSpan(2, 3));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
+ CodepointSpan(4, 5));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
+ CodepointSpan(0, 1));
+}
+
+TEST_F(AnnotatorTest, Annotate) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ // Try passing invalid utf8.
+ EXPECT_TRUE(
+ classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
+ .empty());
+}
+
+TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 26, "phone"),
+ }));
+
+ // Unpaired bracket stripping.
+ EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(11, 22, "phone"),
+ }));
+ EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
+ ElementsAreArray({
+ IsAnnotatedSpan(12, 23, "phone"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ // Number, float number and percentage annotator.
+ EXPECT_THAT(
+ classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
+ "number, 12345.12345 float number",
+ options),
+ UnorderedElementsAreArray(
+ {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
+ IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
+ IsAnnotatedSpan(33, 35, "number"),
+ IsAnnotatedSpan(33, 36, "percentage"),
+ IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
+ IsAnnotatedSpan(49, 60, "phone")}));
+}
+
+TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
+
+ EXPECT_THAT(classifier->Annotate(
+ "853 225 3556 and then turn it up 99%, 99 number", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
+ IsAnnotatedSpan(33, 36, "percentage")}));
+}
+
+void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ // Duration annotator.
+ EXPECT_THAT(classifier->Annotate(
+ "it took 9 minutes and 7 seconds to get there", options),
+ Contains(IsDurationSpan(
+ /*start=*/8, /*end=*/31,
+ /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
+}
+
+TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotatesDurationsInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotatesDurationsInRawModeWithDatetimeGrammar) {
+ const std::string test_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotatesDurationsInRawMode(classifier.get());
+}
+
+TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+ const std::vector<AnnotatedSpan> annotations =
+ classifier->Annotate("let's meet in 3 hours", options);
+
+ EXPECT_THAT(annotations,
+ Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
+ /*time_ms_utc=*/10800000L,
+ DatetimeGranularity::GRANULARITY_HOUR)));
+ EXPECT_THAT(annotations, Contains(IsDurationSpan(
+ /*start=*/14, /*end=*/21,
+ /*duration_ms=*/3 * 60 * 60 * 1000)));
+}
+
+TEST_F(AnnotatorTest, AnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 =
+ "hey, sorry, just finished up. i didn't hear back from you in time.";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+
+ const int kAnnotationLength = 26;
+ EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
+ EXPECT_THAT(
+ classifier->Annotate(str2),
+ ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
+
+ const std::string str3 = str1 + "\n" + str2;
+ EXPECT_THAT(
+ classifier->Annotate(str3),
+ ElementsAreArray({IsAnnotatedSpan(
+ str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
+}
+
+TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ model->selection_feature_options->use_pipe_character_for_newline = true;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 = "hey, this is my phone number 853 225 3556";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+ const std::string str3 = str1 + "|" + str2;
+ const int kAnnotationLengthPhone = 12;
+ const int kAnnotationLengthAddress = 26;
+ // Splitting the lines on `str3` should have the same behavior (e.g. find the
+ // phone and address spans) as if we would annotate `str1` and `str2`
+ // individually.
+ const std::vector<AnnotatedSpan>& annotated_spans =
+ classifier->Annotate(str3);
+ EXPECT_THAT(annotated_spans,
+ ElementsAreArray(
+ {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
+ IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
+ static_cast<int>(str1.size() + 1 +
+ kAnnotationLengthAddress),
+ "address")}));
+}
+
+TEST_F(AnnotatorTest,
+ NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->selection_feature_options->only_use_line_with_click = true;
+ model->selection_feature_options->use_pipe_character_for_newline = false;
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ const std::string str1 = "hey, this is my phone number 853 225 3556";
+ const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
+ const std::string str3 = str1 + "|" + str2;
+ const std::vector<AnnotatedSpan>& annotated_spans =
+ classifier->Annotate(str3);
+ // Note: We only check that we get a single annotated span here when the '|'
+ // character is not used to split lines. The reason behind this is that the
+ // model is not precise for such example and the resulted annotated span might
+ // change when the model changes.
+ EXPECT_THAT(annotated_spans.size(), 1);
+}
+
+TEST_F(AnnotatorTest, AnnotateSmallBatches) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Set the batch size.
+ unpacked_model->selection_options->batch_size = 4;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ AnnotationOptions options;
+ EXPECT_THAT(classifier->Annotate("853 225 3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ // Add test threshold.
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 2.f; // Discards all results.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test thresholds.
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->min_annotate_confidence =
+ 0.f; // Keeps all results.
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
+}
+
+TEST_F(AnnotatorTest, AnnotateDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the model for annotation.
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone annotation
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // We add a custom annotator that wins against the phone classification
+ // below and that we subsequently suppress.
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "suppress");
+
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "suppress", "(\\d{3} ?\\d{4})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+
+ std::vector<ClassificationResult> result =
+ classifier->ClassifyText("january 1, 2017", {0, 15}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1483225200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInZurichTimezone(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInZurichTimezone(classifier.get());
+}
+
+void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "America/Los_Angeles";
+
+ std::vector<ClassificationResult> result =
+ classifier->ClassifyText("march 1, 2017", {0, 13}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1488355200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInLATimezone(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateInLATimezone(classifier.get());
+}
+
+void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+
+ std::vector<ClassificationResult> result = classifier->ClassifyText(
+ "hello world this is the first line\n"
+ "january 1, 2017",
+ {35, 50}, options);
+
+ EXPECT_THAT(result,
+ ElementsAre(IsDateResult(1483225200000,
+ DatetimeGranularity::GRANULARITY_DAY)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateOnAotherLine(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextDateOnAotherLine(classifier.get());
+}
+
+void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
+ const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
+
+ // In US, the date should be interpreted as <month>.<day>.
+ EXPECT_THAT(result,
+ ElementsAre(IsDatetimeResult(
+ 5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "de";
+ result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
+
+ // In Germany, the date should be interpreted as <day>.<month>.
+ EXPECT_THAT(result,
+ ElementsAre(IsDatetimeResult(
+ 10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ ClassificationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ const std::vector<ClassificationResult> result =
+ classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
+ IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ AnnotationOptions options;
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ const std::vector<AnnotatedSpan> spans =
+ classifier->Annotate("set an alarm for 10:30", options);
+
+ ASSERT_EQ(spans.size(), 1);
+ const std::vector<ClassificationResult> result = spans[0].classification;
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
+ IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
+}
+
+TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the patterns for selection.
+ for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
+ unpacked_model->datetime_model->patterns[i]->enabled_modes =
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION;
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_EQ("date",
+ FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
+ EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
+ CodepointSpan(0, 7));
+ EXPECT_THAT(classifier->Annotate("january 1, 2017"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+}
+
+TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test grammar model.
+ unpacked_model->grammar_model.reset(new GrammarModelT);
+ GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
+ grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
+ grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
+ grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
+ grammar_model->tokenizer_options->tokenize_on_script_change = true;
+
+ // Add test rules.
+ grammar_model->rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<tv_detective>", {"jessica", "fletcher"});
+ rules.Add("<tv_detective>", {"columbo"});
+ rules.Add("<tv_detective>", {"magnum"});
+ rules.Add(
+ "<famous_person>", {"<tv_detective>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/0 /* rule classification result */);
+
+ // Set result.
+ grammar_model->rule_classification_result.emplace_back(
+ new GrammarModel_::RuleClassificationResultT);
+ GrammarModel_::RuleClassificationResultT* result =
+ grammar_model->rule_classification_result.back().get();
+ result->collection_name = "famous person";
+ result->enabled_modes = ModeFlag_ALL;
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model->rules.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "Did you see the Novel Connection episode where Jessica Fletcher helps "
+ "Magnum solve the case? I thought that was with Columbo ...";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
+ IsAnnotatedSpan(70, 76, "famous person"),
+ IsAnnotatedSpan(117, 124, "famous person")));
+ EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
+ CodepointSpan{0, 16})),
+ Eq("famous person"));
+ EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
+ Eq(CodepointSpan{0, 16}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{
+ {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsSequence) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 1}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 2}, "phone", 1.0),
+ MakeAnnotatedSpan({2, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 4}, "phone", 1.0),
+ MakeAnnotatedSpan({4, 5}, "phone", 1.0),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 1.0),
+ MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 1.0),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
+ MakeAnnotatedSpan({1, 5}, "phone", 1.0),
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({1}));
+}
+
+TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({3, 7}, "unit", 1),
+ MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 30}, "url", 1), // Looser!
+ MakeAnnotatedSpan({14, 20}, "email", 1),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ // Picks the first and the last annotations because they do not overlap.
+ EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
+}
+
+TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ unpacked_model->conflict_resolution_options.reset(
+ new Model_::ConflictResolutionOptionsT);
+ unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
+ true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TestingAnnotator> classifier =
+ TestingAnnotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ TC3_CHECK(classifier != nullptr);
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({3, 7}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
+ MakeAnnotatedSpan({5, 30}, "url", 1), // Pick longest match.
+ MakeAnnotatedSpan({14, 20}, "email", 1), // Looser!
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({2}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 3}, "phone", 0.5),
+ MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
+ MakeAnnotatedSpan({3, 7}, "phone", 0.6),
+ MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
+ MakeAnnotatedSpan({11, 15}, "phone", 0.9),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "entity", 0.7,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ MakeAnnotatedSpan({5, 10}, "address", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "entity", 0.6,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "entity", 0.7,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ MakeAnnotatedSpan({5, 10}, "entity", 0.6,
+ AnnotatedSpan::Source::KNOWLEDGE),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
+ TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "date", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0}));
+}
+
+TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
+ TestingAnnotator classifier(
+ unilib_.get(), calendarlib_.get(), [](ModelT* model) {
+ model->conflict_resolution_options.reset(
+ new Model_::ConflictResolutionOptionsT);
+ model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
+ false;
+ });
+
+ std::vector<AnnotatedSpan> candidates{{
+ MakeAnnotatedSpan({0, 15}, "address", 0.7),
+ MakeAnnotatedSpan({5, 10}, "date", 0.6),
+ }};
+ std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ std::vector<int> chosen;
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
+ locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ /*interpreter_manager=*/nullptr, &chosen);
+ EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
+}
+
+void VerifyLongInput(const Annotator* classifier) {
+ ASSERT_TRUE(classifier);
+
+ for (const auto& type_value_pair :
+ std::vector<std::pair<std::string, std::string>>{
+ {"address", "350 Third Street, Cambridge"},
+ {"phone", "123 456-7890"},
+ {"url", "www.google.com"},
+ {"email", "someone@gmail.com"},
+ {"flight", "LX 38"},
+ {"date", "September 1, 2018"}}) {
+ const std::string input_100k = std::string(50000, ' ') +
+ type_value_pair.second +
+ std::string(50000, ' ');
+ const int value_length = type_value_pair.second.size();
+
+ EXPECT_THAT(classifier->Annotate(input_100k),
+ ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
+ type_value_pair.first)}));
+ EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
+ CodepointSpan(50000, 50000 + value_length));
+ EXPECT_EQ(type_value_pair.first,
+ FirstResult(classifier->ClassifyText(
+ input_100k, {50000, 50000 + value_length})));
+ }
+}
+
+TEST_F(AnnotatorTest, LongInput) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyLongInput(classifier.get());
+}
+
+TEST_F(AnnotatorTest, LongInputWithDatetimeGrammar) {
+ const std::string test_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ VerifyLongInput(classifier.get());
+}
+
+// These coarse tests are there only to make sure the execution happens in
+// reasonable amount of time.
+TEST_F(AnnotatorTest, LongInputNoResultCheck) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ for (const std::string& value :
+ std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
+ const std::string input_100k =
+ std::string(50000, ' ') + value + std::string(50000, ' ');
+ const int value_length = value.size();
+
+ classifier->Annotate(input_100k);
+ classifier->SuggestSelection(input_100k, {50000, 50001});
+ classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
+ }
+}
+
+TEST_F(AnnotatorTest, MaxTokenLength) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of tokens should behave normally.
+ unpacked_model->classification_options->max_num_tokens = -1;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise the maximum number of tokens to suppress the classification.
+ unpacked_model->classification_options->max_num_tokens = 3;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+
+TEST_F(AnnotatorTest, MinAddressTokenLength) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<Annotator> classifier;
+
+ // With unrestricted number of address tokens should behave normally.
+ unpacked_model->classification_options->address_min_num_tokens = 0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise number of address tokens to suppress the address classification.
+ unpacked_model->classification_options->address_min_num_tokens = 5;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
+ classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+
+TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->triggering_options->other_collection_priority_score = 1.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
+}
+
+TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->triggering_options->other_collection_priority_score = -100.0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
+}
+
+TEST_F(AnnotatorTest, VisitAnnotatorModel) {
+ EXPECT_TRUE(
+ VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitAnnotatorModel<bool>(
+ GetModelPath() + "non_existing_model.fb", [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
+TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
+ ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
+ EXPECT_EQ("phone",
+ FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
+ CodepointSpan(0, 14));
+}
+
+TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
+}
+
+TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
+ {0, 14}, options)));
+}
+
+TEST_F(AnnotatorTest,
+ ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
+ CodepointSpan(0, 14));
+}
+
+TEST_F(AnnotatorTest,
+ SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(
+ model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
+ CodepointSpan(6, 9));
+}
+
+TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
+ ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ AnnotationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest,
+ MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27}, options)));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ ClassificationOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_THAT(
+ classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
+ IsEmpty());
+}
+
+TEST_F(AnnotatorTest,
+ MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "cs";
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
+ options),
+ CodepointSpan(0, 27));
+}
+
+TEST_F(AnnotatorTest,
+ MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
+ std::string model_buffer = ReadFile(GetTestModelPath());
+ model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
+ model->triggering_locales = "en,cs";
+ model->triggering_options->locales = "en,cs";
+ });
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ SelectionOptions options;
+ options.detected_text_language_tags = "de";
+
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
+ options),
+ CodepointSpan(4, 9));
+}
+
+void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
+
+ ASSERT_GE(result.size(), 0);
+ const EntityData* entity_data =
+ GetEntityData(result[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->datetime(), nullptr);
+ EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
+ EXPECT_EQ(entity_data->datetime()->granularity(),
+ EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
+ EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
+
+ auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
+ EXPECT_EQ(meridiem->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
+ EXPECT_EQ(meridiem->absolute_value(), 0);
+ EXPECT_EQ(meridiem->relative_count(), 0);
+ EXPECT_EQ(meridiem->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* minute = entity_data->datetime()->datetime_component()->Get(1);
+ EXPECT_EQ(minute->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
+ EXPECT_EQ(minute->absolute_value(), 0);
+ EXPECT_EQ(minute->relative_count(), 0);
+ EXPECT_EQ(minute->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* hour = entity_data->datetime()->datetime_component()->Get(2);
+ EXPECT_EQ(hour->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
+ EXPECT_EQ(hour->absolute_value(), 0);
+ EXPECT_EQ(hour->relative_count(), 0);
+ EXPECT_EQ(hour->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* day = entity_data->datetime()->datetime_component()->Get(3);
+ EXPECT_EQ(
+ day->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
+ EXPECT_EQ(day->absolute_value(), 5);
+ EXPECT_EQ(day->relative_count(), 0);
+ EXPECT_EQ(day->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* month = entity_data->datetime()->datetime_component()->Get(4);
+ EXPECT_EQ(month->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
+ EXPECT_EQ(month->absolute_value(), 3);
+ EXPECT_EQ(month->relative_count(), 0);
+ EXPECT_EQ(month->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* year = entity_data->datetime()->datetime_component()->Get(5);
+ EXPECT_EQ(year->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
+ EXPECT_EQ(year->absolute_value(), 1970);
+ EXPECT_EQ(year->relative_count(), 0);
+ EXPECT_EQ(year->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+}
+
+TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest,
+ ClassifyTextOutputsDatetimeEntityDataWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
+}
+
+void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
+ EXPECT_TRUE(classifier);
+ std::vector<AnnotatedSpan> result;
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+
+ result = classifier->Annotate("September 1, 2019", options);
+
+ ASSERT_GE(result.size(), 0);
+ ASSERT_GE(result[0].classification.size(), 0);
+ ASSERT_EQ(result[0].classification[0].collection, "date");
+ const EntityData* entity_data =
+ GetEntityData(result[0].classification[0].serialized_entity_data.data());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data->datetime(), nullptr);
+ EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
+ EXPECT_EQ(entity_data->datetime()->granularity(),
+ EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
+ EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
+
+ auto* day = entity_data->datetime()->datetime_component()->Get(0);
+ EXPECT_EQ(
+ day->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
+ EXPECT_EQ(day->absolute_value(), 1);
+ EXPECT_EQ(day->relative_count(), 0);
+ EXPECT_EQ(day->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* month = entity_data->datetime()->datetime_component()->Get(1);
+ EXPECT_EQ(month->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
+ EXPECT_EQ(month->absolute_value(), 9);
+ EXPECT_EQ(month->relative_count(), 0);
+ EXPECT_EQ(month->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+
+ auto* year = entity_data->datetime()->datetime_component()->Get(2);
+ EXPECT_EQ(year->component_type(),
+ EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
+ EXPECT_EQ(year->absolute_value(), 2019);
+ EXPECT_EQ(year->relative_count(), 0);
+ EXPECT_EQ(year->relation_type(),
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatetimeGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
+}
+
+TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ EXPECT_TRUE(classifier);
+ AnnotationOptions options;
+ options.is_serialized_entity_data_enabled = true;
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
+ /*amount=*/"3.5", /*whole_part=*/3,
+ /*decimal_part=*/5, /*nanos=*/500000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
+ /*amount=*/"3.5", /*whole_part=*/3,
+ /*decimal_part=*/5, /*nanos=*/500000000);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("For online purchase of CHF 23.00 enter", options),
+ "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("For online purchase of 23.00 CHF enter", options),
+ "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
+ /*amount=*/"4.8198", /*whole_part=*/4,
+ /*decimal_part=*/8198, /*nanos=*/819800000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
+ /*amount=*/"4.8198", /*whole_part=*/4,
+ /*decimal_part=*/8198, /*nanos=*/819800000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
+ /*amount=*/"0.0255", /*whole_part=*/0,
+ /*decimal_part=*/255, /*nanos=*/25500000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
+ /*amount=*/"0.0255", /*whole_part=*/0,
+ /*decimal_part=*/255, /*nanos=*/25500000);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
+ "OR on card ending 0000.",
+ options),
+ "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
+ "OR on card ending 0000.",
+ options),
+ "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
+ /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
+ /*amount=*/"35",
+ /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
+ /*amount=*/"35", /*whole_part=*/35,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("and win back up to CHF 150 - with digitec",
+ options),
+ "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(
+ classifier->Annotate("and win back up to 150 CHF - with digitec",
+ options),
+ "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
+ /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
+ "CHF", /*amount=*/"3.555.333",
+ /*whole_part=*/3555333, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
+ "CHF", /*amount=*/"3.555.333",
+ /*whole_part=*/3555333, /*decimal_part=*/0,
+ /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
+ /*amount=*/"10,000", /*whole_part=*/10000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
+ /*amount=*/"10,000", /*whole_part=*/10000,
+ /*decimal_part=*/0, /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
+ /*amount=*/"3,555.33", /*whole_part=*/3555,
+ /*decimal_part=*/33, /*nanos=*/330000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
+ /*amount=*/"3,555.33", /*whole_part=*/3555,
+ /*decimal_part=*/33, /*nanos=*/330000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
+ /*amount=*/"3,000.00", /*whole_part=*/3000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
+ /*amount=*/"3,000.00", /*whole_part=*/3000,
+ /*decimal_part=*/0, /*nanos=*/0);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
+ /*amount=*/"1.2", /*whole_part=*/1,
+ /*decimal_part=*/2, /*nanos=*/200000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
+ /*amount=*/"1.2", /*whole_part=*/1,
+ /*decimal_part=*/2, /*nanos=*/200000000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
+ /*amount=*/"1.123456789", /*whole_part=*/1,
+ /*decimal_part=*/123456789, /*nanos=*/123456789);
+ ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
+ /*amount=*/"10.01", /*whole_part=*/10,
+ /*decimal_part=*/1, /*nanos=*/10000000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
+ /*amount=*/"59 million", /*whole_part=*/59000000,
+ /*decimal_part=*/0, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
+ /*amount=*/"7.05 k", /*whole_part=*/7050,
+ /*decimal_part=*/5, /*nanos=*/0);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
+ /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
+ /*decimal_part=*/123456789, /*nanos=*/789000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
+ /*amount=*/"7.000056789 k", /*whole_part=*/7000,
+ /*decimal_part=*/56789, /*nanos=*/56789000);
+
+ ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
+ /*amount=*/"59.3 billion", /*whole_part=*/59,
+ /*decimal_part=*/3, /*nanos=*/300000000);
+ ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
+ /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
+ /*decimal_part=*/5, /*nanos=*/0);
+}
+
+TEST_F(AnnotatorTest, TranslateAction) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
+ "lang_id.smfb");
+ classifier->SetLangId(langid_model.get());
+
+ ClassificationOptions options;
+ options.user_familiar_language_tags = "de";
+
+ std::vector<ClassificationResult> classifications =
+ classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
+ EXPECT_EQ(classifications.size(), 1);
+ EXPECT_EQ(classifications[0].collection, "translate");
+}
+
+TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+
+ std::vector<InputFragment> string_fragments = {
+ {.text = "He owes me 3.5 CHF."},
+ {.text = "...was born on 13/12/1989."},
+ };
+
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments,
+ AnnotationOptions());
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
+}
+
+TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+
+ AnnotationOptions annotation_options;
+ annotation_options.reference_time_ms_utc =
+ 1554465190000; // 04/05/2019 11:53 am
+ int64 fragment_reference_time = 946727580000; // 01/01/2000 11:53 am
+ std::vector<InputFragment> string_fragments = {
+ {.text = "New event at 17:20"},
+ {
+ .text = "New event at 17:20",
+ .datetime_options = Optional<DatetimeOptions>(
+ {.reference_time_ms_utc = fragment_reference_time}),
+ }};
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments, annotation_options);
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+}
+
+TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ std::vector<InputFragment> string_fragments = {
+ {.text = "11/12/2020 17:20"},
+ {
+ .text = "11/12/2020 17:20",
+ .datetime_options = Optional<DatetimeOptions>(
+ {.reference_timezone = "Europe/Zurich"}),
+ }};
+ StatusOr<Annotations> annotations_status =
+ classifier->AnnotateStructuredInput(string_fragments,
+ AnnotationOptions());
+ ASSERT_TRUE(annotations_status.ok());
+ Annotations annotations = annotations_status.ValueOrDie();
+ ASSERT_EQ(annotations.annotated_spans.size(), 2);
+ EXPECT_THAT(annotations.annotated_spans[0],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+ EXPECT_THAT(annotations.annotated_spans[1],
+ ElementsAreArray({IsDatetimeSpan(
+ /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
+ DatetimeGranularity::GRANULARITY_MINUTE)}));
+}
+
+namespace {
+void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
+ unpacked_model->datetime_model.reset(new DatetimeModelT);
+ // This needs to be false otherwise we'd have to define some extractor. When
+ // this is false, the 0-th capturing group (whole match) from the pattern is
+ // used to come up with the indices.
+ unpacked_model->datetime_model->use_extractors_for_locating = false;
+ unpacked_model->datetime_model->locales.push_back("en-US");
+ unpacked_model->datetime_model->default_locales.push_back(0); // en-US
+ unpacked_model->datetime_model->patterns.push_back(
+ std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
+ unpacked_model->datetime_model->patterns.back()->locales.push_back(
+ 0); // en-US
+ unpacked_model->datetime_model->patterns.back()->regexes.push_back(
+ std::unique_ptr<DatetimeModelPattern_::RegexT>(
+ new DatetimeModelPattern_::RegexT));
+ unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
+ "THIS_MATCHES_IN_REGEX_MODEL";
+ unpacked_model->datetime_model->patterns.back()
+ ->regexes.back()
+ ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
+}
+} // namespace
+
+TEST_F(AnnotatorTest, WorksWithBothRegexAndGrammarDatetimeAtOnce) {
+ // This creates a model that has both regex and grammar datetime. The regex
+ // one is broken, and only matches strings "THIS_MATCHES_IN_REGEX_MODEL",
+ // so that we can use it for testing that both models are used correctly.
+ const std::string test_model = ReadFile(GetModelWithGrammarPath());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
+ AddDummyRegexDatetimeModel(unpacked_model.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate("THIS_MATCHES_IN_REGEX_MODEL, and this in "
+ "grammars: February 2, 2020"),
+ ElementsAreArray({
+ // Regex model match
+ IsAnnotatedSpan(0, 27, "date"),
+ // Grammar model match
+ IsAnnotatedSpan(51, 67, "date"),
+ }));
+ EXPECT_EQ(
+ FirstResult(classifier->ClassifyText(
+ "THIS_MATCHES_IN_REGEX_MODEL, and this in grammars: February 2, 2020",
+ {0, 27})),
+ "date");
+ EXPECT_EQ(
+ FirstResult(classifier->ClassifyText(
+ "THIS_MATCHES_IN_REGEX_MODEL, and this in grammars: February 2, 2020",
+ {51, 67})),
+ "date");
+}
+
+TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
+ std::unique_ptr<Annotator> classifier = Annotator::FromPath(
+ GetTestModelPath(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ // This test assumes that both ML model and Regex model trigger on the
+ // following text and output "phone" annotation for it.
+ const std::string test_string = "1000000000";
+ AnnotationOptions options;
+ options.annotation_usecase = ANNOTATION_USECASE_RAW;
+ int num_phones = 0;
+ for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
+ if (span.classification[0].collection == "phone") {
+ num_phones++;
+ }
+ }
+
+ EXPECT_EQ(num_phones, 1);
+}
+
+TEST_F(AnnotatorTest, AnnotateUsingGrammar) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(19, 24, "date"),
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+}
+
+TEST_F(AnnotatorTest, AnnotateGrammarPriority) {
+ const std::string grammar_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromUnownedBuffer(grammar_model.c_str(), grammar_model.size(),
+ unilib_.get(), calendarlib_.get());
+
+ ASSERT_TRUE(classifier);
+
+ // "May 8, 2015, 545" have two annotation
+ // 1) {May 8, 2015}, 545 -> Datetime (High Priority Score)
+ // 2) May 8, {2015, 545} -> phone number (Low Priority Score)
+ const std::string test_string =
+ "May 8, 2015, 545, shares Comment Now that Joss Whedon";
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({IsAnnotatedSpan(0, 11, "date")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateGrammarDatetimeRangesDisable) {
+ const std::string test_model = GetGrammarModel();
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("11 Jan to 3 Feb"),
+ ElementsAreArray({IsAnnotatedSpan(0, 6, "date"),
+ IsAnnotatedSpan(10, 15, "date")}));
+
+ EXPECT_THAT(classifier->Annotate("11 - 14 of Feb"),
+ ElementsAreArray({IsAnnotatedSpan(5, 14, "date")}));
+
+ EXPECT_THAT(classifier->Annotate("Monday 10 - 11pm"),
+ ElementsAreArray({IsAnnotatedSpan(0, 6, "date"),
+ IsAnnotatedSpan(12, 16, "datetime")}));
+
+ EXPECT_THAT(classifier->Annotate("7:20am - 8:00pm"),
+ ElementsAreArray({IsAnnotatedSpan(0, 6, "datetime"),
+ IsAnnotatedSpan(9, 15, "datetime")}));
+}
+
+TEST_F(AnnotatorTest, AnnotateGrammarDatetimeRangesEnable) {
+ DateAnnotationOptions options;
+ options.enable_date_range = true;
+ options.merge_adjacent_components = true;
+ const std::string test_model = GetGrammarModel(options);
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate("11 Jan to 3 Feb"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+
+ EXPECT_THAT(classifier->Annotate("11 - 14 of Feb"),
+ ElementsAreArray({IsAnnotatedSpan(0, 14, "date")}));
+
+ EXPECT_THAT(classifier->Annotate("Monday 10 - 11pm"),
+ ElementsAreArray({IsAnnotatedSpan(0, 16, "datetime")}));
+
+ EXPECT_THAT(classifier->Annotate("7:20am - 8:00pm"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "datetime")}));
+}
+
+TEST_F(AnnotatorTest, InitializeFromString) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
+}
+
+} // namespace test_internal
+} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_test-include.h b/native/annotator/annotator_test-include.h
new file mode 100644
index 0000000..bcbb9e9
--- /dev/null
+++ b/native/annotator/annotator_test-include.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
+
+#include <fstream>
+#include <string>
+
+#include "annotator/annotator.h"
+#include "utils/base/logging.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/annotator.h"
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace test_internal {
+
+inline std::string GetModelPath() {
+ return GetTestDataPath("annotator/test_data/");
+}
+
+class TestingAnnotator : public Annotator {
+ public:
+ TestingAnnotator(
+ const UniLib* unilib, const CalendarLib* calendarlib,
+ const std::function<void(ModelT*)> model_update_fn = [](ModelT* model) {
+ }) {
+ owned_buffer_ = CreateEmptyModel(model_update_fn);
+ ValidateAndInitialize(libtextclassifier3::ViewModel(owned_buffer_.data(),
+ owned_buffer_.size()),
+ unilib, calendarlib);
+ }
+
+ static std::unique_ptr<TestingAnnotator> FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr) {
+ // Safe to downcast from Annotator* to TestingAnnotator* because the
+ // subclass is not adding any new members.
+ return std::unique_ptr<TestingAnnotator>(
+ reinterpret_cast<TestingAnnotator*>(
+ Annotator::FromUnownedBuffer(buffer, size, unilib, calendarlib)
+ .release()));
+ }
+
+ using Annotator::ResolveConflicts;
+};
+
+class AnnotatorTest : public ::testing::TestWithParam<const char*> {
+ protected:
+ AnnotatorTest()
+ : unilib_(CreateUniLibForTesting()),
+ calendarlib_(CreateCalendarLibForTesting()) {}
+
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+} // namespace test_internal
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_TEST_INCLUDE_H_
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc
new file mode 100644
index 0000000..76b033d
--- /dev/null
+++ b/native/annotator/datetime/parser_test.cc
@@ -0,0 +1,1536 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/parser.h"
+
+#include <time.h>
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "annotator/annotator.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/annotator.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using std::vector;
+using testing::ElementsAreArray;
+
+namespace libtextclassifier3 {
+namespace {
+// Builder class to construct the DatetimeComponents and make the test readable.
+class DatetimeComponentsBuilder {
+ public:
+ DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
+ int value) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ return AddComponent(component);
+ }
+
+ DatetimeComponentsBuilder Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ component.relative_qualifier = relative_qualifier;
+ component.relative_count = relative_count;
+ return AddComponent(component);
+ }
+
+ std::vector<DatetimeComponent> Build() {
+ std::vector<DatetimeComponent> result(datetime_components_);
+ datetime_components_.clear();
+ return result;
+ }
+
+ private:
+ DatetimeComponentsBuilder AddComponent(
+ const DatetimeComponent& datetime_component) {
+ datetime_components_.push_back(datetime_component);
+ return *this;
+ }
+ std::vector<DatetimeComponent> datetime_components_;
+};
+
+std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+class ParserTest : public testing::Test {
+ public:
+ void SetUp() override {
+ // Loads default unmodified model. Individual tests can call LoadModel to
+ // make changes.
+ LoadModel([](ModelT* model) {});
+ }
+
+ template <typename Fn>
+ void LoadModel(Fn model_visitor_fn) {
+ std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
+ model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
+ unilib_ = CreateUniLibForTesting();
+ calendarlib_ = CreateCalendarLibForTesting();
+ classifier_ =
+ Annotator::FromUnownedBuffer(model_buffer_.data(), model_buffer_.size(),
+ unilib_.get(), calendarlib_.get());
+ TC3_CHECK(classifier_);
+ parser_ = classifier_->DatetimeParserForTests();
+ TC3_CHECK(parser_);
+ }
+
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ std::vector<DatetimeParseResultSpan> results;
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ annotation_usecase, anchor_start_end, &results)) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ return results.empty();
+ }
+
+ bool ParsesCorrectly(const std::string& marked_text,
+ const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC3_CHECK(brace_open_it != marked_text_unicode.end());
+ TC3_CHECK(brace_end_it != marked_text_unicode.end());
+
+ std::string text;
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
+
+ std::vector<DatetimeParseResultSpan> results;
+
+ if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
+ annotation_usecase, anchor_start_end, &results)) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ if (results.empty()) {
+ TC3_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 below is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results) {
+ if (SpansOverlap(result.span,
+ {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+ std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index},
+ {},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
+ datetime_components[i]};
+ }
+
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
+ }
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
+ }
+
+ return matches;
+ }
+
+ bool ParsesCorrectly(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
+ expected_granularity, datetime_components,
+ anchor_start_end, timezone, locales,
+ annotation_usecase);
+ }
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
+ bool ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
+ }
+
+ protected:
+ std::string model_buffer_;
+ std::unique_ptr<Annotator> classifier_;
+ const DatetimeParser* parser_;
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+};
+
+// Test with just a few cases to make debugging of general failures easier.
+TEST_F(ParserTest, ParseShort) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+}
+
+TEST_F(ParserTest, Parse) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 31 2018}", 1517353200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{09/Mar/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Dec 2, 2010 2:39:58 AM}", 1291253998000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Jun 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{Mar 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{9/28/2011 2:23:15 PM}", 1317212595000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 23)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 28)
+ .Add(DatetimeComponent::ComponentType::MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "Are sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. {19/apr/2010 06:36:15} Are "
+ "sentiments apartments decisively the especially alteration. "
+ "Thrown shy denote ten ladies though ask saw. Or by to he going "
+ "think order event music. Incommode so intention defective at "
+ "convinced. Led income months itself and houses you. After nor "
+ "you leave might share court balls. ",
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30 am}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4pm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-3600000, 39600000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
+ /*anchor_start_end=*/false, "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 4am}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{wednesday at 4am}", 529200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 4,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "last seen {today at 9:01 PM}", 72060000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 9)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "set an alarm for {7am tomorrow}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "lets meet by {noon}", 39600000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "at {midnight}", 82800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+}
+
+TEST_F(ParserTest, ParseWithAnchor) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988}", 567990000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()},
+ /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly(
+ "lorem {1 january 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()},
+ /*anchor_start_end=*/false));
+ EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
+ /*anchor_start_end=*/true));
+}
+
+TEST_F(ParserTest, ParseWithRawUsecase) {
+ // Annotated for RAW usecase.
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow}", 82800000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {in two hours}", 7200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::HOUR, 0,
+ DatetimeComponent::RelativeQualifier::FUTURE, 2)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {next month}", 2674800000, GRANULARITY_MONTH,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NEXT, 1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+ EXPECT_TRUE(ParsesCorrectly(
+ "what's the time {now}", -3600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::NOW, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me on {Saturday}", 169200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ // Not annotated for Smart usecase.
+ EXPECT_TRUE(HasNoResult(
+ "{tomorrow}", /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
+}
+
+// For details please see b/155437137
+TEST_F(ParserTest, PastRelativeDatetime) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last Saturday}",
+ -432000000 /* Fri 1969-12-26 16:00:00 PST */, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 7,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last year}", -31536000000 /* Tue 1968-12-31 16:00:00 PST */,
+ GRANULARITY_YEAR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::YEAR, 0,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {last month}", -2678400000 /* Sun 1969-11-30 16:00:00 PST */,
+ GRANULARITY_MONTH,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MONTH, 0,
+ DatetimeComponent::RelativeQualifier::PAST, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "called you {yesterday}", -90000000, /* Tue 1969-12-30 15:00:00 PST */
+ GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::YESTERDAY, -1)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+}
+
+TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past, and so the parser should move this to the next day ->
+ // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907).
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", 84600000L /* 23.5 hours from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past. The parameter prefer_future_when_unspecified_day is
+ // disabled, so the parser should annotate this to the same day: "Thu Jan 01
+ // 1970 00:30:00" Zurich time.
+ LoadModel([](ModelT* model) {
+ // In the test model, the prefer_future_for_unspecified_date is true; make
+ // it false only for this test.
+ model->datetime_model->prefer_future_for_unspecified_date = false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", -1800000L /* -30 minutes from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{January 1, 1988 12:30pm}", 568035000000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 1988)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow at 12:00 am}", 82800000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 12)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(ParserTest, ParseGerman) {
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{1/2/2018}", 1517439600000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "lorem {1 Januar 2018} ipsum", 1514761200000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/Apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{09/März/2004 22:02:40}", 1078866160000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 40)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 02)
+ .Add(DatetimeComponent::ComponentType::HOUR, 22)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2004)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Dez 2, 2010 2:39:58}", {1291253998000, 1291297198000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 58)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 39)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 2)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{Juni 09 2011 15:28:14}", 1307626094000, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::SECOND, 14)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 28)
+ .Add(DatetimeComponent::ComponentType::HOUR, 15)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 9)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2011)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 4)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 12)
+ .Add(DatetimeComponent::ComponentType::HOUR, 8)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 16)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2010-06-26 02:31:29}", {1277512289000, 1277555489000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 29)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 31)
+ .Add(DatetimeComponent::ComponentType::HOUR, 2)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 26)
+ .Add(DatetimeComponent::ComponentType::MONTH, 6)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{2006/01/22 04:11:05}", {1137899465000, 1137942665000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 5)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 11)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 22)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2006)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015:11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23-Apr-2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23 Apr 2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/15 11:42:35}", {1429782155000, 1429825355000}, GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{04/23/2015 11:42:35}", {1429782155000, 1429825355000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 35)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 42)
+ .Add(DatetimeComponent::ComponentType::HOUR, 11)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 23)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{19/apr/2010:06:36:15}", {1271651775000, 1271694975000},
+ GRANULARITY_SECOND,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::SECOND, 15)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 36)
+ .Add(DatetimeComponent::ComponentType::HOUR, 6)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 19)
+ .Add(DatetimeComponent::ComponentType::MONTH, 4)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2010)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4:30 nachm}", 1514820600000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{januar 1 2018 um 4 nachm}", 1514818800000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{14.03.2017}", 1489446000000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 14)
+ .Add(DatetimeComponent::ComponentType::MONTH, 3)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2017)
+ .Build()}));
+
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen 0:00}", {82800000, 126000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4:00}", {97200000, 140400000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(ParserTest, ParseChinese) {
+ EXPECT_TRUE(ParsesCorrectlyChinese(
+ "{明天 7 上午}", 108000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 7)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 0,
+ DatetimeComponent::RelativeQualifier::TOMORROW, 1)
+ .Build()}));
+}
+
+TEST_F(ParserTest, ParseNonUs) {
+ auto first_may_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-GB"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1430431200000, GRANULARITY_DAY,
+ {first_may_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
+}
+
+TEST_F(ParserTest, ParseUs) {
+ auto five_january_2015 =
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build();
+
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"en-US"));
+ EXPECT_TRUE(ParsesCorrectly("{1/5/2015}", 1420412400000, GRANULARITY_DAY,
+ {five_january_2015},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*locales=*/"es-US"));
+}
+
+TEST_F(ParserTest, ParseUnknownLanguage) {
+ EXPECT_TRUE(ParsesCorrectly(
+ "bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 31)
+ .Add(DatetimeComponent::ComponentType::MONTH, 12)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2015)
+ .Build()},
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
+}
+
+TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ true;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3pm}", 396000000, GRANULARITY_HOUR,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{monday 3:00}", {352800000, 396000000}, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build(),
+ DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 1)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 0)
+ .Add(DatetimeComponent::ComponentType::HOUR, 3)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_WEEK, 2,
+ DatetimeComponent::RelativeQualifier::THIS, 0)
+ .Build()}));
+}
+
+TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{january 1 2018 at 4:30}", 1514777400000, GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 4)
+ .Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::MONTH, 1)
+ .Add(DatetimeComponent::ComponentType::YEAR, 2018)
+ .Build()}));
+}
+
+class ParserLocaleTest : public testing::Test {
+ public:
+ void SetUp() override;
+ bool HasResult(const std::string& input, const std::string& locales);
+
+ protected:
+ std::unique_ptr<UniLib> unilib_;
+ std::unique_ptr<CalendarLib> calendarlib_;
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+void AddPattern(const std::string& regex, int locale,
+ std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
+ patterns->emplace_back(new DatetimeModelPatternT);
+ patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
+ patterns->back()->regexes.back()->pattern = regex;
+ patterns->back()->regexes.back()->groups.push_back(
+ DatetimeGroupType_GROUP_UNUSED);
+ patterns->back()->locales.push_back(locale);
+}
+
+void ParserLocaleTest::SetUp() {
+ DatetimeModelT model;
+ model.use_extractors_for_locating = false;
+ model.locales.clear();
+ model.locales.push_back("en-US");
+ model.locales.push_back("en-CH");
+ model.locales.push_back("zh-Hant");
+ model.locales.push_back("en-*");
+ model.locales.push_back("zh-Hant-*");
+ model.locales.push_back("*-CH");
+ model.locales.push_back("default");
+ model.default_locales.push_back(6);
+
+ AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
+ AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
+ AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
+ AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
+ AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
+
+ builder_.Finish(DatetimeModel::Pack(builder_, &model));
+ const DatetimeModel* model_fb =
+ flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
+ ASSERT_TRUE(model_fb);
+
+ unilib_ = CreateUniLibForTesting();
+ calendarlib_ = CreateCalendarLibForTesting();
+ parser_ =
+ DatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
+ /*decompressor=*/nullptr);
+ ASSERT_TRUE(parser_);
+}
+
+bool ParserLocaleTest::HasResult(const std::string& input,
+ const std::string& locales) {
+ std::vector<DatetimeParseResultSpan> results;
+ EXPECT_TRUE(parser_->Parse(
+ input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
+ return results.size() == 1;
+}
+
+TEST_F(ParserLocaleTest, English) {
+ EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+TEST_F(ParserLocaleTest, TraditionalChinese) {
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
+}
+
+TEST_F(ParserLocaleTest, SwissEnglish) {
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
+ EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
index 1c57c7e..28eec5f 100644
--- a/native/annotator/experimental/experimental-dummy.h
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -49,7 +49,7 @@
bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
std::vector<AnnotatedSpan>& candidates) const {
- return true;
+ return false;
}
};
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
index 5d69d17..6e15d04 100755
--- a/native/annotator/experimental/experimental.fbs
+++ b/native/annotator/experimental/experimental.fbs
@@ -14,8 +14,6 @@
// limitations under the License.
//
-include "utils/container/bit-vector.fbs";
-
namespace libtextclassifier3;
table ExperimentalModel {
}
diff --git a/native/annotator/grammar/grammar-annotator_test.cc b/native/annotator/grammar/grammar-annotator_test.cc
new file mode 100644
index 0000000..39ee950
--- /dev/null
+++ b/native/annotator/grammar/grammar-annotator_test.cc
@@ -0,0 +1,533 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/grammar-annotator.h"
+
+#include <memory>
+
+#include "annotator/grammar/test-utils.h"
+#include "annotator/grammar/utils.h"
+#include "annotator/model_generated.h"
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/grammar/utils/rules.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+flatbuffers::DetachedBuffer PackModel(const GrammarModelT& model) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(GrammarModel::Pack(builder, &model));
+ return builder.Release();
+}
+
+TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
+ IsAnnotatedSpan(51, 57, "flight")));
+}
+
+TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+
+ // Flight: carrier + flight code and check right context.
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("My flight: LX 38 arriving at 4pm, I'll fly back on "
+ "AA2014 on LX 38.00",
+ /*do_copy=*/false),
+ &result));
+
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
+ IsAnnotatedSpan(51, 57, "flight")));
+}
+
+TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.AddValueMapping("<low_confidence_phone>", {"<digits>"},
+ /*value=*/0);
+
+ // Create rule result.
+ const int classification_result_id =
+ AddRuleClassificationResult("phone", ModeFlag_ALL, 1.0, &grammar_model);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ rules.Add(
+ "<phone>", {"please", "call", "<low_confidence_phone>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/classification_result_id);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Please call 911 before 10 am!", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(12, 15, "phone")));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ CodepointSpan{11, 16}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+
+ // Use unbounded context.
+ grammar_model.context_left_num_tokens = -1;
+ grammar_model.context_right_num_tokens = -1;
+
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.AddValueMapping("<flight_selection>", {"<carrier>", "<flight_code>"},
+ /*value=*/0);
+
+ // Flight: carrier + flight code and check right context.
+ const int classification_result_id =
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<flight>", {"<flight_selection>", "<context_assertion>?"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ classification_result_id);
+
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ // Exclude matches like: LX 38.00 etc.
+ rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
+ /*negative=*/true);
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38.00", /*do_copy=*/false), CodepointSpan{4, 9},
+ nullptr));
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38 00", /*do_copy=*/false), CodepointSpan{4, 9},
+ nullptr));
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("See LX 38, seat 5", /*do_copy=*/false),
+ CodepointSpan{4, 9}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+
+ // Max three tokens to the left ("tracking number: ...").
+ grammar_model.context_left_num_tokens = 3;
+ grammar_model.context_right_num_tokens = 0;
+
+ grammar::Rules rules;
+ rules.Add("<tracking_number>", {"<5_digits>"});
+ rules.Add("<tracking_number>", {"<6_digits>"});
+ rules.Add("<tracking_number>", {"<7_digits>"});
+ rules.Add("<tracking_number>", {"<8_digits>"});
+ rules.Add("<tracking_number>", {"<9_digits>"});
+ rules.Add("<tracking_number>", {"<10_digits>"});
+ rules.AddValueMapping("<captured_tracking_number>", {"<tracking_number>"},
+ /*value=*/0);
+ rules.Add("<parcel_tracking_trigger>", {"tracking", "number?", ":?"});
+
+ const int classification_result_id = AddRuleClassificationResult(
+ "parcel_tracking", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<parcel_tracking>",
+ {"<parcel_tracking_trigger>", "<captured_tracking_number>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ classification_result_id);
+
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ grammar_model.rule_classification_result[classification_result_id]
+ ->capturing_group.back()
+ ->extend_selection = true;
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ ClassificationResult result;
+ EXPECT_TRUE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Use tracking number 012345 for live parcel tracking.",
+ /*do_copy=*/false),
+ CodepointSpan{20, 26}, &result));
+ EXPECT_THAT(result, IsClassificationResult("parcel_tracking"));
+
+ EXPECT_FALSE(annotator.ClassifyText(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("Call phone 012345 for live parcel tracking.",
+ /*do_copy=*/false),
+ CodepointSpan{11, 17}, &result));
+}
+
+TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<carrier>", {"lx"});
+ rules.Add("<carrier>", {"aa"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ AnnotatedSpan selection;
+ EXPECT_TRUE(annotator.SuggestSelection(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
+ /*do_copy=*/false),
+ /*selection=*/CodepointSpan{14, 15}, &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
+}
+
+TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ const int person_result =
+ AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
+ rules.Add(
+ "<person>", {"barack", "obama"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/person_result);
+
+ // Add test entity data.
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ entity_data->Set("person", "Former President Barack Obama");
+ grammar_model.rule_classification_result[person_result]
+ ->serialized_entity_data = entity_data->Serialize();
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("I saw Barack Obama today", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 18, "person")));
+
+ // Check entity data.
+ // As we don't have generated code for the ad-hoc generated entity data
+ // schema, we have to check manually using field offsets.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result.front().classification.front().serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Former President Barack Obama");
+}
+
+TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ const int person_result =
+ AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
+
+ rules.Add("<person>", {"barack?", "obama"});
+ rules.Add("<person>", {"zapp?", "brannigan"});
+ rules.AddValueMapping("<captured_person>", {"<person>"},
+ /*value=*/0);
+ rules.Add(
+ "<test>", {"<captured_person>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/person_result);
+
+ // Set capturing group entity data information.
+ grammar_model.rule_classification_result[person_result]
+ ->capturing_group.emplace_back(new CapturingGroupT);
+ CapturingGroupT* group =
+ grammar_model.rule_classification_result[person_result]
+ ->capturing_group.back()
+ .get();
+ group->entity_field_path.reset(new FlatbufferFieldPathT);
+ group->entity_field_path->field.emplace_back(new FlatbufferFieldT);
+ group->entity_field_path->field.back()->field_name = "person";
+ group->normalization_options.reset(new NormalizationOptionsT);
+ group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(
+ {Locale::FromBCP47("en")},
+ UTF8ToUnicodeText("I saw Zapp Brannigan today", /*do_copy=*/false),
+ &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 20, "person")));
+
+ // Check entity data.
+ // As we don't have generated code for the ad-hoc generated entity data
+ // schema, we have to check manually using field offsets.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ result.front().classification.front().serialized_entity_data.data()));
+ EXPECT_THAT(
+ entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "ZAPP BRANNIGAN");
+}
+
+TEST_F(GrammarAnnotatorTest, RespectsRuleModes) {
+ // Create test rules.
+ GrammarModelT grammar_model;
+ SetTestTokenizerOptions(&grammar_model);
+ grammar_model.rules.reset(new grammar::RulesSetT);
+ grammar::Rules rules;
+ rules.Add("<classification_carrier>", {"ei"});
+ rules.Add("<classification_carrier>", {"en"});
+ rules.Add("<selection_carrier>", {"ai"});
+ rules.Add("<selection_carrier>", {"bx"});
+ rules.Add("<annotation_carrier>", {"aa"});
+ rules.Add("<annotation_carrier>", {"lx"});
+ rules.Add("<flight_code>", {"<2_digits>"});
+ rules.Add("<flight_code>", {"<3_digits>"});
+ rules.Add("<flight_code>", {"<4_digits>"});
+ rules.Add(
+ "<flight>", {"<annotation_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
+ rules.Add(
+ "<flight>", {"<selection_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight",
+ ModeFlag_CLASSIFICATION_AND_SELECTION, 1.0,
+ &grammar_model));
+ rules.Add(
+ "<flight>", {"<classification_carrier>", "<flight_code>"},
+ /*callback=*/
+ static_cast<grammar::CallbackId>(GrammarAnnotator::Callback::kRuleMatch),
+ /*callback_param=*/
+ AddRuleClassificationResult("flight", ModeFlag_CLASSIFICATION, 1.0,
+ &grammar_model));
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ grammar_model.rules.get());
+ flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
+ GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
+
+ const UnicodeText text = UTF8ToUnicodeText(
+ "My flight: LX 38 arriving at 4pm, I'll fly back on EI2014 but maybe "
+ "also on bx 222");
+ const std::vector<Locale> locales = {Locale::FromBCP47("en")};
+
+ // Annotation, only high confidence pattern.
+ {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(annotator.Annotate(locales, text, &result));
+ EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight")));
+ }
+
+ // Selection, annotation patterns + selection.
+ {
+ AnnotatedSpan selection;
+
+ // Selects 'LX 38'.
+ EXPECT_TRUE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{14, 15},
+ &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
+
+ // Selects 'bx 222'.
+ EXPECT_TRUE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{76, 77},
+ &selection));
+ EXPECT_THAT(selection, IsAnnotatedSpan(76, 82, "flight"));
+
+ // Doesn't select 'EI2014'.
+ EXPECT_FALSE(annotator.SuggestSelection(locales, text,
+ /*selection=*/CodepointSpan{51, 51},
+ &selection));
+ }
+
+ // Classification, all patterns.
+ {
+ ClassificationResult result;
+
+ // Classifies 'LX 38'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{11, 16}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+
+ // Classifies 'EI2014'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{51, 57}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+
+ // Classifies 'bx 222'.
+ EXPECT_TRUE(
+ annotator.ClassifyText(locales, text, CodepointSpan{76, 82}, &result));
+ EXPECT_THAT(result, IsClassificationResult("flight"));
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/test-utils.cc b/native/annotator/grammar/test-utils.cc
new file mode 100644
index 0000000..45bc301
--- /dev/null
+++ b/native/annotator/grammar/test-utils.cc
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/test-utils.h"
+
+#include "utils/tokenizer.h"
+
+namespace libtextclassifier3 {
+
+GrammarAnnotator GrammarAnnotatorTest::CreateGrammarAnnotator(
+ const ::flatbuffers::DetachedBuffer& serialized_model) {
+ return GrammarAnnotator(
+ unilib_.get(),
+ flatbuffers::GetRoot<GrammarModel>(serialized_model.data()),
+ entity_data_builder_.get());
+}
+
+void SetTestTokenizerOptions(GrammarModelT* model) {
+ model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
+ model->tokenizer_options->tokenization_type = TokenizationType_ICU;
+ model->tokenizer_options->icu_preserve_whitespace_tokens = false;
+ model->tokenizer_options->tokenize_on_script_change = true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/test-utils.h b/native/annotator/grammar/test-utils.h
new file mode 100644
index 0000000..e6b2071
--- /dev/null
+++ b/native/annotator/grammar/test-utils.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
+
+#include <memory>
+
+#include "actions/test-utils.h"
+#include "annotator/grammar/grammar-annotator.h"
+#include "utils/flatbuffers/mutable.h"
+#include "utils/jvm-test-utils.h"
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+// TODO(sofian): Move this matchers to a level up library, useable for more
+// tests in text_classifier.
+MATCHER_P3(IsAnnotatedSpan, start, end, collection,
+ "is annotated span with begin that " +
+ ::testing::DescribeMatcher<int>(start, negation) +
+ ", end that " + ::testing::DescribeMatcher<int>(end, negation) +
+ ", collection that " +
+ ::testing::DescribeMatcher<std::string>(collection, negation)) {
+ return ::testing::ExplainMatchResult(CodepointSpan(start, end), arg.span,
+ result_listener) &&
+ ::testing::ExplainMatchResult(::testing::StrEq(collection),
+ arg.classification.front().collection,
+ result_listener);
+}
+
+MATCHER_P(IsClassificationResult, collection,
+ "is classification result with collection that " +
+ ::testing::DescribeMatcher<std::string>(collection, negation)) {
+ return ::testing::ExplainMatchResult(::testing::StrEq(collection),
+ arg.collection, result_listener);
+}
+
+class GrammarAnnotatorTest : public ::testing::Test {
+ protected:
+ GrammarAnnotatorTest()
+ : unilib_(CreateUniLibForTesting()),
+ serialized_entity_data_schema_(TestEntityDataSchema()),
+ entity_data_builder_(new MutableFlatbufferBuilder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ serialized_entity_data_schema_.data()))) {}
+
+ GrammarAnnotator CreateGrammarAnnotator(
+ const ::flatbuffers::DetachedBuffer& serialized_model);
+
+ std::unique_ptr<UniLib> unilib_;
+ const std::string serialized_entity_data_schema_;
+ std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
+};
+
+void SetTestTokenizerOptions(GrammarModelT* model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_TEST_UTILS_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 6769f46..b279cc5 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -14,17 +14,18 @@
// limitations under the License.
//
-include "annotator/entity-data.fbs";
-include "annotator/grammar/dates/dates.fbs";
-include "utils/grammar/rules.fbs";
include "utils/intents/intent-config.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
-include "annotator/experimental/experimental.fbs";
+include "annotator/grammar/dates/dates.fbs";
include "utils/normalization.fbs";
include "utils/tokenizer.fbs";
-include "utils/zlib/buffer.fbs";
+include "utils/grammar/rules.fbs";
include "utils/resources.fbs";
+include "utils/zlib/buffer.fbs";
+include "utils/container/bit-vector.fbs";
+include "annotator/entity-data.fbs";
+include "annotator/experimental/experimental.fbs";
include "utils/codepoint-range.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
file_identifier "TC2 ";
@@ -674,6 +675,7 @@
conflict_resolution_options:Model_.ConflictResolutionOptions;
experimental_model:ExperimentalModel;
pod_ner_model:PodNerModel;
+ vocab_model:VocabModel;
}
// Method for selecting the center token.
@@ -1059,6 +1061,39 @@
// The possible labels the ner model can output. If empty the default labels
// will be used.
labels:[PodNerModel_.Label];
+
+ // If the ratio of unknown wordpieces in the input text is greater than this
+ // maximum, the text won't be annotated.
+ max_ratio_unknown_wordpieces:float = 0.1;
+}
+
+namespace libtextclassifier3;
+table VocabModel {
+ // A trie that stores a list of vocabs that triggers "Define". A id is
+ // returned when looking up a vocab from the trie and the id can be used
+ // to access more information about that vocab.
+ vocab_trie:[ubyte];
+
+ // A bit vector that tells if the vocab should trigger "Define" for users of
+ // beginner proficiency only. To look up the bit vector, use the id returned
+ // by the trie.
+ beginner_level:BitVectorData;
+
+ // A sorted list of indices of vocabs that should not trigger "Define" if
+ // its leading character is in upper case. The indices are those returned by
+ // trie. You may perform binary search to look up an index.
+ do_not_trigger_in_upper_case:BitVectorData;
+
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ triggering_locales:string (shared);
+
+ // The final score to assign to the results of the vocab model
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
}
root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number_test-include.h b/native/annotator/number/number_test-include.h
index 1cfc74c..9de7c86 100644
--- a/native/annotator/number/number_test-include.h
+++ b/native/annotator/number/number_test-include.h
@@ -18,46 +18,21 @@
#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_TEST_INCLUDE_H_
#include "annotator/number/number.h"
+#include "utils/jvm-test-utils.h"
#include "gtest/gtest.h"
-// Include the version of UniLib depending on the macro.
-#if defined TC3_UNILIB_ICU
-#include "utils/utf8/unilib-icu.h"
-
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-
-#elif defined TC3_UNILIB_JAVAICU
-#include <jni.h>
-
-#include "utils/utf8/unilib-javaicu.h"
-
-extern JNIEnv* g_jenv;
-
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
-
-#elif defined TC3_UNILIB_APPLE
-#include "utils/utf8/unilib-apple.h"
-
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-
-#else
-
-#error Unsupported configuration.
-
-#endif
-
namespace libtextclassifier3 {
namespace test_internal {
class NumberAnnotatorTest : public ::testing::Test {
protected:
NumberAnnotatorTest()
- : TC3_TESTING_CREATE_UNILIB_INSTANCE(unilib_),
- number_annotator_(TestingNumberAnnotatorOptions(), &unilib_) {}
+ : unilib_(CreateUniLibForTesting()),
+ number_annotator_(TestingNumberAnnotatorOptions(), unilib_.get()) {}
const NumberAnnotatorOptions* TestingNumberAnnotatorOptions();
- UniLib unilib_;
+ std::unique_ptr<UniLib> unilib_;
NumberAnnotator number_annotator_;
};
diff --git a/native/annotator/pod_ner/pod-ner-dummy.h b/native/annotator/pod_ner/pod-ner-dummy.h
index 1246ade..2f6cd41 100644
--- a/native/annotator/pod_ner/pod-ner-dummy.h
+++ b/native/annotator/pod_ner/pod-ner-dummy.h
@@ -47,7 +47,7 @@
bool ClassifyText(const UnicodeText &context, CodepointSpan click,
ClassificationResult *result) const {
- return true;
+ return false;
}
};
diff --git a/native/annotator/test-utils.h b/native/annotator/test-utils.h
new file mode 100644
index 0000000..a86302c
--- /dev/null
+++ b/native/annotator/test-utils.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
+
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+using ::testing::Value;
+
+MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
+ const std::string first_result = arg.classification.empty()
+ ? "<INVALID RESULTS>"
+ : arg.classification[0].collection;
+ return Value(arg.span, CodepointSpan(start, end)) &&
+ Value(first_result, best_class);
+}
+
+MATCHER_P2(IsDateResult, time_ms_utc, granularity, "") {
+ return Value(arg.collection, "date") &&
+ Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
+ Value(arg.datetime_parse_result.granularity, granularity);
+}
+
+MATCHER_P2(IsDatetimeResult, time_ms_utc, granularity, "") {
+ return Value(arg.collection, "datetime") &&
+ Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
+ Value(arg.datetime_parse_result.granularity, granularity);
+}
+
+MATCHER_P3(IsDurationSpan, start, end, duration_ms, "") {
+ if (arg.classification.empty()) {
+ return false;
+ }
+ return ExplainMatchResult(IsAnnotatedSpan(start, end, "duration"), arg,
+ result_listener) &&
+ arg.classification[0].duration_ms == duration_ms;
+}
+
+MATCHER_P4(IsDatetimeSpan, start, end, time_ms_utc, granularity, "") {
+ if (arg.classification.empty()) {
+ return false;
+ }
+ return ExplainMatchResult(IsAnnotatedSpan(start, end, "datetime"), arg,
+ result_listener) &&
+ arg.classification[0].datetime_parse_result.time_ms_utc ==
+ time_ms_utc &&
+ arg.classification[0].datetime_parse_result.granularity == granularity;
+}
+
+MATCHER_P2(IsBetween, low, high, "") { return low < arg && arg < high; }
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TEST_UTILS_H_
diff --git a/native/annotator/test_data/test_grammar_model.fb b/native/annotator/test_data/test_grammar_model.fb
new file mode 100644
index 0000000..30f133e
--- /dev/null
+++ b/native/annotator/test_data/test_grammar_model.fb
Binary files differ
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
new file mode 100644
index 0000000..55f55c9
--- /dev/null
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
new file mode 100644
index 0000000..abe3fb0
--- /dev/null
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/types.h b/native/annotator/types.h
index df6f676..f7e1143 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -60,6 +60,7 @@
struct CodepointSpan {
static const CodepointSpan kInvalid;
+ CodepointSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
CodepointSpan(CodepointIndex start, CodepointIndex end)
: first(start), second(end) {}
@@ -560,10 +561,14 @@
// Comma-separated list of language tags which the user can read and
// understand (BCP 47).
std::string user_familiar_language_tags;
+ // If true, trigger dictionary on words that are of beginner level.
+ bool trigger_dictionary_on_beginner_words = false;
bool operator==(const ClassificationOptions& other) const {
return this->user_familiar_language_tags ==
other.user_familiar_language_tags &&
+ this->trigger_dictionary_on_beginner_words ==
+ other.trigger_dictionary_on_beginner_words &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
@@ -594,12 +599,17 @@
AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
+ // If true, trigger dictionary on words that are of beginner level.
+ bool trigger_dictionary_on_beginner_words = false;
+
bool operator==(const AnnotationOptions& other) const {
return this->is_serialized_entity_data_enabled ==
other.is_serialized_entity_data_enabled &&
this->permissions == other.permissions &&
this->entity_types == other.entity_types &&
this->annotate_mode == other.annotate_mode &&
+ this->trigger_dictionary_on_beginner_words ==
+ other.trigger_dictionary_on_beginner_words &&
BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
}
};
diff --git a/native/actions/test_data/en.fb b/native/annotator/vocab/test_data/test.model
old mode 100755
new mode 100644
similarity index 70%
rename from native/actions/test_data/en.fb
rename to native/annotator/vocab/test_data/test.model
index fbb5a6c..06b189d
--- a/native/actions/test_data/en.fb
+++ b/native/annotator/vocab/test_data/test.model
Binary files differ
diff --git a/native/annotator/vocab/vocab-annotator-dummy.h b/native/annotator/vocab/vocab-annotator-dummy.h
new file mode 100644
index 0000000..eda8a9c
--- /dev/null
+++ b/native/annotator/vocab/vocab-annotator-dummy.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+class VocabAnnotator {
+ public:
+ static std::unique_ptr<VocabAnnotator> Create(
+ const VocabModel *model, const FeatureProcessor &feature_processor,
+ const UniLib &unilib) {
+ return nullptr;
+ }
+
+ bool Annotate(const UnicodeText &context,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words,
+ std::vector<AnnotatedSpan> *results) const {
+ return true;
+ }
+
+ bool ClassifyText(const UnicodeText &context, CodepointSpan click,
+ const std::vector<Locale> detected_text_language_tags,
+ bool trigger_on_beginner_words,
+ ClassificationResult *result) const {
+ return false;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_DUMMY_H_
diff --git a/native/annotator/vocab/vocab-annotator.h b/native/annotator/vocab/vocab-annotator.h
new file mode 100644
index 0000000..35fa928
--- /dev/null
+++ b/native/annotator/vocab/vocab-annotator.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
+
+#if defined TC3_VOCAB_ANNOTATOR_IMPL
+#include "annotator/vocab/vocab-annotator-impl.h"
+#elif defined TC3_VOCAB_ANNOTATOR_DUMMY
+#include "annotator/vocab/vocab-annotator-dummy.h"
+#else
+#error No vocab-annotator implementation specified.
+#endif // TC3_VOCAB_ANNOTATOR_IMPL
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_VOCAB_VOCAB_ANNOTATOR_H_
diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc
index 66f7f38..8efa386 100644
--- a/native/lang_id/common/flatbuffers/model-utils.cc
+++ b/native/lang_id/common/flatbuffers/model-utils.cc
@@ -26,14 +26,6 @@
namespace libtextclassifier3 {
namespace saft_fbs {
-namespace {
-
-// Returns true if we have clear evidence that |model| fails its checksum.
-//
-// E.g., if |model| has the crc32 field, and the value of that field does not
-// match the checksum, then this function returns true. If there is no crc32
-// field, then we don't know what the original (at build time) checksum was, so
-// we don't know anything clear and this function returns false.
bool ClearlyFailsChecksum(const Model &model) {
if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
SAFTM_LOG(WARNING)
@@ -50,7 +42,6 @@
SAFTM_DLOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
return false;
}
-} // namespace
const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
if ((data == nullptr) || (num_bytes == 0)) {
diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h
index 197e1e3..cf33dd5 100644
--- a/native/lang_id/common/flatbuffers/model-utils.h
+++ b/native/lang_id/common/flatbuffers/model-utils.h
@@ -66,6 +66,14 @@
// corruption. GetVerifiedModelFromBytes performs this check.
mobile::uint32 ComputeCrc2Checksum(const Model *model);
+// Returns true if we have clear evidence that |model| fails its checksum.
+//
+// E.g., if |model| has the crc32 field, and the value of that field does not
+// match the checksum, then this function returns true. If there is no crc32
+// field, then we don't know what the original (at build time) checksum was, so
+// we don't know anything clear and this function returns false.
+bool ClearlyFailsChecksum(const Model &model);
+
} // namespace saft_fbs
} // namespace nlp_saft
diff --git a/native/models/textclassifier.ar.model b/native/models/textclassifier.ar.model
index d9710b9..87a442a 100755
--- a/native/models/textclassifier.ar.model
+++ b/native/models/textclassifier.ar.model
Binary files differ
diff --git a/native/models/textclassifier.en.model b/native/models/textclassifier.en.model
index f5fcc23..70e3cd2 100755
--- a/native/models/textclassifier.en.model
+++ b/native/models/textclassifier.en.model
Binary files differ
diff --git a/native/models/textclassifier.es.model b/native/models/textclassifier.es.model
index 33dddec..8ea0938 100755
--- a/native/models/textclassifier.es.model
+++ b/native/models/textclassifier.es.model
Binary files differ
diff --git a/native/models/textclassifier.fr.model b/native/models/textclassifier.fr.model
index 45df2f1..3ed3172 100755
--- a/native/models/textclassifier.fr.model
+++ b/native/models/textclassifier.fr.model
Binary files differ
diff --git a/native/models/textclassifier.it.model b/native/models/textclassifier.it.model
index 70bf151..4381909 100755
--- a/native/models/textclassifier.it.model
+++ b/native/models/textclassifier.it.model
Binary files differ
diff --git a/native/models/textclassifier.ja.model b/native/models/textclassifier.ja.model
index d28801a..5db8e14 100755
--- a/native/models/textclassifier.ja.model
+++ b/native/models/textclassifier.ja.model
Binary files differ
diff --git a/native/models/textclassifier.ko.model b/native/models/textclassifier.ko.model
index c4bacdb..a0d37ff 100755
--- a/native/models/textclassifier.ko.model
+++ b/native/models/textclassifier.ko.model
Binary files differ
diff --git a/native/models/textclassifier.nl.model b/native/models/textclassifier.nl.model
index 78e2f46..5e627e0 100755
--- a/native/models/textclassifier.nl.model
+++ b/native/models/textclassifier.nl.model
Binary files differ
diff --git a/native/models/textclassifier.pl.model b/native/models/textclassifier.pl.model
index 6090f54..7c43109 100755
--- a/native/models/textclassifier.pl.model
+++ b/native/models/textclassifier.pl.model
Binary files differ
diff --git a/native/models/textclassifier.pt.model b/native/models/textclassifier.pt.model
index 7ab45d8..b3b2232 100755
--- a/native/models/textclassifier.pt.model
+++ b/native/models/textclassifier.pt.model
Binary files differ
diff --git a/native/models/textclassifier.ru.model b/native/models/textclassifier.ru.model
index 441d3fe..722afbe 100755
--- a/native/models/textclassifier.ru.model
+++ b/native/models/textclassifier.ru.model
Binary files differ
diff --git a/native/models/textclassifier.th.model b/native/models/textclassifier.th.model
index 6a0a0d8..b156ed7 100755
--- a/native/models/textclassifier.th.model
+++ b/native/models/textclassifier.th.model
Binary files differ
diff --git a/native/models/textclassifier.tr.model b/native/models/textclassifier.tr.model
index 5752ba0..5a66a11 100755
--- a/native/models/textclassifier.tr.model
+++ b/native/models/textclassifier.tr.model
Binary files differ
diff --git a/native/models/textclassifier.universal.model b/native/models/textclassifier.universal.model
index 971c2af..83704c3 100755
--- a/native/models/textclassifier.universal.model
+++ b/native/models/textclassifier.universal.model
Binary files differ
diff --git a/native/models/textclassifier.zh.model b/native/models/textclassifier.zh.model
index ef9a536..946d188 100755
--- a/native/models/textclassifier.zh.model
+++ b/native/models/textclassifier.zh.model
Binary files differ
diff --git a/native/utils/base/arena.cc b/native/utils/base/arena.cc
index fcaed8e..d1e85e8 100644
--- a/native/utils/base/arena.cc
+++ b/native/utils/base/arena.cc
@@ -29,6 +29,7 @@
namespace libtextclassifier3 {
+#ifndef __cpp_aligned_new
static void *aligned_malloc(size_t size, int minimum_alignment) {
void *ptr = nullptr;
// posix_memalign requires that the requested alignment be at least
@@ -42,6 +43,7 @@
else
return ptr;
}
+#endif // !__cpp_aligned_new
// The value here doesn't matter until page_aligned_ is supported.
static const int kPageSize = 8192; // should be getpagesize()
diff --git a/native/utils/base/arena_test.cc b/native/utils/base/arena_test.cc
new file mode 100644
index 0000000..d5e9bf3
--- /dev/null
+++ b/native/utils/base/arena_test.cc
@@ -0,0 +1,369 @@
+#include "utils/base/arena.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+//------------------------------------------------------------------------
+// Write random data to allocated memory
+static void TestMemory(void* mem, int size) {
+ // Do some memory allocation to check that the arena doesn't mess up
+ // the internal memory allocator
+ char* tmp[100];
+ for (int i = 0; i < TC3_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i + 1];
+ }
+
+ memset(mem, 0xcc, size);
+
+ // Free up the allocated memory;
+ for (char* s : tmp) {
+ delete[] s;
+ }
+}
+
+//------------------------------------------------------------------------
+// Check memory ptr
+static void CheckMemory(void* mem, int size) {
+ TC3_CHECK(mem != nullptr);
+ TestMemory(mem, size);
+}
+
+//------------------------------------------------------------------------
+// Check memory ptr and alignment
+static void CheckAlignment(void* mem, int size, int alignment) {
+ TC3_CHECK(mem != nullptr);
+ ASSERT_EQ(0, (reinterpret_cast<uintptr_t>(mem) & (alignment - 1)))
+ << "mem=" << mem << " alignment=" << alignment;
+ TestMemory(mem, size);
+}
+
+//------------------------------------------------------------------------
+template <class A>
+void TestArena(const char* name, A* a, int blksize) {
+ TC3_VLOG(INFO) << "Testing arena '" << name << "': blksize = " << blksize
+ << ": actual blksize = " << a->block_size();
+
+ int s;
+ blksize = a->block_size();
+
+ // Allocate zero bytes
+ TC3_CHECK(a->is_empty());
+ a->Alloc(0);
+ TC3_CHECK(a->is_empty());
+
+ // Allocate same as blksize
+ CheckMemory(a->Alloc(blksize), blksize);
+ TC3_CHECK(!a->is_empty());
+
+ // Allocate some chunks adding up to blksize
+ s = blksize / 4;
+ CheckMemory(a->Alloc(s), s);
+ CheckMemory(a->Alloc(s), s);
+ CheckMemory(a->Alloc(s), s);
+
+ int s2 = blksize - (s * 3);
+ CheckMemory(a->Alloc(s2), s2);
+
+ // Allocate large chunk
+ CheckMemory(a->Alloc(blksize * 2), blksize * 2);
+ CheckMemory(a->Alloc(blksize * 2 + 1), blksize * 2 + 1);
+ CheckMemory(a->Alloc(blksize * 2 + 2), blksize * 2 + 2);
+ CheckMemory(a->Alloc(blksize * 2 + 3), blksize * 2 + 3);
+
+ // Allocate aligned
+ s = blksize / 2;
+ CheckAlignment(a->AllocAligned(s, 1), s, 1);
+ CheckAlignment(a->AllocAligned(s + 1, 2), s + 1, 2);
+ CheckAlignment(a->AllocAligned(s + 2, 2), s + 2, 2);
+ CheckAlignment(a->AllocAligned(s + 3, 4), s + 3, 4);
+ CheckAlignment(a->AllocAligned(s + 4, 4), s + 4, 4);
+ CheckAlignment(a->AllocAligned(s + 5, 4), s + 5, 4);
+ CheckAlignment(a->AllocAligned(s + 6, 4), s + 6, 4);
+
+ // Free
+ for (int i = 0; i < 100; i++) {
+ int i2 = i * i;
+ a->Free(a->Alloc(i2), i2);
+ }
+
+ // Memdup
+ char mem[500];
+ for (int i = 0; i < 500; i++) mem[i] = i & 255;
+ char* mem2 = a->Memdup(mem, sizeof(mem));
+ TC3_CHECK_EQ(0, memcmp(mem, mem2, sizeof(mem)));
+
+ // MemdupPlusNUL
+ const char* msg_mpn = "won't use all this length";
+ char* msg2_mpn = a->MemdupPlusNUL(msg_mpn, 10);
+ TC3_CHECK_EQ(0, strcmp(msg2_mpn, "won't use "));
+ a->Free(msg2_mpn, 11);
+
+ // Strdup
+ const char* msg = "arena unit test is cool...";
+ char* msg2 = a->Strdup(msg);
+ TC3_CHECK_EQ(0, strcmp(msg, msg2));
+ a->Free(msg2, strlen(msg) + 1);
+
+ // Strndup
+ char* msg3 = a->Strndup(msg, 10);
+ TC3_CHECK_EQ(0, strncmp(msg3, msg, 10));
+ a->Free(msg3, 10);
+ TC3_CHECK(!a->is_empty());
+
+ // Reset
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+
+ // Realloc
+ char* m1 = a->Alloc(blksize / 2);
+ CheckMemory(m1, blksize / 2);
+ TC3_CHECK(!a->is_empty());
+ CheckMemory(a->Alloc(blksize / 2), blksize / 2); // Allocate another block
+ m1 = a->Realloc(m1, blksize / 2, blksize);
+ CheckMemory(m1, blksize);
+ m1 = a->Realloc(m1, blksize, 23456);
+ CheckMemory(m1, 23456);
+
+ // Shrink
+ m1 = a->Shrink(m1, 200);
+ CheckMemory(m1, 200);
+ m1 = a->Shrink(m1, 100);
+ CheckMemory(m1, 100);
+ m1 = a->Shrink(m1, 1);
+ CheckMemory(m1, 1);
+ a->Free(m1, 1);
+ TC3_CHECK(!a->is_empty());
+
+ // Calloc
+ char* m2 = a->Calloc(2000);
+ for (int i = 0; i < 2000; ++i) {
+ TC3_CHECK_EQ(0, m2[i]);
+ }
+
+ // bytes_until_next_allocation
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+ int alignment = blksize - a->bytes_until_next_allocation();
+ TC3_VLOG(INFO) << "Alignment overhead in initial block = " << alignment;
+
+ s = a->bytes_until_next_allocation() - 1;
+ CheckMemory(a->Alloc(s), s);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 1);
+ CheckMemory(a->Alloc(1), 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 0);
+
+ CheckMemory(a->Alloc(2 * blksize), 2 * blksize);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), 0);
+
+ CheckMemory(a->Alloc(1), 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - 1);
+
+ s = blksize / 2;
+ char* m0 = a->Alloc(s);
+ CheckMemory(m0, s);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - s - 1);
+ m0 = a->Shrink(m0, 1);
+ CheckMemory(m0, 1);
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - 2);
+
+ a->Reset();
+ TC3_CHECK(a->is_empty());
+ TC3_CHECK_EQ(a->bytes_until_next_allocation(), blksize - alignment);
+}
+
+static void EnsureNoAddressInRangeIsPoisoned(void* buffer, size_t range_size) {
+#ifdef ADDRESS_SANITIZER
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(buffer, range_size));
+#endif
+}
+
+static void DoTest(const char* label, int blksize, char* buffer) {
+ {
+ UnsafeArena ua(buffer, blksize);
+ TestArena((std::string("UnsafeArena") + label).c_str(), &ua, blksize);
+ }
+ EnsureNoAddressInRangeIsPoisoned(buffer, blksize);
+}
+
+//------------------------------------------------------------------------
+class BasicTest : public ::testing::TestWithParam<int> {};
+
+INSTANTIATE_TEST_SUITE_P(AllSizes, BasicTest,
+ ::testing::Values(BaseArena::kDefaultAlignment + 1, 10,
+ 100, 1024, 12345, 123450, 131072,
+ 1234500));
+
+TEST_P(BasicTest, DoTest) {
+ const int blksize = GetParam();
+
+ // Allocate some memory from heap first
+ char* tmp[100];
+ for (int i = 0; i < TC3_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i];
+ }
+
+ // Initial buffer for testing pre-allocated arenas
+ char* buffer = new char[blksize + BaseArena::kDefaultAlignment];
+
+ DoTest("", blksize, nullptr);
+ DoTest("(p0)", blksize, buffer + 0);
+ DoTest("(p1)", blksize, buffer + 1);
+ DoTest("(p2)", blksize, buffer + 2);
+ DoTest("(p3)", blksize, buffer + 3);
+ DoTest("(p4)", blksize, buffer + 4);
+ DoTest("(p5)", blksize, buffer + 5);
+
+ // Free up the allocated heap memory
+ for (char* s : tmp) {
+ delete[] s;
+ }
+
+ delete[] buffer;
+}
+
+//------------------------------------------------------------------------
+// NOTE: these stats will only be accurate in non-debug mode (otherwise
+// they'll all be 0). So: if you want accurate timing, run in "normal"
+// or "opt" mode. If you want accurate stats, run in "debug" mode.
+void ShowStatus(const char* const header, const BaseArena::Status& status) {
+ printf("\n--- status: %s\n", header);
+ printf(" %zu bytes allocated\n", status.bytes_allocated());
+}
+
+// This just tests the arena code proper, without use of allocators of
+// gladiators or STL or anything like that
+void TestArena2(UnsafeArena* const arena) {
+ const char sshort[] = "This is a short string";
+ char slong[3000];
+ memset(slong, 'a', sizeof(slong));
+ slong[sizeof(slong) - 1] = '\0';
+
+ char* s1 = arena->Strdup(sshort);
+ char* s2 = arena->Strdup(slong);
+ char* s3 = arena->Strndup(sshort, 100);
+ char* s4 = arena->Strndup(slong, 100);
+ char* s5 = arena->Memdup(sshort, 10);
+ char* s6 = arena->Realloc(s5, 10, 20);
+ arena->Shrink(s5, 10); // get s5 back to using 10 bytes again
+ char* s7 = arena->Memdup(slong, 10);
+ char* s8 = arena->Realloc(s7, 10, 5);
+ char* s9 = arena->Strdup(s1);
+ char* s10 = arena->Realloc(s4, 100, 10);
+ char* s11 = arena->Realloc(s4, 10, 100);
+ char* s12 = arena->Strdup(s9);
+ char* s13 = arena->Realloc(s9, sizeof(sshort) - 1, 100000); // won't fit :-)
+
+ TC3_CHECK_EQ(0, strcmp(s1, sshort));
+ TC3_CHECK_EQ(0, strcmp(s2, slong));
+ TC3_CHECK_EQ(0, strcmp(s3, sshort));
+ // s4 was realloced so it is not safe to read from
+ TC3_CHECK_EQ(0, strncmp(s5, sshort, 10));
+ TC3_CHECK_EQ(0, strncmp(s6, sshort, 10));
+ TC3_CHECK_EQ(s5, s6); // Should have room to grow here
+ // only the first 5 bytes of s7 should match; the realloc should have
+ // caused the next byte to actually go to s9
+ TC3_CHECK_EQ(0, strncmp(s7, slong, 5));
+ TC3_CHECK_EQ(s7, s8); // Realloc-smaller should cause us to realloc in place
+ // s9 was realloced so it is not safe to read from
+ TC3_CHECK_EQ(s10, s4); // Realloc-smaller should cause us to realloc in place
+ // Even though we're back to prev size, we had to move the pointer. Thus
+ // only the first 10 bytes are known since we grew from 10 to 100
+ TC3_CHECK_NE(s11, s4);
+ TC3_CHECK_EQ(0, strncmp(s11, slong, 10));
+ TC3_CHECK_EQ(0, strcmp(s12, s1));
+ TC3_CHECK_NE(s12, s13); // too big to grow-in-place, so we should move
+}
+
+//--------------------------------------------------------------------
+// Test some fundamental STL containers
+
+template <typename T>
+struct test_hash {
+ int operator()(const T&) const { return 0; }
+ inline bool operator()(const T& s1, const T& s2) const { return s1 < s2; }
+};
+template <>
+struct test_hash<const char*> {
+ int operator()(const char*) const { return 0; }
+
+ inline bool operator()(const char* s1, const char* s2) const {
+ return (s1 != s2) &&
+ (s2 == nullptr || (s1 != nullptr && strcmp(s1, s2) < 0));
+ }
+};
+
+// temp definitions from strutil.h, until the compiler error
+// generated by #including that file is fixed.
+struct streq {
+ bool operator()(const char* s1, const char* s2) const {
+ return ((s1 == nullptr && s2 == nullptr) ||
+ (s1 && s2 && *s1 == *s2 && strcmp(s1, s2) == 0));
+ }
+};
+struct strlt {
+ bool operator()(const char* s1, const char* s2) const {
+ return (s1 != s2) &&
+ (s2 == nullptr || (s1 != nullptr && strcmp(s1, s2) < 0));
+ }
+};
+
+void DoPoisonTest(BaseArena* b, size_t size) {
+#ifdef ADDRESS_SANITIZER
+ TC3_LOG(INFO) << "DoPoisonTest(" << b << ", " << size << ")";
+ char* c1 = b->SlowAlloc(size);
+ char* c2 = b->SlowAlloc(size);
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c1, size));
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c2, size));
+ char* c3 = b->SlowRealloc(c2, size, size / 2);
+ TC3_CHECK_EQ(nullptr, __asan_region_is_poisoned(c3, size / 2));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c2, size));
+ b->Reset();
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c1, size));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c2, size));
+ TC3_CHECK_NE(nullptr, __asan_region_is_poisoned(c3, size / 2));
+#endif
+}
+
+TEST(ArenaTest, TestPoison) {
+ {
+ UnsafeArena arena(512);
+ DoPoisonTest(&arena, 128);
+ DoPoisonTest(&arena, 256);
+ DoPoisonTest(&arena, 512);
+ DoPoisonTest(&arena, 1024);
+ }
+
+ char* buffer = new char[512];
+ {
+ UnsafeArena arena(buffer, 512);
+ DoPoisonTest(&arena, 128);
+ DoPoisonTest(&arena, 256);
+ DoPoisonTest(&arena, 512);
+ DoPoisonTest(&arena, 1024);
+ }
+ EnsureNoAddressInRangeIsPoisoned(buffer, 512);
+
+ delete[] buffer;
+}
+
+//------------------------------------------------------------------------
+
+template <class A>
+void TestStrndupUnterminated() {
+ const char kFoo[3] = {'f', 'o', 'o'};
+ char* source = new char[3];
+ memcpy(source, kFoo, sizeof(kFoo));
+ A arena(4096);
+ char* dup = arena.Strndup(source, sizeof(kFoo));
+ TC3_CHECK_EQ(0, memcmp(dup, kFoo, sizeof(kFoo)));
+ delete[] source;
+}
+
+TEST(ArenaTest, StrndupWithUnterminatedStringUnsafe) {
+ TestStrndupUnterminated<UnsafeArena>();
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar_test-include.cc b/native/utils/calendar/calendar_test-include.cc
index b818a5a..36b9778 100644
--- a/native/utils/calendar/calendar_test-include.cc
+++ b/native/utils/calendar/calendar_test-include.cc
@@ -26,7 +26,7 @@
DatetimeGranularity granularity;
std::string timezone;
DatetimeParsedData data;
- bool result = calendarlib_.InterpretParseData(
+ bool result = calendarlib_->InterpretParseData(
data, /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Zurich",
/*reference_locale=*/"en-CH",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity);
@@ -39,14 +39,14 @@
DatetimeParsedData data;
data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -60,7 +60,7 @@
DatetimeParsedData data;
data.SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, 2018);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -68,7 +68,7 @@
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH, 4);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -76,7 +76,7 @@
EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH, 25);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -84,7 +84,7 @@
EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -92,7 +92,7 @@
EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 33);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -100,7 +100,7 @@
EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 59);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-CH",
@@ -117,14 +117,14 @@
DatetimeComponent::RelativeQualifier::NEXT);
data.SetRelativeCount(DatetimeComponent::ComponentType::WEEK, 1);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"de-CH",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
@@ -147,7 +147,7 @@
future_wed_parse.SetAbsoluteValue(
DatetimeComponent::ComponentType::DAY_OF_WEEK, kWednesday);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -163,7 +163,7 @@
next_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
1);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -179,7 +179,7 @@
same_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
1);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -195,7 +195,7 @@
last_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
1);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -211,7 +211,7 @@
past_wed_parse.SetRelativeCount(DatetimeComponent::ComponentType::DAY_OF_WEEK,
-2);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -225,7 +225,7 @@
DatetimeComponent::RelativeQualifier::FUTURE);
in_3_hours_parse.SetRelativeCount(DatetimeComponent::ComponentType::HOUR, 3);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -240,7 +240,7 @@
in_5_minutes_parse.SetRelativeCount(DatetimeComponent::ComponentType::MINUTE,
5);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -255,7 +255,7 @@
in_10_seconds_parse.SetRelativeCount(DatetimeComponent::ComponentType::SECOND,
10);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
/*reference_locale=*/"en-US",
/*prefer_future_for_unspecified_date=*/false, &time, &granularity));
@@ -270,7 +270,7 @@
data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
/*reference_timezone=*/"Europe/Zurich",
@@ -287,7 +287,7 @@
data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
/*reference_timezone=*/"Europe/Zurich",
@@ -303,7 +303,7 @@
data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
/*reference_timezone=*/"Europe/Zurich",
@@ -311,7 +311,7 @@
&time, &granularity));
EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
- ASSERT_TRUE(calendarlib_.InterpretParseData(
+ ASSERT_TRUE(calendarlib_->InterpretParseData(
data,
/*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
/*reference_timezone=*/"Europe/Zurich",
diff --git a/native/utils/calendar/calendar_test-include.h b/native/utils/calendar/calendar_test-include.h
index 58ad6e0..504d67e 100644
--- a/native/utils/calendar/calendar_test-include.h
+++ b/native/utils/calendar/calendar_test-include.h
@@ -19,38 +19,17 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-
+#include "utils/jvm-test-utils.h"
#include "gtest/gtest.h"
-#if defined TC3_CALENDAR_ICU
-#include "utils/calendar/calendar-icu.h"
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
-#elif defined TC3_CALENDAR_APPLE
-#include "utils/calendar/calendar-apple.h"
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
-#elif defined TC3_CALENDAR_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) \
- VAR(JniCache::Create(g_jenv))
-#include "utils/calendar/calendar-javaicu.h"
-#else
-#error Unsupported calendar implementation.
-#endif
-
-// This can get overridden in the javaicu version which needs to pass an JNIEnv*
-// argument to the constructor.
-#ifndef TC3_TESTING_CREATE_CALENDARLIB_INSTANCE
-
-#endif
-
namespace libtextclassifier3 {
namespace test_internal {
class CalendarTest : public ::testing::Test {
protected:
- CalendarTest() : TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(calendarlib_) {}
- CalendarLib calendarlib_;
+ CalendarTest()
+ : calendarlib_(libtextclassifier3::CreateCalendarLibForTesting()) {}
+ std::unique_ptr<CalendarLib> calendarlib_;
};
} // namespace test_internal
diff --git a/native/utils/flatbuffers/flatbuffers.h b/native/utils/flatbuffers/flatbuffers.h
index f76e12d..1bb739b 100644
--- a/native/utils/flatbuffers/flatbuffers.h
+++ b/native/utils/flatbuffers/flatbuffers.h
@@ -103,6 +103,8 @@
// Cast as flatbuffer type.
const T* get() const { return flatbuffers::GetRoot<T>(buffer_.data()); }
+ const B& buffer() const { return buffer_; }
+
const T* operator->() const {
return flatbuffers::GetRoot<T>(buffer_.data());
}
diff --git a/native/utils/flatbuffers/flatbuffers_test.bfbs b/native/utils/flatbuffers/flatbuffers_test.bfbs
new file mode 100644
index 0000000..725e512
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test.fbs b/native/utils/flatbuffers/flatbuffers_test.fbs
new file mode 100644
index 0000000..70b164a
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test.fbs
@@ -0,0 +1,66 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table FlightNumberInfo {
+ carrier_code: string;
+ flight_code: int;
+}
+
+table ContactInfo {
+ first_name: string;
+ last_name: string;
+ phone_number: string;
+ score: float;
+}
+
+table Reminder {
+ title: string;
+ notes: [string];
+}
+
+table NestedA {
+ nestedb: NestedB;
+ value: string;
+}
+
+table NestedB {
+ nesteda: NestedA;
+}
+
+enum EnumValue : short {
+ VALUE_0 = 0,
+ VALUE_1 = 1,
+ VALUE_2 = 2,
+}
+
+table EntityData {
+ an_int_field: int;
+ a_long_field: int64;
+ a_bool_field: bool;
+ a_float_field: float;
+ a_double_field: double;
+ flight_number: FlightNumberInfo;
+ contact_info: ContactInfo;
+ reminders: [Reminder];
+ numbers: [int];
+ strings: [string];
+ nested: NestedA;
+ enum_value: EnumValue;
+}
+
+root_type libtextclassifier3.test.EntityData;
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.bfbs b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
new file mode 100644
index 0000000..ea8b7d2
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.bfbs
Binary files differ
diff --git a/native/utils/flatbuffers/flatbuffers_test_extended.fbs b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
new file mode 100644
index 0000000..6ce9973
--- /dev/null
+++ b/native/utils/flatbuffers/flatbuffers_test_extended.fbs
@@ -0,0 +1,67 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table FlightNumberInfo {
+ carrier_code: string;
+ flight_code: int;
+}
+
+table ContactInfo {
+ first_name: string;
+ last_name: string;
+ phone_number: string;
+ score: float;
+}
+
+table Reminder {
+ title: string;
+ notes: [string];
+}
+
+table NestedA {
+ nestedb: NestedB;
+ value: string;
+}
+
+table NestedB {
+ nesteda: NestedA;
+}
+
+enum EnumValue : short {
+ VALUE_0 = 0,
+ VALUE_1 = 1,
+ VALUE_2 = 2,
+}
+
+table EntityData {
+ an_int_field: int;
+ a_long_field: int64;
+ a_bool_field: bool;
+ a_float_field: float;
+ a_double_field: double;
+ flight_number: FlightNumberInfo;
+ contact_info: ContactInfo;
+ reminders: [Reminder];
+ numbers: [int];
+ strings: [string];
+ nested: NestedA;
+ enum_value: EnumValue;
+ mystic: string; // Extra field.
+}
+
+root_type libtextclassifier3.test.EntityData;
diff --git a/native/utils/flatbuffers/mutable.cc b/native/utils/flatbuffers/mutable.cc
index f5d298d..ca3f1b0 100644
--- a/native/utils/flatbuffers/mutable.cc
+++ b/native/utils/flatbuffers/mutable.cc
@@ -137,50 +137,12 @@
return libtextclassifier3::GetFieldOrNull(type_, field_offset);
}
-Variant MutableFlatbuffer::ParseEnumValue(const reflection::Type* type,
- StringPiece value) const {
- TC3_DCHECK(IsEnum(type));
- TC3_CHECK_NE(schema_->enums(), nullptr);
- const auto* enum_values = schema_->enums()->Get(type->index())->values();
- if (enum_values == nullptr) {
- TC3_LOG(ERROR) << "Enum has no specified values.";
- return Variant();
- }
- for (const reflection::EnumVal* enum_value : *enum_values) {
- if (value.Equals(StringPiece(enum_value->name()->c_str(),
- enum_value->name()->size()))) {
- const int64 value = enum_value->value();
- switch (type->base_type()) {
- case reflection::BaseType::Byte:
- return Variant(static_cast<int8>(value));
- case reflection::BaseType::UByte:
- return Variant(static_cast<uint8>(value));
- case reflection::BaseType::Short:
- return Variant(static_cast<int16>(value));
- case reflection::BaseType::UShort:
- return Variant(static_cast<uint16>(value));
- case reflection::BaseType::Int:
- return Variant(static_cast<int32>(value));
- case reflection::BaseType::UInt:
- return Variant(static_cast<uint32>(value));
- case reflection::BaseType::Long:
- return Variant(value);
- case reflection::BaseType::ULong:
- return Variant(static_cast<uint64>(value));
- default:
- break;
- }
- }
- }
- return Variant();
-}
-
bool MutableFlatbuffer::SetFromEnumValueName(const reflection::Field* field,
StringPiece value_name) {
if (!IsEnum(field->type())) {
return false;
}
- Variant variant_value = ParseEnumValue(field->type(), value_name);
+ Variant variant_value = ParseEnumValue(schema_, field->type(), value_name);
if (!variant_value.HasValue()) {
return false;
}
@@ -271,7 +233,6 @@
MutableFlatbuffer* MutableFlatbuffer::Mutable(const reflection::Field* field) {
if (field->type()->base_type() != reflection::Obj) {
- TC3_LOG(ERROR) << "Field is not of type Object.";
return nullptr;
}
const auto entry = children_.find(field);
diff --git a/native/utils/flatbuffers/mutable.h b/native/utils/flatbuffers/mutable.h
index c4d681a..8210e2a 100644
--- a/native/utils/flatbuffers/mutable.h
+++ b/native/utils/flatbuffers/mutable.h
@@ -175,9 +175,6 @@
const reflection::Object* type() const { return type_; }
private:
- // Parses an enum value.
- Variant ParseEnumValue(const reflection::Type* type, StringPiece value) const;
-
// Helper function for merging given repeated field from given flatbuffer
// table. Appends the elements.
template <typename T>
diff --git a/native/utils/flatbuffers/mutable_test.cc b/native/utils/flatbuffers/mutable_test.cc
new file mode 100644
index 0000000..a119f1f
--- /dev/null
+++ b/native/utils/flatbuffers/mutable_test.cc
@@ -0,0 +1,350 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/mutable.h"
+
+#include <map>
+#include <memory>
+#include <string>
+
+#include "utils/flatbuffers/flatbuffers.h"
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/flatbuffers_test_generated.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::SizeIs;
+
+class MutableFlatbufferTest : public testing::Test {
+ public:
+ explicit MutableFlatbufferTest()
+ : schema_(LoadTestMetadata()), builder_(schema_.get()) {}
+
+ protected:
+ OwnedFlatbuffer<reflection::Schema, std::string> schema_;
+ MutableFlatbufferBuilder builder_;
+};
+
+TEST_F(MutableFlatbufferTest, PrimitiveFieldsAreCorrectlySet) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+ EXPECT_TRUE(buffer->Set("an_int_field", 42));
+ EXPECT_TRUE(buffer->Set("a_long_field", int64{84}));
+ EXPECT_TRUE(buffer->Set("a_bool_field", true));
+ EXPECT_TRUE(buffer->Set("a_float_field", 1.f));
+ EXPECT_TRUE(buffer->Set("a_double_field", 1.0));
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 42);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->a_bool_field, true);
+ EXPECT_NEAR(entity_data->a_float_field, 1.f, 1e-4);
+ EXPECT_NEAR(entity_data->a_double_field, 1.f, 1e-4);
+}
+
+TEST_F(MutableFlatbufferTest, EnumValuesCanBeSpecifiedByName) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+
+ EXPECT_TRUE(IsEnum(buffer->GetFieldOrNull("enum_value")->type()));
+
+ EXPECT_TRUE(buffer->SetFromEnumValueName("enum_value", "VALUE_1"));
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ EXPECT_EQ(entity_data->enum_value,
+ libtextclassifier3::test::EnumValue_VALUE_1);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesUnknownFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer != nullptr);
+
+ // Add a field that is not known to the (statically generated) code.
+ EXPECT_TRUE(buffer->Set("mystic", "this is an unknown field."));
+
+ OwnedFlatbuffer<flatbuffers::Table, std::string> extra(buffer->Serialize());
+ EXPECT_EQ(extra
+ ->GetPointer<const flatbuffers::String*>(
+ buffer->GetFieldOrNull("mystic")->offset())
+ ->str(),
+ "this is an unknown field.");
+}
+
+TEST_F(MutableFlatbufferTest, HandlesNestedFields) {
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ MutableFlatbuffer* parent = nullptr;
+ reflection::Field const* field = nullptr;
+ EXPECT_TRUE(
+ buffer->GetFieldWithParent(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ &parent, &field));
+ EXPECT_EQ(parent, buffer->Mutable("flight_number"));
+ EXPECT_EQ(field,
+ buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
+}
+
+TEST_F(MutableFlatbufferTest, HandlesMultipleNestedFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ MutableFlatbuffer* contact_info = buffer->Mutable("contact_info");
+ EXPECT_TRUE(contact_info->Set("first_name", "Barack"));
+ EXPECT_TRUE(contact_info->Set("last_name", "Obama"));
+ EXPECT_TRUE(contact_info->Set("phone_number", "1-800-TEST"));
+ EXPECT_TRUE(contact_info->Set("score", 1.f));
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+ EXPECT_EQ(entity_data->contact_info->last_name, "Obama");
+ EXPECT_EQ(entity_data->contact_info->phone_number, "1-800-TEST");
+ EXPECT_NEAR(entity_data->contact_info->score, 1.f, 1e-4);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesFieldsSetWithNamePath) {
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST_F(MutableFlatbufferTest, HandlesFieldsSetWithOffsetPath) {
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 14;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_offset = 4;
+ flatbuffers::FlatBufferBuilder path_builder;
+ path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
+
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ // Test setting value using Set function.
+ buffer->Mutable("flight_number")->Set("flight_code", 38);
+ // Test setting value using FlatbufferFieldPath.
+ buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
+ path_builder.GetBufferPointer()),
+ "LX");
+
+ // Try to parse with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 38);
+}
+
+TEST_F(MutableFlatbufferTest, PartialBuffersAreCorrectlyMerged) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+ auto* reminders = buffer->Repeated("reminders");
+ MutableFlatbuffer* reminder1 = reminders->Add();
+ reminder1->Set("title", "reminder1");
+ auto* reminder1_notes = reminder1->Repeated("notes");
+ reminder1_notes->Add("note1");
+ reminder1_notes->Add("note2");
+
+ // Create message to merge.
+ test::EntityDataT additional_entity_data;
+ additional_entity_data.an_int_field = 43;
+ additional_entity_data.flight_number.reset(new test::FlightNumberInfoT);
+ additional_entity_data.flight_number->flight_code = 39;
+ additional_entity_data.contact_info.reset(new test::ContactInfoT);
+ additional_entity_data.contact_info->first_name = "Barack";
+ additional_entity_data.reminders.push_back(
+ std::unique_ptr<test::ReminderT>(new test::ReminderT));
+ additional_entity_data.reminders[0]->notes.push_back("additional note1");
+ additional_entity_data.reminders[0]->notes.push_back("additional note2");
+ additional_entity_data.numbers.push_back(9);
+ additional_entity_data.numbers.push_back(10);
+ additional_entity_data.strings.push_back("str1");
+ additional_entity_data.strings.push_back("str2");
+
+ // Merge it.
+ EXPECT_TRUE(buffer->MergeFromSerializedFlatbuffer(
+ PackFlatbuffer<test::EntityData>(&additional_entity_data)));
+
+ // Try to parse it with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->an_int_field, 43);
+ EXPECT_EQ(entity_data->a_long_field, 84);
+ EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
+ EXPECT_EQ(entity_data->flight_number->flight_code, 39);
+ EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
+ ASSERT_THAT(entity_data->reminders, SizeIs(2));
+ EXPECT_THAT(entity_data->reminders[1]->notes,
+ ElementsAre("additional note1", "additional note2"));
+ EXPECT_THAT(entity_data->numbers, ElementsAre(9, 10));
+ EXPECT_THAT(entity_data->strings, ElementsAre("str1", "str2"));
+}
+
+TEST_F(MutableFlatbufferTest, MergesNestedFields) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ // Set a multiply nested field.
+ std::unique_ptr<FlatbufferFieldPathT> field_path_t =
+ CreateFieldPath("nested.nestedb.nesteda.nestedb.nesteda");
+ OwnedFlatbuffer<FlatbufferFieldPath, std::string> field_path(
+ PackFlatbuffer<FlatbufferFieldPath>(field_path_t.get()));
+ buffer->Mutable(field_path.get())->Set("value", "le value");
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_EQ(entity_data->nested->nestedb->nesteda->nestedb->nesteda->value,
+ "le value");
+}
+
+TEST_F(MutableFlatbufferTest, PrimitiveAndNestedFieldsAreCorrectlyFlattened) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
+ EXPECT_EQ(4, entity_data_map.size());
+ EXPECT_EQ(42, entity_data_map["an_int_field"].Value<int>());
+ EXPECT_EQ(84, entity_data_map["a_long_field"].Value<int64>());
+ EXPECT_EQ("LX", entity_data_map["flight_number.carrier_code"]
+ .ConstRefValue<std::string>());
+ EXPECT_EQ(38, entity_data_map["flight_number.flight_code"].Value<int>());
+}
+
+TEST_F(MutableFlatbufferTest, ToTextProtoWorks) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ buffer->Set("an_int_field", 42);
+ buffer->Set("a_long_field", int64{84});
+ MutableFlatbuffer* flight_info = buffer->Mutable("flight_number");
+ flight_info->Set("carrier_code", "LX");
+ flight_info->Set("flight_code", 38);
+
+ EXPECT_EQ(buffer->ToTextProto(),
+ "a_long_field: 84, an_int_field: 42, flight_number "
+ "{flight_code: 38, carrier_code: 'LX'}");
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldSetThroughReflectionCanBeRead) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+
+ auto reminders = buffer->Repeated("reminders");
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder");
+ auto notes = reminder->Repeated("notes");
+ notes->Add("note A");
+ notes->Add("note B");
+ }
+ {
+ auto reminder = reminders->Add();
+ reminder->Set("title", "test reminder 2");
+ reminder->Add("notes", "note i");
+ reminder->Add("notes", "note ii");
+ reminder->Add("notes", "note iii");
+ }
+
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ EXPECT_THAT(entity_data->reminders, SizeIs(2));
+ EXPECT_EQ(entity_data->reminders[0]->title, "test reminder");
+ EXPECT_THAT(entity_data->reminders[0]->notes,
+ ElementsAre("note A", "note B"));
+ EXPECT_EQ(entity_data->reminders[1]->title, "test reminder 2");
+ EXPECT_THAT(entity_data->reminders[1]->notes,
+ ElementsAre("note i", "note ii", "note iii"));
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldAddMethodWithIncompatibleValues) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_FALSE(buffer->Repeated("numbers")->Add(static_cast<int64>(123)));
+ EXPECT_FALSE(buffer->Repeated("numbers")->Add(static_cast<int8>(9)));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(static_cast<int>(999)));
+
+ // Try to parse it with the generated code.
+ std::unique_ptr<test::EntityDataT> entity_data =
+ LoadAndVerifyMutableFlatbuffer<test::EntityData>(buffer->Serialize());
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_NE(entity_data, nullptr);
+ ASSERT_EQ(entity_data->numbers.size(), 1);
+ EXPECT_EQ(entity_data->numbers[0], 999);
+}
+
+TEST_F(MutableFlatbufferTest, RepeatedFieldGetAndSizeMethods) {
+ std::unique_ptr<MutableFlatbuffer> buffer = builder_.NewRoot();
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(1));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(2));
+ EXPECT_TRUE(buffer->Repeated("numbers")->Add(3));
+
+ EXPECT_EQ(buffer->Repeated("numbers")->Size(), 3);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(0), 1);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(1), 2);
+ EXPECT_EQ(buffer->Repeated("numbers")->Get<int>(2), 3);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.cc b/native/utils/flatbuffers/reflection.cc
index 4815d87..7d6d3f4 100644
--- a/native/utils/flatbuffers/reflection.cc
+++ b/native/utils/flatbuffers/reflection.cc
@@ -76,6 +76,16 @@
return nullptr;
}
+Optional<int> TypeIdForObject(const reflection::Schema* schema,
+ const reflection::Object* type) {
+ for (int i = 0; i < schema->objects()->size(); i++) {
+ if (schema->objects()->Get(i) == type) {
+ return Optional<int>(i);
+ }
+ }
+ return Optional<int>();
+}
+
Optional<int> TypeIdForName(const reflection::Schema* schema,
const StringPiece type_name) {
for (int i = 0; i < schema->objects()->size(); i++) {
@@ -116,4 +126,42 @@
return true;
}
+Variant ParseEnumValue(const reflection::Schema* schema,
+ const reflection::Type* type, StringPiece value) {
+ TC3_DCHECK(IsEnum(type));
+ TC3_CHECK_NE(schema->enums(), nullptr);
+ const auto* enum_values = schema->enums()->Get(type->index())->values();
+ if (enum_values == nullptr) {
+ TC3_LOG(ERROR) << "Enum has no specified values.";
+ return Variant();
+ }
+ for (const reflection::EnumVal* enum_value : *enum_values) {
+ if (value.Equals(StringPiece(enum_value->name()->c_str(),
+ enum_value->name()->size()))) {
+ const int64 value = enum_value->value();
+ switch (type->base_type()) {
+ case reflection::BaseType::Byte:
+ return Variant(static_cast<int8>(value));
+ case reflection::BaseType::UByte:
+ return Variant(static_cast<uint8>(value));
+ case reflection::BaseType::Short:
+ return Variant(static_cast<int16>(value));
+ case reflection::BaseType::UShort:
+ return Variant(static_cast<uint16>(value));
+ case reflection::BaseType::Int:
+ return Variant(static_cast<int32>(value));
+ case reflection::BaseType::UInt:
+ return Variant(static_cast<uint32>(value));
+ case reflection::BaseType::Long:
+ return Variant(value);
+ case reflection::BaseType::ULong:
+ return Variant(static_cast<uint64>(value));
+ default:
+ break;
+ }
+ }
+ }
+ return Variant();
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/reflection.h b/native/utils/flatbuffers/reflection.h
index b7f8990..8650a95 100644
--- a/native/utils/flatbuffers/reflection.h
+++ b/native/utils/flatbuffers/reflection.h
@@ -22,6 +22,7 @@
#include "utils/flatbuffers/flatbuffers_generated.h"
#include "utils/optional.h"
#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
#include "flatbuffers/reflection.h"
#include "flatbuffers/reflection_generated.h"
@@ -116,6 +117,10 @@
Optional<int> TypeIdForName(const reflection::Schema* schema,
const StringPiece type_name);
+// Gets the type id for a type.
+Optional<int> TypeIdForObject(const reflection::Schema* schema,
+ const reflection::Object* type);
+
// Resolves field lookups by name to the concrete field offsets.
bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
FlatbufferFieldPathT* path);
@@ -125,6 +130,10 @@
return flatbuffers::IsInteger(type->base_type()) && type->index() >= 0;
}
+// Parses an enum value.
+Variant ParseEnumValue(const reflection::Schema* schema,
+ const reflection::Type* type, StringPiece value);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_
diff --git a/native/utils/flatbuffers/reflection_test.cc b/native/utils/flatbuffers/reflection_test.cc
new file mode 100644
index 0000000..9a56b77
--- /dev/null
+++ b/native/utils/flatbuffers/reflection_test.cc
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers/reflection.h"
+
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/flatbuffers/test-utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(ReflectionTest, ResolvesFieldOffsets) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+
+ EXPECT_TRUE(SwapFieldNamesForOffsetsInPath(schema, &path));
+
+ EXPECT_THAT(path.field[0]->field_name, testing::IsEmpty());
+ EXPECT_EQ(14, path.field[0]->field_offset);
+ EXPECT_THAT(path.field[1]->field_name, testing::IsEmpty());
+ EXPECT_EQ(4, path.field[1]->field_offset);
+}
+
+TEST(ReflectionTest, ParseEnumValuesByName) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ const reflection::Field* field =
+ GetFieldOrNull(schema->root_table(), "enum_value");
+
+ Variant enum_value_0 = ParseEnumValue(schema, field->type(), "VALUE_0");
+ Variant enum_value_1 = ParseEnumValue(schema, field->type(), "VALUE_1");
+ Variant enum_no_value = ParseEnumValue(schema, field->type(), "NO_VALUE");
+
+ EXPECT_TRUE(enum_value_0.HasValue());
+ EXPECT_EQ(enum_value_0.GetType(), Variant::TYPE_INT_VALUE);
+ EXPECT_EQ(enum_value_0.Value<int>(), 0);
+
+ EXPECT_TRUE(enum_value_1.HasValue());
+ EXPECT_EQ(enum_value_1.GetType(), Variant::TYPE_INT_VALUE);
+ EXPECT_EQ(enum_value_1.Value<int>(), 1);
+
+ EXPECT_FALSE(enum_no_value.HasValue());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers/test-utils.h b/native/utils/flatbuffers/test-utils.h
new file mode 100644
index 0000000..5fd5b24
--- /dev/null
+++ b/native/utils/flatbuffers/test-utils.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common test utils.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
+
+#include <fstream>
+#include <string>
+
+#include "utils/flatbuffers/flatbuffers_generated.h"
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/test-data-test-utils.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+inline std::string LoadTestMetadata() {
+ std::ifstream test_config_stream(
+ GetTestDataPath("utils/flatbuffers/flatbuffers_test_extended.bfbs"));
+ return std::string((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+}
+
+// Creates a flatbuffer field path from a dot separated field path string.
+inline std::unique_ptr<FlatbufferFieldPathT> CreateFieldPath(
+ const StringPiece path) {
+ std::unique_ptr<FlatbufferFieldPathT> field_path(new FlatbufferFieldPathT);
+ for (const StringPiece field : strings::Split(path, '.')) {
+ field_path->field.emplace_back(new FlatbufferFieldT);
+ field_path->field.back()->field_name = field.ToString();
+ }
+ return field_path;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_TEST_UTILS_H_
diff --git a/native/utils/grammar/next/semantics/expression.fbs b/native/utils/grammar/next/semantics/expression.fbs
index 9d1e5d0..40c1eb1 100755
--- a/native/utils/grammar/next/semantics/expression.fbs
+++ b/native/utils/grammar/next/semantics/expression.fbs
@@ -23,6 +23,7 @@
ComposeExpression,
SpanAsStringExpression,
ParseNumberExpression,
+ MergeValueExpression,
}
// A semantic expression.
@@ -92,3 +93,12 @@
value:SemanticExpression;
}
+// Merge the semantic expressions.
+namespace libtextclassifier3.grammar.next;
+table MergeValueExpression {
+ // The id of the type of the result.
+ type:int;
+
+ values:[SemanticExpression];
+}
+
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
index b05b87f..ac15a44 100644
--- a/native/utils/grammar/utils/ir.h
+++ b/native/utils/grammar/utils/ir.h
@@ -105,12 +105,17 @@
const Nonterm nonterminal = ++num_nonterminals_;
if (!name.empty()) {
// Record debug information.
- nonterminal_names_[nonterminal] = name;
- nonterminal_ids_[name] = nonterminal;
+ SetNonterminal(name, nonterminal);
}
return nonterminal;
}
+ // Sets the name of a nonterminal.
+ void SetNonterminal(const std::string& name, const Nonterm nonterminal) {
+ nonterminal_names_[nonterminal] = name;
+ nonterminal_ids_[name] = nonterminal;
+ }
+
// Defines a nonterminal if not yet defined.
Nonterm DefineNonterminal(Nonterm nonterminal) {
return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal();
@@ -183,6 +188,12 @@
bool include_debug_information = false) const;
const std::vector<RulesShard>& shards() const { return shards_; }
+ const std::vector<std::pair<std::string, Nonterm>>& regex_rules() const {
+ return regex_rules_;
+ }
+ const std::vector<std::pair<std::string, Nonterm>>& annotations() const {
+ return annotations_;
+ }
private:
template <typename R, typename H>
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index d661a21..c988194 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -396,6 +396,17 @@
regex_rules_.push_back(regex_pattern);
}
+bool Rules::UsesFillers() const {
+ for (const Rule& rule : rules_) {
+ for (const RhsElement& rhs_element : rule.rhs) {
+ if (IsNonterminalOfName(rhs_element, kFiller)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
Ir rules(filters_, num_shards_);
std::unordered_map<int, Nonterm> nonterminal_ids;
@@ -455,6 +466,22 @@
rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
}
+ // Check whether fillers are still referenced (if they couldn't get optimized
+ // away).
+ if (UsesFillers()) {
+ TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
+ "matching performance might be impacted.";
+
+ // Add a definition for the filler:
+ // <filler> = <token>
+ // <filler> = <token> <filler>
+ const Nonterm filler = rules.GetNonterminalForName(kFiller);
+ const Nonterm token =
+ rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
+ rules.Add(filler, token);
+ rules.Add(filler, std::vector<Nonterm>{token, filler});
+ }
+
// Now, keep adding eligible rules (rules whose rhs is completely assigned)
// until we can't make any more progress.
// Note: The following code is quadratic in the worst case.
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 0c8d7da..b818d39 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -216,6 +216,9 @@
bool IsNonterminalOfName(const RhsElement& element,
const std::string& nonterminal) const;
+ // Checks whether the fillers are used in any active rule.
+ bool UsesFillers() const;
+
const int num_shards_;
// Non-terminal to id map.
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
index 6761118..30be704 100644
--- a/native/utils/grammar/utils/rules_test.cc
+++ b/native/utils/grammar/utils/rules_test.cc
@@ -180,6 +180,27 @@
EXPECT_THAT(frozen_rules.lhs, IsEmpty());
}
+TEST(SerializeRulesTest, HandlesFillers) {
+ Rules rules;
+ rules.Add("<test>", {"<filler>?", "a", "test"});
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("a\0test\0", 7));
+
+ // Expect removal of anchors and fillers in this case.
+ // The rule above is equivalent to: <code> ::= this is a test, binarized into
+ // <tmp_0> ::= <filler> a
+ // <test> ::= <tmp_0> test
+ // <test> ::= a test
+ // <filler> ::= <token> <filler>
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(4));
+ // <filler> ::= <token>
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
+}
+
TEST(SerializeRulesTest, HandlesAnnotations) {
Rules rules;
rules.AddAnnotation("phone");
diff --git a/native/utils/intents/intent-generator-test-lib.cc b/native/utils/intents/intent-generator-test-lib.cc
index 6e535dc..4207a3e 100644
--- a/native/utils/intents/intent-generator-test-lib.cc
+++ b/native/utils/intents/intent-generator-test-lib.cc
@@ -23,6 +23,7 @@
#include "utils/intents/intent-generator.h"
#include "utils/intents/remote-action-template.h"
#include "utils/java/jni-helper.h"
+#include "utils/jvm-test-utils.h"
#include "utils/resources_generated.h"
#include "utils/testing/logging_event_listener.h"
#include "utils/variant.h"
@@ -30,9 +31,6 @@
#include "gtest/gtest.h"
#include "flatbuffers/reflection.h"
-extern JNIEnv* g_jenv;
-extern jobject g_context;
-
namespace libtextclassifier3 {
namespace {
@@ -119,7 +117,7 @@
class IntentGeneratorTest : public testing::Test {
protected:
explicit IntentGeneratorTest()
- : jni_cache_(JniCache::Create(g_jenv)),
+ : jni_cache_(JniCache::Create(GetJenv())),
resource_buffer_(BuildTestResources()),
resources_(
flatbuffers::GetRoot<ResourcePool>(resource_buffer_.data())) {}
@@ -152,7 +150,7 @@
BuildTestIntentFactoryModel("test", R"lua(
return {
{
- -- Should fail, as no app g_context is provided.
+ -- Should fail, as no app GetAndroidContext() is provided.
data = external.android.package_name,
}
})lua");
@@ -162,7 +160,7 @@
ClassificationResult classification = {"test", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_FALSE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
/*reference_time_ms_utc=*/0, "test", {0, 4}, /*context=*/nullptr,
/*annotations_entity_data_schema=*/nullptr, &intents));
@@ -188,16 +186,17 @@
ClassificationResult classification = {"address", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
- /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, g_context,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Map");
EXPECT_EQ(intents[0].title_with_entity.value(), "333 E Wonderview Ave");
EXPECT_EQ(intents[0].description.value(), "Locate selected address");
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
- EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333+E+Wonderview+Ave");
+ EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
}
TEST_F(IntentGeneratorTest, HandlesCallbacks) {
@@ -230,12 +229,13 @@
ClassificationResult classification = {"test", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
- /*reference_time_ms_utc=*/0, "this is a test", {0, 14}, g_context,
+ /*reference_time_ms_utc=*/0, "this is a test", {0, 14},
+ GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
- EXPECT_EQ(intents[0].data.value(), "encoded=this+is+a+test");
+ EXPECT_EQ(intents[0].data.value(), "encoded=this%20is%20a%20test");
EXPECT_THAT(intents[0].category, ElementsAre("test_category"));
EXPECT_THAT(intents[0].extra, SizeIs(6));
EXPECT_EQ(intents[0].extra["package"].ConstRefValue<std::string>(),
@@ -282,8 +282,8 @@
{annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(), suggestion,
- conversation, g_context,
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
@@ -323,8 +323,8 @@
{}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(), suggestion,
- conversation, g_context,
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
@@ -372,8 +372,8 @@
{location_annotation, time_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(), suggestion,
- conversation, g_context,
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
@@ -418,8 +418,8 @@
{from_annotation, to_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(), suggestion,
- conversation, g_context,
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
@@ -446,15 +446,16 @@
ClassificationResult classification = {"address", 1.0};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "de-DE").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "de-DE").ValueOrDie().get(),
classification,
- /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20}, g_context,
+ /*reference_time_ms_utc=*/0, "333 E Wonderview Ave", {0, 20},
+ GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_EQ(intents[0].title_without_entity.value(), "Karte");
EXPECT_EQ(intents[0].description.value(), "Ausgewählte Adresse finden");
EXPECT_EQ(intents[0].action.value(), "android.intent.action.VIEW");
- EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333+E+Wonderview+Ave");
+ EXPECT_EQ(intents[0].data.value(), "geo:0,0?q=333%20E%20Wonderview%20Ave");
}
TEST_F(IntentGeneratorTest, HandlesIteration) {
@@ -491,8 +492,8 @@
{location_annotation, greeting_annotation}};
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(), suggestion,
- conversation, g_context,
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
+ suggestion, conversation, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
@@ -597,15 +598,15 @@
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
- /*reference_time_ms_utc=*/0, "highground", {0, 10}, g_context,
+ /*reference_time_ms_utc=*/0, "highground", {0, 10}, GetAndroidContext(),
/*annotations_entity_data_schema=*/entity_data_schema, &intents));
EXPECT_THAT(intents, SizeIs(1));
EXPECT_THAT(intents[0].extra, SizeIs(3));
EXPECT_EQ(intents[0].extra["name"].ConstRefValue<std::string>(), "kenobi");
EXPECT_EQ(intents[0].extra["encoded_phone"].ConstRefValue<std::string>(),
- "1+800+HIGHGROUND");
+ "1%20800%20HIGHGROUND");
EXPECT_EQ(intents[0].extra["age"].Value<int>(), 38);
}
@@ -635,9 +636,9 @@
std::vector<RemoteActionTemplate> intents;
EXPECT_TRUE(generator->GenerateIntents(
- JniHelper::NewStringUTF(g_jenv, "en-US").ValueOrDie().get(),
+ JniHelper::NewStringUTF(GetJenv(), "en-US").ValueOrDie().get(),
classification,
- /*reference_time_ms_utc=*/0, "test", {0, 4}, g_context,
+ /*reference_time_ms_utc=*/0, "test", {0, 4}, GetAndroidContext(),
/*annotations_entity_data_schema=*/nullptr, &intents));
EXPECT_THAT(intents, SizeIs(1));
diff --git a/native/utils/intents/jni-lua.cc b/native/utils/intents/jni-lua.cc
index f151f4d..71a466e 100644
--- a/native/utils/intents/jni-lua.cc
+++ b/native/utils/intents/jni-lua.cc
@@ -245,7 +245,7 @@
return 0;
}
- // Call Java URL encoder.
+ // Call Java Uri encode.
const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
jni_cache_->ConvertToJavaString(input);
if (!status_or_input_str.ok()) {
@@ -254,12 +254,11 @@
}
StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
JniHelper::CallStaticObjectMethod<jstring>(
- jenv_, jni_cache_->urlencoder_class.get(),
- jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
- jni_cache_->string_utf8.get());
+ jenv_, jni_cache_->uri_class.get(), jni_cache_->uri_encode,
+ status_or_input_str.ValueOrDie().get());
if (!status_or_encoded_str.ok()) {
- TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
+ TC3_LOG(ERROR) << "Error calling Uri.encode";
lua_error(state_);
return 0;
}
diff --git a/native/utils/java/jni-cache.cc b/native/utils/java/jni-cache.cc
index 58d3369..824141a 100644
--- a/native/utils/java/jni-cache.cc
+++ b/native/utils/java/jni-cache.cc
@@ -34,8 +34,7 @@
breakiterator_class(nullptr, jvm),
integer_class(nullptr, jvm),
calendar_class(nullptr, jvm),
- timezone_class(nullptr, jvm),
- urlencoder_class(nullptr, jvm)
+ timezone_class(nullptr, jvm)
#ifdef __ANDROID__
,
context_class(nullptr, jvm),
@@ -222,12 +221,6 @@
TC3_GET_STATIC_METHOD(timezone, get_timezone, "getTimeZone",
"(Ljava/lang/String;)Ljava/util/TimeZone;");
- // URLEncoder.
- TC3_GET_CLASS_OR_RETURN_NULL(urlencoder, "java/net/URLEncoder");
- TC3_GET_STATIC_METHOD(
- urlencoder, encode, "encode",
- "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");
-
#ifdef __ANDROID__
// Context.
TC3_GET_CLASS_OR_RETURN_NULL(context, "android/content/Context");
@@ -242,6 +235,8 @@
"(Ljava/lang/String;)Landroid/net/Uri;");
TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;");
TC3_GET_METHOD(uri, get_host, "getHost", "()Ljava/lang/String;");
+ TC3_GET_STATIC_METHOD(uri, encode, "encode",
+ "(Ljava/lang/String;)Ljava/lang/String;");
// UserManager.
TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager");
diff --git a/native/utils/java/jni-cache.h b/native/utils/java/jni-cache.h
index ab48419..8754f4c 100644
--- a/native/utils/java/jni-cache.h
+++ b/native/utils/java/jni-cache.h
@@ -107,10 +107,6 @@
ScopedGlobalRef<jclass> timezone_class;
jmethodID timezone_get_timezone = nullptr;
- // java.net.URLEncoder
- ScopedGlobalRef<jclass> urlencoder_class;
- jmethodID urlencoder_encode = nullptr;
-
// android.content.Context
ScopedGlobalRef<jclass> context_class;
jmethodID context_get_package_name = nullptr;
@@ -121,6 +117,7 @@
jmethodID uri_parse = nullptr;
jmethodID uri_get_scheme = nullptr;
jmethodID uri_get_host = nullptr;
+ jmethodID uri_encode = nullptr;
// android.os.UserManager
ScopedGlobalRef<jclass> usermanager_class;
diff --git a/native/utils/jvm-test-utils.h b/native/utils/jvm-test-utils.h
new file mode 100644
index 0000000..4600248
--- /dev/null
+++ b/native/utils/jvm-test-utils.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
+
+#include "utils/calendar/calendar.h"
+#include "utils/utf8/unilib.h"
+
+#if defined(__ANDROID__)
+#include <jni.h>
+
+// Use these with jvm_test_launcher which provides these global variables.
+extern JNIEnv* g_jenv;
+extern jobject g_context;
+#endif
+
+namespace libtextclassifier3 {
+inline std::unique_ptr<UniLib> CreateUniLibForTesting() {
+#if defined TC3_UNILIB_JAVAICU
+ return std::make_unique<UniLib>(JniCache::Create(g_jenv));
+#else
+ return std::make_unique<UniLib>();
+#endif
+}
+
+inline std::unique_ptr<CalendarLib> CreateCalendarLibForTesting() {
+#if defined TC3_CALENDAR_JAVAICU
+ return std::make_unique<CalendarLib>(JniCache::Create(g_jenv));
+#else
+ return std::make_unique<CalendarLib>();
+#endif
+}
+
+#if defined(__ANDROID__)
+inline JNIEnv* GetJenv() { return g_jenv; }
+
+inline jobject GetAndroidContext() { return g_context; }
+#endif
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JVM_TEST_UTILS_H_
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index 0a7b33f..9117c54 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -81,6 +81,8 @@
const reflection::BaseType field_type = field->type()->base_type();
switch (field_type) {
case reflection::Bool:
+ Push(table->GetField<bool>(field->offset(), field->default_integer()));
+ break;
case reflection::UByte:
Push(table->GetField<uint8>(field->offset(), field->default_integer()));
break;
@@ -93,12 +95,6 @@
case reflection::UInt:
Push(table->GetField<uint32>(field->offset(), field->default_integer()));
break;
- case reflection::Short:
- Push(table->GetField<int16>(field->offset(), field->default_integer()));
- break;
- case reflection::UShort:
- Push(table->GetField<uint16>(field->offset(), field->default_integer()));
- break;
case reflection::Long:
Push(table->GetField<int64>(field->offset(), field->default_integer()));
break;
@@ -139,6 +135,9 @@
}
switch (field->type()->element()) {
case reflection::Bool:
+ PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
+ field->offset()));
+ break;
case reflection::UByte:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<uint8>*>(
@@ -158,16 +157,6 @@
table->GetPointer<const flatbuffers::Vector<uint32>*>(
field->offset()));
break;
- case reflection::Short:
- PushRepeatedField(
- table->GetPointer<const flatbuffers::Vector<int16>*>(
- field->offset()));
- break;
- case reflection::UShort:
- PushRepeatedField(
- table->GetPointer<const flatbuffers::Vector<uint16>*>(
- field->offset()));
- break;
case reflection::Long:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<int64>*>(
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
index 98c451c..a76c790 100644
--- a/native/utils/lua-utils.h
+++ b/native/utils/lua-utils.h
@@ -137,42 +137,42 @@
template <>
int64 Read<int64>(const int index) const {
- return static_cast<int64>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint64 Read<uint64>(const int index) const {
- return static_cast<uint64>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int32 Read<int32>(const int index) const {
- return static_cast<int32>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint32 Read<uint32>(const int index) const {
- return static_cast<uint32>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int16 Read<int16>(const int index) const {
- return static_cast<int16>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint16 Read<uint16>(const int index) const {
- return static_cast<uint16>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int8 Read<int8>(const int index) const {
- return static_cast<int8>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint8 Read<uint8>(const int index) const {
- return static_cast<uint8>(lua_tonumber(state_, /*idx=*/index));
+ return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
@@ -507,7 +507,7 @@
// Reads a repeated field from lua.
template <typename T>
void ReadRepeatedField(const int index, RepeatedField* result) const {
- for (const auto& element : ReadVector<T>(index)) {
+ for (const T& element : ReadVector<T>(index)) {
result->Add(element);
}
}
diff --git a/native/utils/lua-utils_test.cc b/native/utils/lua-utils_test.cc
index b4f6181..22a8a87 100644
--- a/native/utils/lua-utils_test.cc
+++ b/native/utils/lua-utils_test.cc
@@ -16,97 +16,31 @@
#include "utils/lua-utils.h"
+#include <memory>
#include <string>
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
+#include "utils/lua_utils_tests_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/test-data-test-utils.h"
+#include "utils/testing/test_data_generator.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
+using testing::DoubleEq;
using testing::ElementsAre;
using testing::Eq;
using testing::FloatEq;
-std::string TestFlatbufferSchema() {
- // Creates a test schema for flatbuffer passing tests.
- // Cannot use the object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("float_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Float),
- /*id=*/0,
- /*offset=*/4),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("nested_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Obj,
- /*element=*/reflection::None,
- /*index=*/0 /* self */),
- /*id=*/1,
- /*offset=*/6),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("repeated_nested_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Vector,
- /*element=*/reflection::Obj,
- /*index=*/0 /* self */),
- /*id=*/2,
- /*offset=*/8),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("repeated_string_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::Vector,
- /*element=*/reflection::String),
- /*id=*/3,
- /*offset=*/10),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("string_field"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/4,
- /*offset=*/12)};
-
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("TestData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
- return std::string(
- reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
- schema_builder.GetSize());
-}
-
class LuaUtilsTest : public testing::Test, protected LuaEnvironment {
protected:
LuaUtilsTest()
- : serialized_flatbuffer_schema_(TestFlatbufferSchema()),
- schema_(flatbuffers::GetRoot<reflection::Schema>(
- serialized_flatbuffer_schema_.data())),
- flatbuffer_builder_(schema_) {
+ : schema_(GetTestFileContent("utils/lua_utils_tests.bfbs")),
+ flatbuffer_builder_(schema_.get()) {
EXPECT_THAT(RunProtected([this] {
LoadDefaultLibraries();
return LUA_OK;
@@ -123,63 +57,121 @@
Eq(LUA_OK));
}
- const std::string serialized_flatbuffer_schema_;
- const reflection::Schema* schema_;
+ OwnedFlatbuffer<reflection::Schema, std::string> schema_;
MutableFlatbufferBuilder flatbuffer_builder_;
+ TestDataGenerator test_data_generator_;
};
-TEST_F(LuaUtilsTest, HandlesVectors) {
- {
- PushVector(std::vector<int64>{1, 2, 3, 4, 5});
- EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
- }
- {
- PushVector(std::vector<std::string>{"hello", "there"});
- EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
- }
- {
- PushVector(std::vector<bool>{true, true, false});
- EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
- }
+template <typename T>
+class TypedLuaUtilsTest : public LuaUtilsTest {};
+
+using testing::Types;
+using LuaTypes =
+ ::testing::Types<int64, uint64, int32, uint32, int16, uint16, int8, uint8,
+ float, double, bool, std::string>;
+TYPED_TEST_SUITE(TypedLuaUtilsTest, LuaTypes);
+
+TYPED_TEST(TypedLuaUtilsTest, HandlesVectors) {
+ std::vector<TypeParam> elements(5);
+ std::generate_n(elements.begin(), 5, [&]() {
+ return this->test_data_generator_.template generate<TypeParam>();
+ });
+
+ this->PushVector(elements);
+
+ EXPECT_THAT(this->template ReadVector<TypeParam>(),
+ testing::ContainerEq(elements));
}
-TEST_F(LuaUtilsTest, HandlesVectorIterators) {
- {
- const std::vector<int64> elements = {1, 2, 3, 4, 5};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
- }
- {
- const std::vector<std::string> elements = {"hello", "there"};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
- }
- {
- const std::vector<bool> elements = {true, true, false};
- PushVectorIterator(&elements);
- EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
- }
+TYPED_TEST(TypedLuaUtilsTest, HandlesVectorIterators) {
+ std::vector<TypeParam> elements(5);
+ std::generate_n(elements.begin(), 5, [&]() {
+ return this->test_data_generator_.template generate<TypeParam>();
+ });
+
+ this->PushVectorIterator(&elements);
+
+ EXPECT_THAT(this->template ReadVector<TypeParam>(),
+ testing::ContainerEq(elements));
}
-TEST_F(LuaUtilsTest, ReadsFlatbufferResults) {
+TEST_F(LuaUtilsTest, PushAndReadsFlatbufferRoundTrip) {
// Setup.
+ test::TestDataT input_data;
+ input_data.byte_field = 1;
+ input_data.ubyte_field = 2;
+ input_data.int_field = 10;
+ input_data.uint_field = 11;
+ input_data.long_field = 20;
+ input_data.ulong_field = 21;
+ input_data.bool_field = true;
+ input_data.float_field = 42.1;
+ input_data.double_field = 12.4;
+ input_data.string_field = "hello there";
+ // Nested field.
+ input_data.nested_field = std::make_unique<test::TestDataT>();
+ input_data.nested_field->float_field = 64;
+ input_data.nested_field->string_field = "hello nested";
+ // Repeated fields.
+ input_data.repeated_byte_field = {1, 2, 1};
+ input_data.repeated_byte_field = {1, 2, 1};
+ input_data.repeated_ubyte_field = {2, 4, 2};
+ input_data.repeated_int_field = {1, 2, 3};
+ input_data.repeated_uint_field = {2, 4, 6};
+ input_data.repeated_long_field = {4, 5, 6};
+ input_data.repeated_ulong_field = {8, 10, 12};
+ input_data.repeated_bool_field = {true, false, true};
+ input_data.repeated_float_field = {1.23, 2.34, 3.45};
+ input_data.repeated_double_field = {1.11, 2.22, 3.33};
+ input_data.repeated_string_field = {"a", "bold", "one"};
+ // Repeated nested fields.
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->string_field = "a";
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->string_field = "b";
+ input_data.repeated_nested_field.push_back(
+ std::make_unique<test::TestDataT>());
+ input_data.repeated_nested_field.back()->repeated_string_field = {"nested",
+ "nested2"};
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(test::TestData::Pack(builder, &input_data));
+ const flatbuffers::DetachedBuffer input_buffer = builder.Release();
+ PushFlatbuffer(schema_.get(),
+ flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
+ lua_setglobal(state_, "arg");
+
RunScript(R"lua(
return {
- float_field = 42.1,
- string_field = "hello there",
-
- -- Nested field.
+ byte_field = arg.byte_field,
+ ubyte_field = arg.ubyte_field,
+ int_field = arg.int_field,
+ uint_field = arg.uint_field,
+ long_field = arg.long_field,
+ ulong_field = arg.ulong_field,
+ bool_field = arg.bool_field,
+ float_field = arg.float_field,
+ double_field = arg.double_field,
+ string_field = arg.string_field,
nested_field = {
- float_field = 64,
- string_field = "hello nested",
+ float_field = arg.nested_field.float_field,
+ string_field = arg.nested_field.string_field,
},
-
- -- Repeated fields.
- repeated_string_field = { "a", "bold", "one" },
+ repeated_byte_field = arg.repeated_byte_field,
+ repeated_ubyte_field = arg.repeated_ubyte_field,
+ repeated_int_field = arg.repeated_int_field,
+ repeated_uint_field = arg.repeated_uint_field,
+ repeated_long_field = arg.repeated_long_field,
+ repeated_ulong_field = arg.repeated_ulong_field,
+ repeated_bool_field = arg.repeated_bool_field,
+ repeated_float_field = arg.repeated_float_field,
+ repeated_double_field = arg.repeated_double_field,
+ repeated_string_field = arg.repeated_string_field,
repeated_nested_field = {
- { string_field = "a" },
- { string_field = "b" },
- { repeated_string_field = { "nested", "nested2" } },
+ { string_field = arg.repeated_nested_field[1].string_field },
+ { string_field = arg.repeated_nested_field[2].string_field },
+ { repeated_string_field = arg.repeated_nested_field[3].repeated_string_field },
},
}
)lua");
@@ -188,89 +180,38 @@
std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
ReadFlatbuffer(/*index=*/-1, buffer.get());
const std::string serialized_buffer = buffer->Serialize();
+ std::unique_ptr<test::TestDataT> test_data =
+ LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
- // Check fields. As we do not have flatbuffer compiled generated code for the
- // ad hoc generated test schema, we have to read by manually using field
- // offsets.
- const flatbuffers::Table* flatbuffer_data =
- flatbuffers::GetRoot<flatbuffers::Table>(serialized_buffer.data());
- EXPECT_THAT(flatbuffer_data->GetField<float>(/*field=*/4, /*defaultval=*/0),
- FloatEq(42.1));
- EXPECT_THAT(
- flatbuffer_data->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "hello there");
-
- // Read the nested field.
- const flatbuffers::Table* nested_field =
- flatbuffer_data->GetPointer<const flatbuffers::Table*>(/*field=*/6);
- EXPECT_THAT(nested_field->GetField<float>(/*field=*/4, /*defaultval=*/0),
- FloatEq(64));
- EXPECT_THAT(
- nested_field->GetPointer<const flatbuffers::String*>(/*field=*/12)->str(),
- "hello nested");
-
- // Read the repeated string field.
- auto repeated_strings = flatbuffer_data->GetPointer<
- flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
- /*field=*/10);
- EXPECT_THAT(repeated_strings->size(), Eq(3));
- EXPECT_THAT(repeated_strings->GetAsString(0)->str(), Eq("a"));
- EXPECT_THAT(repeated_strings->GetAsString(1)->str(), Eq("bold"));
- EXPECT_THAT(repeated_strings->GetAsString(2)->str(), Eq("one"));
-
- // Read the repeated nested field.
- auto repeated_nested_fields = flatbuffer_data->GetPointer<
- flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>*>(
- /*field=*/8);
- EXPECT_THAT(repeated_nested_fields->size(), Eq(3));
- EXPECT_THAT(repeated_nested_fields->Get(0)
- ->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "a");
- EXPECT_THAT(repeated_nested_fields->Get(1)
- ->GetPointer<const flatbuffers::String*>(/*field=*/12)
- ->str(),
- "b");
-}
-
-TEST_F(LuaUtilsTest, HandlesSimpleFlatbufferFields) {
- // Create test flatbuffer.
- std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
- buffer->Set("float_field", 42.f);
- const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
- lua_setglobal(state_, "arg");
-
- // Setup.
- RunScript(R"lua(
- return arg.float_field
- )lua");
-
- EXPECT_THAT(Read<float>(), FloatEq(42));
-}
-
-TEST_F(LuaUtilsTest, HandlesRepeatedFlatbufferFields) {
- // Create test flatbuffer.
- std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
- RepeatedField* repeated_field = buffer->Repeated("repeated_string_field");
- repeated_field->Add("this");
- repeated_field->Add("is");
- repeated_field->Add("a");
- repeated_field->Add("test");
- const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
- lua_setglobal(state_, "arg");
-
- // Return flatbuffer repeated field as vector.
- RunScript(R"lua(
- return arg.repeated_string_field
- )lua");
-
- EXPECT_THAT(ReadVector<std::string>(),
- ElementsAre("this", "is", "a", "test"));
+ EXPECT_THAT(test_data->byte_field, 1);
+ EXPECT_THAT(test_data->ubyte_field, 2);
+ EXPECT_THAT(test_data->int_field, 10);
+ EXPECT_THAT(test_data->uint_field, 11);
+ EXPECT_THAT(test_data->long_field, 20);
+ EXPECT_THAT(test_data->ulong_field, 21);
+ EXPECT_THAT(test_data->bool_field, true);
+ EXPECT_THAT(test_data->float_field, FloatEq(42.1));
+ EXPECT_THAT(test_data->double_field, DoubleEq(12.4));
+ EXPECT_THAT(test_data->string_field, "hello there");
+ EXPECT_THAT(test_data->repeated_byte_field, ElementsAre(1, 2, 1));
+ EXPECT_THAT(test_data->repeated_ubyte_field, ElementsAre(2, 4, 2));
+ EXPECT_THAT(test_data->repeated_int_field, ElementsAre(1, 2, 3));
+ EXPECT_THAT(test_data->repeated_uint_field, ElementsAre(2, 4, 6));
+ EXPECT_THAT(test_data->repeated_long_field, ElementsAre(4, 5, 6));
+ EXPECT_THAT(test_data->repeated_ulong_field, ElementsAre(8, 10, 12));
+ EXPECT_THAT(test_data->repeated_bool_field, ElementsAre(true, false, true));
+ EXPECT_THAT(test_data->repeated_float_field, ElementsAre(1.23, 2.34, 3.45));
+ EXPECT_THAT(test_data->repeated_double_field, ElementsAre(1.11, 2.22, 3.33));
+ EXPECT_THAT(test_data->repeated_string_field,
+ ElementsAre("a", "bold", "one"));
+ // Nested fields.
+ EXPECT_THAT(test_data->nested_field->float_field, FloatEq(64));
+ EXPECT_THAT(test_data->nested_field->string_field, "hello nested");
+ // Repeated nested fields.
+ EXPECT_THAT(test_data->repeated_nested_field[0]->string_field, "a");
+ EXPECT_THAT(test_data->repeated_nested_field[1]->string_field, "b");
+ EXPECT_THAT(test_data->repeated_nested_field[2]->repeated_string_field,
+ ElementsAre("nested", "nested2"));
}
TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
@@ -287,8 +228,8 @@
nested_repeated->Add("are");
repeated_field->Add()->Set("string_field", "you?");
const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
lua_setglobal(state_, "arg");
RunScript(R"lua(
@@ -312,15 +253,15 @@
std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
buffer->Set("string_field", "first");
const std::string serialized_buffer = buffer->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
lua_setglobal(state_, "arg");
// The second flatbuffer.
std::unique_ptr<MutableFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
buffer2->Set("string_field", "second");
const std::string serialized_buffer2 = buffer2->Serialize();
- PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
- serialized_buffer2.data()));
+ PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer2.data()));
lua_setglobal(state_, "arg2");
RunScript(R"lua(
diff --git a/native/utils/lua_utils_tests.bfbs b/native/utils/lua_utils_tests.bfbs
new file mode 100644
index 0000000..acb731b
--- /dev/null
+++ b/native/utils/lua_utils_tests.bfbs
Binary files differ
diff --git a/native/utils/lua_utils_tests.fbs b/native/utils/lua_utils_tests.fbs
new file mode 100644
index 0000000..6d8ad38
--- /dev/null
+++ b/native/utils/lua_utils_tests.fbs
@@ -0,0 +1,45 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.test;
+
+table TestData {
+ byte_field: byte;
+ ubyte_field: ubyte;
+ int_field: int;
+ uint_field: uint;
+ long_field: int64;
+ ulong_field: uint64;
+ bool_field: bool;
+ float_field: float;
+ double_field: double;
+ string_field: string;
+ nested_field: TestData;
+
+ repeated_byte_field: [byte];
+ repeated_ubyte_field: [ubyte];
+ repeated_int_field: [int];
+ repeated_uint_field: [uint];
+ repeated_long_field: [int64];
+ repeated_ulong_field: [uint64];
+ repeated_bool_field: [bool];
+ repeated_float_field: [float];
+ repeated_double_field: [double];
+ repeated_string_field: [string];
+ repeated_nested_field: [TestData];
+}
+
+root_type libtextclassifier3.test.TestData;
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
index b4d9b83..0a05718 100755
--- a/native/utils/resources.fbs
+++ b/native/utils/resources.fbs
@@ -14,8 +14,8 @@
// limitations under the License.
//
-include "utils/i18n/language-tag.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/i18n/language-tag.fbs";
namespace libtextclassifier3;
table Resource {
diff --git a/native/utils/test-data-test-utils.h b/native/utils/test-data-test-utils.h
index 8bafbeb..61f6d97 100644
--- a/native/utils/test-data-test-utils.h
+++ b/native/utils/test-data-test-utils.h
@@ -18,16 +18,21 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_TEST_DATA_TEST_UTILS_H_
+#include <fstream>
#include "gtest/gtest.h"
-#include "android-base/file.h"
namespace libtextclassifier3 {
// Get the file path to the test data.
inline std::string GetTestDataPath(const std::string& relative_path) {
- return android::base::GetExecutableDirectory() + "/" +
- relative_path;
+ return "/data/local/tmp/" + relative_path;
+}
+
+inline std::string GetTestFileContent(const std::string& relative_path) {
+ const std::string full_path = GetTestDataPath(relative_path);
+ std::ifstream file_stream(full_path);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.cc b/native/utils/testing/annotator.cc
new file mode 100644
index 0000000..47c3fb7
--- /dev/null
+++ b/native/utils/testing/annotator.cc
@@ -0,0 +1,206 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/testing/annotator.h"
+
+#include "utils/flatbuffers/mutable.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+std::string FirstResult(const std::vector<ClassificationResult>& results) {
+ if (results.empty()) {
+ return "<INVALID RESULTS>";
+ }
+ return results[0].collection;
+}
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score,
+ const float priority_score) {
+ std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
+ result->collection_name = collection_name;
+ result->pattern = pattern;
+ // We cannot directly operate with |= on the flag, so use an int here.
+ int enabled_modes = ModeFlag_NONE;
+ if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
+ if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
+ if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
+ result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
+ result->target_classification_score = score;
+ result->priority_score = priority_score;
+ return result;
+}
+
+// Shortcut function that doesn't need to specify the priority score.
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score) {
+ return MakePattern(collection_name, pattern, enabled_for_classification,
+ enabled_for_selection, enabled_for_annotation,
+ /*score=*/score,
+ /*priority_score=*/score);
+}
+
+void AddTestRegexModel(ModelT* unpacked_model) {
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person_with_age", "(Barack) (?:(Obama) )?is (\\d+) years old",
+ /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true, 1.0));
+
+ // Use meta data to generate custom serialized entity data.
+ MutableFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ unpacked_model->entity_data_schema.data()));
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+
+ {
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("is_alive", true);
+ pattern->serialized_entity_data = entity_data->Serialize();
+ }
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ pattern->capturing_group.emplace_back(new CapturingGroupT);
+ // Group 0 is the full match, capturing groups starting at 1.
+ pattern->capturing_group[1]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[1]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
+ "first_name";
+ pattern->capturing_group[2]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[2]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
+ "last_name";
+ // Set `former_us_president` field if we match Obama.
+ {
+ std::unique_ptr<MutableFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("former_us_president", true);
+ pattern->capturing_group[2]->serialized_entity_data =
+ entity_data->Serialize();
+ }
+ pattern->capturing_group[3]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[3]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
+ "age";
+}
+
+std::string CreateEmptyModel(
+ const std::function<void(ModelT* model)> model_update_fn) {
+ ModelT model;
+ model_update_fn(&model);
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, &model));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model) {
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("first_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("is_alive"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("last_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("age"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Int),
+ /*id=*/3,
+ /*offset=*/10),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("former_us_president"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/4,
+ /*offset=*/12)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ unpacked_model->entity_data_schema.assign(
+ schema_builder.GetBufferPointer(),
+ schema_builder.GetBufferPointer() + schema_builder.GetSize());
+}
+
+AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
+ const std::string& collection,
+ const float score,
+ AnnotatedSpan::Source source) {
+ AnnotatedSpan result;
+ result.span = span;
+ result.classification.push_back({collection, score});
+ result.source = source;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/testing/annotator.h b/native/utils/testing/annotator.h
new file mode 100644
index 0000000..794e55f
--- /dev/null
+++ b/native/utils/testing/annotator.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Helper utilities for testing Annotator.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
+// can modify it. Afterwards the modified unpacked model is serialized back to a
+// flatbuffer.
+template <typename Fn>
+std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
+ Fn visitor_fn) {
+ std::unique_ptr<ModelT> unpacked_model =
+ UnPackModel(model_flatbuffer.c_str());
+
+ visitor_fn(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::string FirstResult(const std::vector<ClassificationResult>& results);
+
+std::string ReadFile(const std::string& file_name);
+
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score,
+ const float priority_score);
+
+// Shortcut function that doesn't need to specify the priority score.
+std::unique_ptr<RegexModel_::PatternT> MakePattern(
+ const std::string& collection_name, const std::string& pattern,
+ const bool enabled_for_classification, const bool enabled_for_selection,
+ const bool enabled_for_annotation, const float score);
+
+void AddTestRegexModel(ModelT* unpacked_model);
+
+// Creates empty serialized Annotator model. Takes optional function that can
+// modify the unpacked ModelT before the serialization.
+std::string CreateEmptyModel(
+ const std::function<void(ModelT* model)> model_update_fn =
+ [](ModelT* model) {});
+
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model);
+
+AnnotatedSpan MakeAnnotatedSpan(
+ CodepointSpan span, const std::string& collection, const float score,
+ AnnotatedSpan::Source source = AnnotatedSpan::Source::OTHER);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
new file mode 100644
index 0000000..30c7aed
--- /dev/null
+++ b/native/utils/testing/test_data_generator.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
+
+#include <algorithm>
+#include <iostream>
+#include <random>
+
+#include "utils/strings/stringpiece.h"
+
+// Generates test data randomly.
+class TestDataGenerator {
+ public:
+ explicit TestDataGenerator() : random_engine_(0) {}
+
+ template <typename T,
+ typename std::enable_if_t<std::is_integral<T>::value>* = nullptr>
+ T generate() {
+ std::uniform_int_distribution<T> dist;
+ return dist(random_engine_);
+ }
+
+ template <typename T, typename std::enable_if_t<
+ std::is_floating_point<T>::value>* = nullptr>
+ T generate() {
+ std::uniform_real_distribution<T> dist;
+ return dist(random_engine_);
+ }
+
+ template <typename T, typename std::enable_if_t<
+ std::is_same<std::string, T>::value>* = nullptr>
+ T generate() {
+ std::uniform_int_distribution<> dist(1, 10);
+ return std::string(dist(random_engine_), '*');
+ }
+
+ private:
+ std::default_random_engine random_engine_;
+};
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_TEST_DATA_GENERATOR_H_
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 31ed414..2dbd786 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -59,6 +59,9 @@
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_WHERE();
+TfLiteRegistration* Register_ONE_HOT();
+TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_TANH();
} // namespace builtin
} // namespace ops
} // namespace tflite
@@ -176,6 +179,18 @@
tflite::ops::builtin::Register_LOG_SOFTMAX());
resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
::tflite::ops::builtin::Register_WHERE());
+ resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
+ tflite::ops::builtin::Register_ONE_HOT(),
+ /*min_version=*/1,
+ /*max_version=*/1);
+ resolver->AddBuiltin(tflite::BuiltinOperator_POW,
+ tflite::ops::builtin::Register_POW(),
+ /*min_version=*/1,
+ /*max_version=*/1);
+ resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
+ tflite::ops::builtin::Register_TANH(),
+ /*min_version=*/1,
+ /*max_version=*/1);
}
#else
void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
diff --git a/native/utils/tflite/text_encoder.cc b/native/utils/tflite/text_encoder.cc
index 78bb51a..80b5854 100644
--- a/native/utils/tflite/text_encoder.cc
+++ b/native/utils/tflite/text_encoder.cc
@@ -100,14 +100,14 @@
config->add_dummy_prefix(), config->remove_extra_whitespaces(),
config->escape_whitespaces()));
- const int num_pieces = config->pieces_scores()->Length();
+ const int num_pieces = config->pieces_scores()->size();
switch (config->matcher_type()) {
case SentencePieceMatcherType_MAPPED_TRIE: {
const TrieNode* pieces_trie_nodes =
reinterpret_cast<const TrieNode*>(config->pieces()->Data());
const int pieces_trie_nodes_length =
- config->pieces()->Length() / sizeof(TrieNode);
+ config->pieces()->size() / sizeof(TrieNode);
encoder_op->matcher.reset(
new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
break;
@@ -115,7 +115,7 @@
case SentencePieceMatcherType_SORTED_STRING_TABLE: {
encoder_op->matcher.reset(new SortedStringsTable(
num_pieces, config->pieces_offsets()->data(),
- StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ StringPiece(config->pieces()->data(), config->pieces()->size())));
break;
}
default: {
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
index 30149af..7423cf3 100644
--- a/native/utils/utf8/unilib-common.cc
+++ b/native/utils/utf8/unilib-common.cc
@@ -395,6 +395,18 @@
constexpr char32 kDots[] = {0x002e, 0xfe52, 0xff0e};
constexpr int kNumDots = ARRAYSIZE(kDots);
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=Apostrophe
+constexpr char32 kApostrophe[] = {0x0027, 0x02BC, 0x02EE, 0x055A,
+ 0x07F4, 0x07F5, 0xFF07};
+constexpr int kNumApostrophe = ARRAYSIZE(kApostrophe);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=Quotation
+constexpr char32 kQuotation[] = {
+ 0x0022, 0x00AB, 0x00BB, 0x2018, 0x2019, 0x201A, 0x201B, 0x201C,
+ 0x201D, 0x201E, 0x201F, 0x2039, 0x203A, 0x275B, 0x275C, 0x275D,
+ 0x275E, 0x276E, 0x276F, 0x2E42, 0x301D, 0x301E, 0x301F, 0xFF02};
+constexpr int kNumQuotation = ARRAYSIZE(kQuotation);
+
#undef ARRAYSIZE
static_assert(kNumOpeningBrackets == kNumClosingBrackets,
@@ -576,6 +588,14 @@
return GetMatchIndex(kDots, kNumDots, codepoint) >= 0;
}
+bool IsApostrophe(char32 codepoint) {
+ return GetMatchIndex(kApostrophe, kNumApostrophe, codepoint) >= 0;
+}
+
+bool IsQuotation(char32 codepoint) {
+ return GetMatchIndex(kQuotation, kNumQuotation, codepoint) >= 0;
+}
+
bool IsLatinLetter(char32 codepoint) {
return (GetOverlappingRangeIndex(
kLatinLettersRangesStart, kLatinLettersRangesEnd,
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
index eeffe9c..2788f3c 100644
--- a/native/utils/utf8/unilib-common.h
+++ b/native/utils/utf8/unilib-common.h
@@ -35,6 +35,8 @@
bool IsMinus(char32 codepoint);
bool IsNumberSign(char32 codepoint);
bool IsDot(char32 codepoint);
+bool IsApostrophe(char32 codepoint);
+bool IsQuotation(char32 codepoint);
bool IsLatinLetter(char32 codepoint);
bool IsArabicLetter(char32 codepoint);
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
index d0e6164..18cc261 100644
--- a/native/utils/utf8/unilib.h
+++ b/native/utils/utf8/unilib.h
@@ -108,6 +108,14 @@
return libtextclassifier3::IsDot(codepoint);
}
+ bool IsApostrophe(char32 codepoint) const {
+ return libtextclassifier3::IsApostrophe(codepoint);
+ }
+
+ bool IsQuotation(char32 codepoint) const {
+ return libtextclassifier3::IsQuotation(codepoint);
+ }
+
bool IsLatinLetter(char32 codepoint) const {
return libtextclassifier3::IsLatinLetter(codepoint);
}
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
index de0d2ea..518f4c8 100644
--- a/native/utils/utf8/unilib_test-include.cc
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -25,152 +25,158 @@
using ::testing::ElementsAre;
TEST_F(UniLibTest, CharacterClassesAscii) {
- EXPECT_TRUE(unilib_.IsOpeningBracket('('));
- EXPECT_TRUE(unilib_.IsClosingBracket(')'));
- EXPECT_FALSE(unilib_.IsWhitespace(')'));
- EXPECT_TRUE(unilib_.IsWhitespace(' '));
- EXPECT_FALSE(unilib_.IsDigit(')'));
- EXPECT_TRUE(unilib_.IsDigit('0'));
- EXPECT_TRUE(unilib_.IsDigit('9'));
- EXPECT_FALSE(unilib_.IsUpper(')'));
- EXPECT_TRUE(unilib_.IsUpper('A'));
- EXPECT_TRUE(unilib_.IsUpper('Z'));
- EXPECT_FALSE(unilib_.IsLower(')'));
- EXPECT_TRUE(unilib_.IsLower('a'));
- EXPECT_TRUE(unilib_.IsLower('z'));
- EXPECT_TRUE(unilib_.IsPunctuation('!'));
- EXPECT_TRUE(unilib_.IsPunctuation('?'));
- EXPECT_TRUE(unilib_.IsPunctuation('#'));
- EXPECT_TRUE(unilib_.IsPunctuation('('));
- EXPECT_FALSE(unilib_.IsPunctuation('0'));
- EXPECT_FALSE(unilib_.IsPunctuation('$'));
- EXPECT_TRUE(unilib_.IsPercentage('%'));
- EXPECT_TRUE(unilib_.IsPercentage(u'%'));
- EXPECT_TRUE(unilib_.IsSlash('/'));
- EXPECT_TRUE(unilib_.IsSlash(u'/'));
- EXPECT_TRUE(unilib_.IsMinus('-'));
- EXPECT_TRUE(unilib_.IsMinus(u'-'));
- EXPECT_TRUE(unilib_.IsNumberSign('#'));
- EXPECT_TRUE(unilib_.IsNumberSign(u'#'));
- EXPECT_TRUE(unilib_.IsDot('.'));
- EXPECT_TRUE(unilib_.IsDot(u'.'));
+ EXPECT_TRUE(unilib_->IsOpeningBracket('('));
+ EXPECT_TRUE(unilib_->IsClosingBracket(')'));
+ EXPECT_FALSE(unilib_->IsWhitespace(')'));
+ EXPECT_TRUE(unilib_->IsWhitespace(' '));
+ EXPECT_FALSE(unilib_->IsDigit(')'));
+ EXPECT_TRUE(unilib_->IsDigit('0'));
+ EXPECT_TRUE(unilib_->IsDigit('9'));
+ EXPECT_FALSE(unilib_->IsUpper(')'));
+ EXPECT_TRUE(unilib_->IsUpper('A'));
+ EXPECT_TRUE(unilib_->IsUpper('Z'));
+ EXPECT_FALSE(unilib_->IsLower(')'));
+ EXPECT_TRUE(unilib_->IsLower('a'));
+ EXPECT_TRUE(unilib_->IsLower('z'));
+ EXPECT_TRUE(unilib_->IsPunctuation('!'));
+ EXPECT_TRUE(unilib_->IsPunctuation('?'));
+ EXPECT_TRUE(unilib_->IsPunctuation('#'));
+ EXPECT_TRUE(unilib_->IsPunctuation('('));
+ EXPECT_FALSE(unilib_->IsPunctuation('0'));
+ EXPECT_FALSE(unilib_->IsPunctuation('$'));
+ EXPECT_TRUE(unilib_->IsPercentage('%'));
+ EXPECT_TRUE(unilib_->IsPercentage(u'%'));
+ EXPECT_TRUE(unilib_->IsSlash('/'));
+ EXPECT_TRUE(unilib_->IsSlash(u'/'));
+ EXPECT_TRUE(unilib_->IsMinus('-'));
+ EXPECT_TRUE(unilib_->IsMinus(u'-'));
+ EXPECT_TRUE(unilib_->IsNumberSign('#'));
+ EXPECT_TRUE(unilib_->IsNumberSign(u'#'));
+ EXPECT_TRUE(unilib_->IsDot('.'));
+ EXPECT_TRUE(unilib_->IsDot(u'.'));
+ EXPECT_TRUE(unilib_->IsApostrophe('\''));
+ EXPECT_TRUE(unilib_->IsApostrophe(u'ߴ'));
+ EXPECT_TRUE(unilib_->IsQuotation(u'"'));
+ EXPECT_TRUE(unilib_->IsQuotation(u'”'));
- EXPECT_TRUE(unilib_.IsLatinLetter('A'));
- EXPECT_TRUE(unilib_.IsArabicLetter(u'ب')); // ARABIC LETTER BEH
+ EXPECT_TRUE(unilib_->IsLatinLetter('A'));
+ EXPECT_TRUE(unilib_->IsArabicLetter(u'ب')); // ARABIC LETTER BEH
EXPECT_TRUE(
- unilib_.IsCyrillicLetter(u'ᲀ')); // CYRILLIC SMALL LETTER ROUNDED VE
- EXPECT_TRUE(unilib_.IsChineseLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_TRUE(unilib_.IsJapaneseLetter(u'ぁ')); // HIRAGANA LETTER SMALL A
- EXPECT_TRUE(unilib_.IsKoreanLetter(u'ㄱ')); // HANGUL LETTER KIYEOK
- EXPECT_TRUE(unilib_.IsThaiLetter(u'ก')); // THAI CHARACTER KO KAI
- EXPECT_TRUE(unilib_.IsCJTletter(u'ก')); // THAI CHARACTER KO KAI
- EXPECT_FALSE(unilib_.IsCJTletter('A'));
+ unilib_->IsCyrillicLetter(u'ᲀ')); // CYRILLIC SMALL LETTER ROUNDED VE
+ EXPECT_TRUE(unilib_->IsChineseLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsJapaneseLetter(u'ぁ')); // HIRAGANA LETTER SMALL A
+ EXPECT_TRUE(unilib_->IsKoreanLetter(u'ㄱ')); // HANGUL LETTER KIYEOK
+ EXPECT_TRUE(unilib_->IsThaiLetter(u'ก')); // THAI CHARACTER KO KAI
+ EXPECT_TRUE(unilib_->IsCJTletter(u'ก')); // THAI CHARACTER KO KAI
+ EXPECT_FALSE(unilib_->IsCJTletter('A'));
- EXPECT_TRUE(unilib_.IsLetter('A'));
- EXPECT_TRUE(unilib_.IsLetter(u'A'));
- EXPECT_TRUE(unilib_.IsLetter(u'ト')); // KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(u'ト')); // HALFWIDTH KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsLetter('A'));
+ EXPECT_TRUE(unilib_->IsLetter(u'A'));
+ EXPECT_TRUE(unilib_->IsLetter(u'ト')); // KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(u'ト')); // HALFWIDTH KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(u'豈')); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_EQ(unilib_.ToLower('A'), 'a');
- EXPECT_EQ(unilib_.ToLower('Z'), 'z');
- EXPECT_EQ(unilib_.ToLower(')'), ')');
- EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Never gonna give you up."))
+ EXPECT_EQ(unilib_->ToLower('A'), 'a');
+ EXPECT_EQ(unilib_->ToLower('Z'), 'z');
+ EXPECT_EQ(unilib_->ToLower(')'), ')');
+ EXPECT_EQ(unilib_->ToLowerText(UTF8ToUnicodeText("Never gonna give you up."))
.ToUTF8String(),
"never gonna give you up.");
- EXPECT_EQ(unilib_.ToUpper('a'), 'A');
- EXPECT_EQ(unilib_.ToUpper('z'), 'Z');
- EXPECT_EQ(unilib_.ToUpper(')'), ')');
- EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Never gonna let you down."))
+ EXPECT_EQ(unilib_->ToUpper('a'), 'A');
+ EXPECT_EQ(unilib_->ToUpper('z'), 'Z');
+ EXPECT_EQ(unilib_->ToUpper(')'), ')');
+ EXPECT_EQ(unilib_->ToUpperText(UTF8ToUnicodeText("Never gonna let you down."))
.ToUTF8String(),
"NEVER GONNA LET YOU DOWN.");
- EXPECT_EQ(unilib_.GetPairedBracket(')'), '(');
- EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
+ EXPECT_EQ(unilib_->GetPairedBracket(')'), '(');
+ EXPECT_EQ(unilib_->GetPairedBracket('}'), '{');
}
TEST_F(UniLibTest, CharacterClassesUnicode) {
- EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
- EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
- EXPECT_FALSE(unilib_.IsWhitespace(0x23F0)); // ALARM CLOCK
- EXPECT_TRUE(unilib_.IsWhitespace(0x2003)); // EM SPACE
- EXPECT_FALSE(unilib_.IsDigit(0xA619)); // VAI SYMBOL JONG
- EXPECT_TRUE(unilib_.IsDigit(0xA620)); // VAI DIGIT ZERO
- EXPECT_TRUE(unilib_.IsDigit(0xA629)); // VAI DIGIT NINE
- EXPECT_FALSE(unilib_.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
- EXPECT_FALSE(unilib_.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
- EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03B1)); // GREEK SMALL ALPHA
- EXPECT_TRUE(unilib_.IsLower(0x03CB)); // GREEK SMALL UPSILON
- EXPECT_TRUE(unilib_.IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsLower(0x03C0)); // GREEK SMALL PI
- EXPECT_TRUE(unilib_.IsLower(0x007A)); // SMALL Z
- EXPECT_FALSE(unilib_.IsLower(0x005A)); // CAPITAL Z
- EXPECT_FALSE(unilib_.IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_FALSE(unilib_.IsLower(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsPunctuation(0x055E)); // ARMENIAN QUESTION MARK
- EXPECT_TRUE(unilib_.IsPunctuation(0x066C)); // ARABIC THOUSANDS SEPARATOR
- EXPECT_TRUE(unilib_.IsPunctuation(0x07F7)); // NKO SYMBOL GBAKURUNEN
- EXPECT_TRUE(unilib_.IsPunctuation(0x10AF2)); // DOUBLE DOT WITHIN DOT
- EXPECT_FALSE(unilib_.IsPunctuation(0x00A3)); // POUND SIGN
- EXPECT_FALSE(unilib_.IsPunctuation(0xA838)); // NORTH INDIC RUPEE MARK
- EXPECT_TRUE(unilib_.IsPercentage(0x0025)); // PERCENT SIGN
- EXPECT_TRUE(unilib_.IsPercentage(0xFF05)); // FULLWIDTH PERCENT SIGN
- EXPECT_TRUE(unilib_.IsSlash(0x002F)); // SOLIDUS
- EXPECT_TRUE(unilib_.IsSlash(0xFF0F)); // FULLWIDTH SOLIDUS
- EXPECT_TRUE(unilib_.IsMinus(0x002D)); // HYPHEN-MINUS
- EXPECT_TRUE(unilib_.IsMinus(0xFF0D)); // FULLWIDTH HYPHEN-MINUS
- EXPECT_TRUE(unilib_.IsNumberSign(0x0023)); // NUMBER SIGN
- EXPECT_TRUE(unilib_.IsNumberSign(0xFF03)); // FULLWIDTH NUMBER SIGN
- EXPECT_TRUE(unilib_.IsDot(0x002E)); // FULL STOP
- EXPECT_TRUE(unilib_.IsDot(0xFF0E)); // FULLWIDTH FULL STOP
+ EXPECT_TRUE(unilib_->IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
+ EXPECT_TRUE(unilib_->IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
+ EXPECT_FALSE(unilib_->IsWhitespace(0x23F0)); // ALARM CLOCK
+ EXPECT_TRUE(unilib_->IsWhitespace(0x2003)); // EM SPACE
+ EXPECT_FALSE(unilib_->IsDigit(0xA619)); // VAI SYMBOL JONG
+ EXPECT_TRUE(unilib_->IsDigit(0xA620)); // VAI DIGIT ZERO
+ EXPECT_TRUE(unilib_->IsDigit(0xA629)); // VAI DIGIT NINE
+ EXPECT_FALSE(unilib_->IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
+ EXPECT_FALSE(unilib_->IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsUpper(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib_->IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
+ EXPECT_FALSE(unilib_->IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_->IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_->IsLower(0x03B1)); // GREEK SMALL ALPHA
+ EXPECT_TRUE(unilib_->IsLower(0x03CB)); // GREEK SMALL UPSILON
+ EXPECT_TRUE(unilib_->IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_->IsLower(0x03C0)); // GREEK SMALL PI
+ EXPECT_TRUE(unilib_->IsLower(0x007A)); // SMALL Z
+ EXPECT_FALSE(unilib_->IsLower(0x005A)); // CAPITAL Z
+ EXPECT_FALSE(unilib_->IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_FALSE(unilib_->IsLower(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib_->IsPunctuation(0x055E)); // ARMENIAN QUESTION MARK
+ EXPECT_TRUE(unilib_->IsPunctuation(0x066C)); // ARABIC THOUSANDS SEPARATOR
+ EXPECT_TRUE(unilib_->IsPunctuation(0x07F7)); // NKO SYMBOL GBAKURUNEN
+ EXPECT_TRUE(unilib_->IsPunctuation(0x10AF2)); // DOUBLE DOT WITHIN DOT
+ EXPECT_FALSE(unilib_->IsPunctuation(0x00A3)); // POUND SIGN
+ EXPECT_FALSE(unilib_->IsPunctuation(0xA838)); // NORTH INDIC RUPEE MARK
+ EXPECT_TRUE(unilib_->IsPercentage(0x0025)); // PERCENT SIGN
+ EXPECT_TRUE(unilib_->IsPercentage(0xFF05)); // FULLWIDTH PERCENT SIGN
+ EXPECT_TRUE(unilib_->IsSlash(0x002F)); // SOLIDUS
+ EXPECT_TRUE(unilib_->IsSlash(0xFF0F)); // FULLWIDTH SOLIDUS
+ EXPECT_TRUE(unilib_->IsMinus(0x002D)); // HYPHEN-MINUS
+ EXPECT_TRUE(unilib_->IsMinus(0xFF0D)); // FULLWIDTH HYPHEN-MINUS
+ EXPECT_TRUE(unilib_->IsNumberSign(0x0023)); // NUMBER SIGN
+ EXPECT_TRUE(unilib_->IsNumberSign(0xFF03)); // FULLWIDTH NUMBER SIGN
+ EXPECT_TRUE(unilib_->IsDot(0x002E)); // FULL STOP
+ EXPECT_TRUE(unilib_->IsDot(0xFF0E)); // FULLWIDTH FULL STOP
- EXPECT_TRUE(unilib_.IsLatinLetter(0x0041)); // LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsArabicLetter(0x0628)); // ARABIC LETTER BEH
+ EXPECT_TRUE(unilib_->IsLatinLetter(0x0041)); // LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsArabicLetter(0x0628)); // ARABIC LETTER BEH
EXPECT_TRUE(
- unilib_.IsCyrillicLetter(0x1C80)); // CYRILLIC SMALL LETTER ROUNDED VE
- EXPECT_TRUE(unilib_.IsChineseLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_TRUE(unilib_.IsJapaneseLetter(0x3041)); // HIRAGANA LETTER SMALL A
- EXPECT_TRUE(unilib_.IsKoreanLetter(0x3131)); // HANGUL LETTER KIYEOK
- EXPECT_TRUE(unilib_.IsThaiLetter(0x0E01)); // THAI CHARACTER KO KAI
- EXPECT_TRUE(unilib_.IsCJTletter(0x0E01)); // THAI CHARACTER KO KAI
- EXPECT_FALSE(unilib_.IsCJTletter(0x0041)); // LATIN CAPITAL LETTER A
+ unilib_->IsCyrillicLetter(0x1C80)); // CYRILLIC SMALL LETTER ROUNDED VE
+ EXPECT_TRUE(unilib_->IsChineseLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsJapaneseLetter(0x3041)); // HIRAGANA LETTER SMALL A
+ EXPECT_TRUE(unilib_->IsKoreanLetter(0x3131)); // HANGUL LETTER KIYEOK
+ EXPECT_TRUE(unilib_->IsThaiLetter(0x0E01)); // THAI CHARACTER KO KAI
+ EXPECT_TRUE(unilib_->IsCJTletter(0x0E01)); // THAI CHARACTER KO KAI
+ EXPECT_FALSE(unilib_->IsCJTletter(0x0041)); // LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsLetter(0x0041)); // LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsLetter(0xFF21)); // FULLWIDTH LATIN CAPITAL LETTER A
- EXPECT_TRUE(unilib_.IsLetter(0x30C8)); // KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(0xFF84)); // HALFWIDTH KATAKANA LETTER TO
- EXPECT_TRUE(unilib_.IsLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
+ EXPECT_TRUE(unilib_->IsLetter(0x0041)); // LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsLetter(0xFF21)); // FULLWIDTH LATIN CAPITAL LETTER A
+ EXPECT_TRUE(unilib_->IsLetter(0x30C8)); // KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(0xFF84)); // HALFWIDTH KATAKANA LETTER TO
+ EXPECT_TRUE(unilib_->IsLetter(0xF900)); // CJK COMPATIBILITY IDEOGRAPH
- EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
- EXPECT_EQ(unilib_.ToLower(0x03A3), 0x03C3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
- .ToUTF8String(),
- "κανένας άνθρωπος δεν ξέρει");
- EXPECT_TRUE(unilib_.IsLowerText(UTF8ToUnicodeText("ξέρει")));
- EXPECT_EQ(unilib_.ToUpper(0x03B1), 0x0391); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
- EXPECT_EQ(unilib_.ToUpper(0x03C3), 0x03A3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToUpper(0x03C2), 0x03A3); // GREEK CAPITAL LETTER SIGMA
- EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
- .ToUTF8String(),
- "ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ");
- EXPECT_TRUE(unilib_.IsUpperText(UTF8ToUnicodeText("ΚΑΝΈΝΑΣ")));
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
+ EXPECT_EQ(unilib_->ToLower(0x0391), 0x03B1); // GREEK ALPHA
+ EXPECT_EQ(unilib_->ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_->ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+ EXPECT_EQ(unilib_->ToLower(0x03A3), 0x03C3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(
+ unilib_->ToLowerText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "κανένας άνθρωπος δεν ξέρει");
+ EXPECT_TRUE(unilib_->IsLowerText(UTF8ToUnicodeText("ξέρει")));
+ EXPECT_EQ(unilib_->ToUpper(0x03B1), 0x0391); // GREEK ALPHA
+ EXPECT_EQ(unilib_->ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_->ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
+ EXPECT_EQ(unilib_->ToUpper(0x03C3), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(unilib_->ToUpper(0x03C2), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(
+ unilib_->ToUpperText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ");
+ EXPECT_TRUE(unilib_->IsUpperText(UTF8ToUnicodeText("ΚΑΝΈΝΑΣ")));
+ EXPECT_EQ(unilib_->GetPairedBracket(0x0F3C), 0x0F3D);
+ EXPECT_EQ(unilib_->GetPairedBracket(0x0F3D), 0x0F3C);
}
TEST_F(UniLibTest, RegexInterface) {
const UnicodeText regex_pattern =
UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
+ unilib_->CreateRegexPattern(regex_pattern);
const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false);
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input);
@@ -188,7 +194,7 @@
const UnicodeText regex_pattern =
UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
+ unilib_->CreateRegexPattern(regex_pattern);
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher;
@@ -220,7 +226,7 @@
TEST_F(UniLibTest, RegexLazy) {
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateLazyRegexPattern(
+ unilib_->CreateLazyRegexPattern(
UTF8ToUnicodeText("[a-z][0-9]", /*do_copy=*/false));
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher;
@@ -246,7 +252,7 @@
const UnicodeText regex_pattern =
UTF8ToUnicodeText("([0-9])([0-9]+)😋", /*do_copy=*/false);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
+ unilib_->CreateRegexPattern(regex_pattern);
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher;
@@ -278,7 +284,7 @@
const UnicodeText regex_pattern =
UTF8ToUnicodeText("([0-9])([a-z])?", /*do_copy=*/false);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
+ unilib_->CreateRegexPattern(regex_pattern);
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher;
@@ -297,7 +303,7 @@
const UnicodeText regex_pattern =
UTF8ToUnicodeText("(.*)", /*do_copy=*/false);
std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
+ unilib_->CreateRegexPattern(regex_pattern);
int status;
std::unique_ptr<UniLib::RegexMatcher> matcher;
@@ -313,7 +319,7 @@
TEST_F(UniLibTest, BreakIterator) {
const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
+ unilib_->CreateBreakIterator(text);
std::vector<int> break_indices;
int break_index = 0;
while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
@@ -325,7 +331,7 @@
TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
+ unilib_->CreateBreakIterator(text);
std::vector<int> break_indices;
int break_index = 0;
while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
@@ -336,20 +342,20 @@
TEST_F(UniLibTest, Integer32Parse) {
int result;
- EXPECT_TRUE(
- unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
+ EXPECT_TRUE(unilib_->ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
EXPECT_EQ(result, 123);
}
TEST_F(UniLibTest, Integer32ParseFloatNumber) {
int result;
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt32(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
+ &result));
}
TEST_F(UniLibTest, Integer32ParseLongNumber) {
int32 result;
- EXPECT_TRUE(unilib_.ParseInt32(
+ EXPECT_TRUE(unilib_->ParseInt32(
UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
EXPECT_EQ(result, 1000000000);
}
@@ -357,97 +363,97 @@
TEST_F(UniLibTest, Integer32ParseEmptyString) {
int result;
EXPECT_FALSE(
- unilib_.ParseInt32(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+ unilib_->ParseInt32(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, Integer32ParseFullWidth) {
int result;
// The input string here is full width
- EXPECT_TRUE(unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
+ EXPECT_TRUE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
EXPECT_EQ(result, 123);
}
TEST_F(UniLibTest, Integer32ParseNotNumber) {
int result;
// The input string here is full width
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt32(
+ UTF8ToUnicodeText("1a3", /*do_copy=*/false), &result));
// Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("Nancy",
- /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt32(UTF8ToUnicodeText("Nancy",
+ /*do_copy=*/false),
+ &result));
// Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseInt32(
+ EXPECT_FALSE(unilib_->ParseInt32(
UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, Integer64Parse) {
int64 result;
- EXPECT_TRUE(
- unilib_.ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
+ EXPECT_TRUE(unilib_->ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
EXPECT_EQ(result, 123);
}
TEST_F(UniLibTest, Integer64ParseFloatNumber) {
int64 result;
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt64(UTF8ToUnicodeText("12.3", /*do_copy=*/false),
+ &result));
}
TEST_F(UniLibTest, Integer64ParseLongNumber) {
int64 result;
// The limitation comes from the javaicu implementation: parseDouble does not
// have ICU support and parseInt limit the size of the number.
- EXPECT_TRUE(unilib_.ParseInt64(
+ EXPECT_TRUE(unilib_->ParseInt64(
UTF8ToUnicodeText("1000000000", /*do_copy=*/false), &result));
EXPECT_EQ(result, 1000000000);
}
TEST_F(UniLibTest, Integer64ParseOverflowNumber) {
int64 result;
- EXPECT_FALSE(unilib_.ParseInt64(
+ EXPECT_FALSE(unilib_->ParseInt64(
UTF8ToUnicodeText("92233720368547758099", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, Integer64ParseOverflowNegativeNumber) {
int64 result;
- EXPECT_FALSE(unilib_.ParseInt64(
+ EXPECT_FALSE(unilib_->ParseInt64(
UTF8ToUnicodeText("-92233720368547758099", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, Integer64ParseEmptyString) {
int64 result;
EXPECT_FALSE(
- unilib_.ParseInt64(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+ unilib_->ParseInt64(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, Integer64ParseFullWidth) {
int64 result;
// The input string here is full width
- EXPECT_TRUE(unilib_.ParseInt64(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
+ EXPECT_TRUE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
EXPECT_EQ(result, 123);
}
TEST_F(UniLibTest, Integer64ParseNotNumber) {
int64 result;
// The input string here is full width
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("1a4", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt64(
+ UTF8ToUnicodeText("1a4", /*do_copy=*/false), &result));
// Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseInt64(UTF8ToUnicodeText("Nancy",
- /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseInt64(UTF8ToUnicodeText("Nancy",
+ /*do_copy=*/false),
+ &result));
// Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseInt64(
+ EXPECT_FALSE(unilib_->ParseInt64(
UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, DoubleParse) {
double result;
- EXPECT_TRUE(unilib_.ParseDouble(UTF8ToUnicodeText("1.23", /*do_copy=*/false),
- &result));
+ EXPECT_TRUE(unilib_->ParseDouble(UTF8ToUnicodeText("1.23", /*do_copy=*/false),
+ &result));
EXPECT_EQ(result, 1.23);
}
@@ -455,46 +461,46 @@
double result;
// The limitation comes from the javaicu implementation: parseDouble does not
// have ICU support and parseInt limit the size of the number.
- EXPECT_TRUE(unilib_.ParseDouble(
+ EXPECT_TRUE(unilib_->ParseDouble(
UTF8ToUnicodeText("999999999.999999999", /*do_copy=*/false), &result));
EXPECT_EQ(result, 999999999.999999999);
}
TEST_F(UniLibTest, DoubleParseWithoutFractionalPart) {
double result;
- EXPECT_TRUE(unilib_.ParseDouble(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
+ EXPECT_TRUE(unilib_->ParseDouble(UTF8ToUnicodeText("123", /*do_copy=*/false),
+ &result));
EXPECT_EQ(result, 123);
}
TEST_F(UniLibTest, DoubleParseEmptyString) {
double result;
EXPECT_FALSE(
- unilib_.ParseDouble(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
+ unilib_->ParseDouble(UTF8ToUnicodeText("", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, DoubleParsePrecedingDot) {
double result;
- EXPECT_FALSE(unilib_.ParseDouble(UTF8ToUnicodeText(".123", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText(".123", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, DoubleParseLeadingDot) {
double result;
- EXPECT_FALSE(unilib_.ParseDouble(UTF8ToUnicodeText("123.", /*do_copy=*/false),
- &result));
+ EXPECT_FALSE(unilib_->ParseDouble(
+ UTF8ToUnicodeText("123.", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, DoubleParseMultipleDots) {
double result;
- EXPECT_FALSE(unilib_.ParseDouble(
+ EXPECT_FALSE(unilib_->ParseDouble(
UTF8ToUnicodeText("1.2.3", /*do_copy=*/false), &result));
}
TEST_F(UniLibTest, DoubleParseFullWidth) {
double result;
// The input string here is full width
- EXPECT_TRUE(unilib_.ParseDouble(
+ EXPECT_TRUE(unilib_->ParseDouble(
UTF8ToUnicodeText("1.23", /*do_copy=*/false), &result));
EXPECT_EQ(result, 1.23);
}
@@ -502,13 +508,13 @@
TEST_F(UniLibTest, DoubleParseNotNumber) {
double result;
// The input string here is full width
- EXPECT_FALSE(unilib_.ParseDouble(
+ EXPECT_FALSE(unilib_->ParseDouble(
UTF8ToUnicodeText("1a5", /*do_copy=*/false), &result));
// Strings starting with "nan" are not numbers.
- EXPECT_FALSE(unilib_.ParseDouble(
+ EXPECT_FALSE(unilib_->ParseDouble(
UTF8ToUnicodeText("Nancy", /*do_copy=*/false), &result));
// Strings starting with "inf" are not numbers
- EXPECT_FALSE(unilib_.ParseDouble(
+ EXPECT_FALSE(unilib_->ParseDouble(
UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
}
diff --git a/native/utils/utf8/unilib_test-include.h b/native/utils/utf8/unilib_test-include.h
index 342a00c..8ae8a0f 100644
--- a/native/utils/utf8/unilib_test-include.h
+++ b/native/utils/utf8/unilib_test-include.h
@@ -17,28 +17,17 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
+#include "utils/jvm-test-utils.h"
#include "utils/utf8/unilib.h"
#include "gtest/gtest.h"
-#if defined TC3_UNILIB_ICU
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#elif defined TC3_UNILIB_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
-#elif defined TC3_UNILIB_APPLE
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#elif defined TC3_UNILIB_DUMMY
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#endif
-
namespace libtextclassifier3 {
namespace test_internal {
class UniLibTest : public ::testing::Test {
protected:
- UniLibTest() : TC3_TESTING_CREATE_UNILIB_INSTANCE(unilib_) {}
- UniLib unilib_;
+ UniLibTest() : unilib_(CreateUniLibForTesting()) {}
+ std::unique_ptr<UniLib> unilib_;
};
} // namespace test_internal
diff --git a/native/utils/zlib/zlib_regex.cc b/native/utils/zlib/zlib_regex.cc
index 4822d6f..901bb91 100644
--- a/native/utils/zlib/zlib_regex.cc
+++ b/native/utils/zlib/zlib_regex.cc
@@ -48,7 +48,7 @@
}
unicode_regex_pattern =
UTF8ToUnicodeText(uncompressed_pattern->c_str(),
- uncompressed_pattern->Length(), /*do_copy=*/false);
+ uncompressed_pattern->size(), /*do_copy=*/false);
}
if (result_pattern_text != nullptr) {
diff --git a/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java b/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java
new file mode 100644
index 0000000..215c3bc
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/NotificationTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Notification;
+import android.app.Person;
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests the {@link Notification} API. This is to fulfill the API coverage requirement for mainline.
+ *
+ * <p>Ideally, it should be a CTS. However, it is too late to add a CTS for R, and thus we test it
+ * on our side.
+ */
+@RunWith(JUnit4.class)
+public final class NotificationTest {
+
+ private Context context;
+
+ @Before
+ public void setup() {
+ context = ApplicationProvider.getApplicationContext();
+ }
+
+ @Test
+ public void getMessagesFromBundleArray() {
+ Person sender = new Person.Builder().setName("Sender").build();
+ Notification.MessagingStyle.Message firstExpectedMessage =
+ new Notification.MessagingStyle.Message("hello", /* timestamp= */ 123, sender);
+ Notification.MessagingStyle.Message secondExpectedMessage =
+ new Notification.MessagingStyle.Message("hello2", /* timestamp= */ 456, sender);
+
+ Notification.MessagingStyle messagingStyle =
+ new Notification.MessagingStyle("self name")
+ .addMessage(firstExpectedMessage)
+ .addMessage(secondExpectedMessage);
+ Notification notification =
+ new Notification.Builder(context, "test id")
+ .setSmallIcon(1)
+ .setContentTitle("test title")
+ .setStyle(messagingStyle)
+ .build();
+
+ List<Notification.MessagingStyle.Message> actualMessages =
+ Notification.MessagingStyle.Message.getMessagesFromBundleArray(
+ notification.extras.getParcelableArray(Notification.EXTRA_MESSAGES));
+
+ assertThat(actualMessages).hasSize(2);
+ assertMessageEquals(firstExpectedMessage, actualMessages.get(0));
+ assertMessageEquals(secondExpectedMessage, actualMessages.get(1));
+ }
+
+ private static void assertMessageEquals(
+ Notification.MessagingStyle.Message expected, Notification.MessagingStyle.Message actual) {
+ assertThat(actual.getText().toString()).isEqualTo(expected.getText().toString());
+ assertThat(actual.getTimestamp()).isEqualTo(expected.getTimestamp());
+ assertThat(actual.getSenderPerson()).isEqualTo(expected.getSenderPerson());
+ }
+}